Allow parurl to append to URLs that already have a query-string.
[wrw.git] / wrw / proto.py
1 import time, calendar, collections, binascii, base64
2
3 statusinfo = {
4     400: ("Bad Request", "Invalid HTTP request."),
5     401: ("Unauthorized", "Authentication must be provided for the requested resource."),
6     403: ("Forbidden", "You are not authorized to request the requested resource."),
7     404: ("Not Found", "The requested resource was not found."),
8     405: ("Method Not Allowed", "The request method is not recognized or permitted by the requested resource."),
9     500: ("Server Error", "An internal error occurred."),
10     501: ("Not Implemented", "The requested functionality has not been implemented."),
11     503: ("Service Unavailable", "Service is being denied at this time."),
12     }
13
14 def httpdate(ts):
15     return time.strftime("%a, %d %b %Y %H:%M:%S +0000", time.gmtime(ts))
16
17 def phttpdate(dstr):
18     tz = dstr[-6:]
19     dstr = dstr[:-6]
20     if tz[0] != " " or (tz[1] != "+" and tz[1] != "-") or not tz[2:].isdigit():
21         return None
22     tz = int(tz[1:])
23     tz = (((tz / 100) * 60) + (tz % 100)) * 60
24     return calendar.timegm(time.strptime(dstr, "%a, %d %b %Y %H:%M:%S")) - tz
25
26 def pmimehead(hstr):
27     def pws(p):
28         while p < len(hstr) and hstr[p].isspace():
29             p += 1
30         return p
31     def token(p, sep):
32         buf = ""
33         p = pws(p)
34         if p >= len(hstr):
35             return "", p
36         if hstr[p] == '"':
37             p += 1
38             while p < len(hstr):
39                 if hstr[p] == '\\':
40                     p += 1
41                     if p < len(hstr):
42                         buf += hstr[p]
43                         p += 1
44                     else:
45                         break
46                 elif hstr[p] == '"':
47                     p += 1
48                     break
49                 else:
50                     buf += hstr[p]
51                     p += 1
52             return buf, pws(p)
53         else:
54             while p < len(hstr):
55                 if hstr[p] in sep:
56                     break
57                 buf += hstr[p]
58                 p += 1
59             return buf.strip(), pws(p)
60     p = 0
61     val, p = token(p, ";")
62     pars = {}
63     while p < len(hstr):
64         if hstr[p] != ';':
65             break
66         p += 1
67         k, p = token(p, "=")
68         if k == "" or hstr[p:p + 1] != '=':
69             break
70         p += 1
71         v, p = token(p, ';')
72         pars[k.lower()] = v
73     return val, pars
74
75 def htmlq(html):
76     ret = ""
77     for c in html:
78         if c == "&":
79             ret += "&amp;"
80         elif c == "<":
81             ret += "&lt;"
82         elif c == ">":
83             ret += "&gt;"
84         else:
85             ret += c
86     return ret
87
88 def simpleerror(env, startreq, code, title, msg):
89     buf = """<?xml version="1.0" encoding="US-ASCII"?>
90 <!DOCTYPE html PUBLIC "-//W3C//DTD XHTML 1.1//EN" "http://www.w3.org/TR/xhtml11/DTD/xhtml11.dtd">
91 <html xmlns="http://www.w3.org/1999/xhtml" xml:lang="en-US">
92 <head>
93 <title>%s</title>
94 </head>
95 <body>
96 <h1>%s</h1>
97 <p>%s</p>
98 </body>
99 </html>
100 """ % (title, title, htmlq(msg))
101     buf = buf.encode("us-ascii")
102     startreq("%i %s" % (code, title), [("Content-Type", "text/html"), ("Content-Length", str(len(buf)))])
103     return [buf]
104
105 def urlq(url):
106     if isinstance(url, str):
107         url = url.encode("utf-8")
108     ret = ""
109     invalid = b"%;&=#?/\"'"
110     for c in url:
111         if c in invalid or (c <= 32) or (c >= 128):
112             ret += "%%%02X" % c
113         else:
114             ret += chr(c)
115     return ret
116
117 class urlerror(ValueError):
118     pass
119
120 def parseurl(url):
121     p = url.find("://")
122     if p < 0:
123         raise urlerror("Protocol not found in absolute URL `%s'" % url)
124     proto = url[:p]
125     l = url.find("/", p + 3)
126     if l < 0:
127         raise urlerror("Local part not found in absolute URL `%s'" % url)
128     host = url[p + 3:l]
129     local = url[l:]
130     q = local.find("?")
131     if q < 0:
132         query = ""
133     else:
134         query = local[q + 1:]
135         local = local[:q]
136     return proto, host, local, query
137
138 def consurl(proto, host, local, query=""):
139     if len(local) < 1 and local[0] != '/':
140         raise urlerror("Local part of URL must begin with a slash")
141     ret = "%s://%s%s" % (proto, host, local)
142     if len(query) > 0:
143         ret += "?" + query
144     return ret
145
146 def appendurl(url, other):
147     if "://" in other:
148         return other
149     proto, host, local, query = parseurl(url)
150     if len(other) > 0 and other[0] == '/':
151         return consurl(proto, host, other)
152     else:
153         p = local.rfind('/')
154         return consurl(proto, host, local[:p + 1] + other)
155
156 def siteurl(req):
157     host = req.ihead.get("Host", None)
158     if host is None:
159         raise Exception("Could not reconstruct URL because no Host header was sent")
160     proto = "http"
161     if req.https:
162         proto = "https"
163     return "%s://%s/" % (proto, host)
164
165 def scripturl(req):
166     s = siteurl(req)
167     if req.uriname[0] != '/':
168         raise Exception("Malformed local part when reconstructing URL")
169     return siteurl(req) + req.uriname[1:]
170
171 def requrl(req):
172     s = siteurl(req)
173     if req.uri[0] != '/':
174         raise Exception("Malformed local part when reconstructing URL")
175     return siteurl(req) + req.uri[1:]
176
177 def parstring(pars={}, **augment):
178     buf = ""
179     for key in pars:
180         if key in augment:
181             val = augment[key]
182             del augment[key]
183         else:
184             val = pars[key]
185         if buf != "": buf += "&"
186         buf += urlq(key) + "=" + urlq(str(val))
187     for key in augment:
188         if buf != "": buf += "&"
189         buf += urlq(key) + "=" + urlq(str(augment[key]))
190     return buf
191
192 def parurl(url, pars={}, **augment):
193     qs = parstring(pars, **augment)
194     if qs != "":
195         return url + ("&" if "?" in url else "?") + qs
196     else:
197         return url
198
199 # Wrap these, since binascii is a bit funky. :P
200 def enhex(bs):
201     return base64.b16encode(bs).decode("us-ascii")
202 def unhex(es):
203     if not isinstance(es, collections.ByteString):
204         try:
205             es = es.encode("us-ascii")
206         except UnicodeError:
207             raise binascii.Error("non-ascii character in hex-string")
208     return base64.b16decode(es)
209 def enb32(bs):
210     return base64.b32encode(bs).decode("us-ascii")
211 def unb32(es):
212     if not isinstance(es, collections.ByteString):
213         try:
214             es = es.encode("us-ascii")
215         except UnicodeError:
216             raise binascii.Error("non-ascii character in base32-string")
217     if (len(es) % 8) != 0:
218         es += b"=" * (8 - (len(es) % 8))
219     es = es.upper()             # The whole point of Base32 is that it's case-insensitive :P
220     return base64.b32decode(es)
221 def enb64(bs):
222     return base64.b64encode(bs).decode("us-ascii")
223 def unb64(es):
224     if not isinstance(es, collections.ByteString):
225         try:
226             es = es.encode("us-ascii")
227         except UnicodeError:
228             raise binascii.Error("non-ascii character in base64-string")
229     if (len(es) % 4) != 0:
230         es += b"=" * (4 - (len(es) % 4))
231     return base64.b64decode(es)