Fix keyword-parameter handling bug in formparams.
[wrw.git] / wrw / proto.py
1 import time, calendar, collections.abc, binascii, base64
2
3 statusinfo = {
4     400: ("Bad Request", "Invalid HTTP request."),
5     401: ("Unauthorized", "Authentication must be provided for the requested resource."),
6     403: ("Forbidden", "You are not authorized to request the requested resource."),
7     404: ("Not Found", "The requested resource was not found."),
8     405: ("Method Not Allowed", "The request method is not recognized or permitted by the requested resource."),
9     429: ("Too Many Requests", "Your client is sending more frequent requests than are accepted."),
10     500: ("Server Error", "An internal error occurred."),
11     501: ("Not Implemented", "The requested functionality has not been implemented."),
12     503: ("Service Unavailable", "Service is being denied at this time."),
13     }
14
15 def httpdate(ts):
16     return time.strftime("%a, %d %b %Y %H:%M:%S +0000", time.gmtime(ts))
17
18 def phttpdate(dstr):
19     tz = dstr[-6:]
20     dstr = dstr[:-6]
21     if tz[0] != " " or (tz[1] != "+" and tz[1] != "-") or not tz[2:].isdigit():
22         return None
23     tz = int(tz[1:])
24     tz = (((tz / 100) * 60) + (tz % 100)) * 60
25     return calendar.timegm(time.strptime(dstr, "%a, %d %b %Y %H:%M:%S")) - tz
26
27 def pmimehead(hstr):
28     def pws(p):
29         while p < len(hstr) and hstr[p].isspace():
30             p += 1
31         return p
32     def token(p, sep):
33         buf = ""
34         p = pws(p)
35         if p >= len(hstr):
36             return "", p
37         if hstr[p] == '"':
38             p += 1
39             while p < len(hstr):
40                 if hstr[p] == '\\':
41                     p += 1
42                     if p < len(hstr):
43                         buf += hstr[p]
44                         p += 1
45                     else:
46                         break
47                 elif hstr[p] == '"':
48                     p += 1
49                     break
50                 else:
51                     buf += hstr[p]
52                     p += 1
53             return buf, pws(p)
54         else:
55             while p < len(hstr):
56                 if hstr[p] in sep:
57                     break
58                 buf += hstr[p]
59                 p += 1
60             return buf.strip(), pws(p)
61     p = 0
62     val, p = token(p, ";")
63     pars = {}
64     while p < len(hstr):
65         if hstr[p] != ';':
66             break
67         p += 1
68         k, p = token(p, "=")
69         if k == "" or hstr[p:p + 1] != '=':
70             break
71         p += 1
72         v, p = token(p, ';')
73         pars[k.lower()] = v
74     return val, pars
75
76 def htmlq(html):
77     ret = ""
78     for c in html:
79         if c == "&":
80             ret += "&amp;"
81         elif c == "<":
82             ret += "&lt;"
83         elif c == ">":
84             ret += "&gt;"
85         else:
86             ret += c
87     return ret
88
89 def simpleerror(env, startreq, code, title, msg):
90     buf = """<?xml version="1.0" encoding="US-ASCII"?>
91 <!DOCTYPE html PUBLIC "-//W3C//DTD XHTML 1.1//EN" "http://www.w3.org/TR/xhtml11/DTD/xhtml11.dtd">
92 <html xmlns="http://www.w3.org/1999/xhtml" xml:lang="en-US">
93 <head>
94 <title>%s</title>
95 </head>
96 <body>
97 <h1>%s</h1>
98 <p>%s</p>
99 </body>
100 </html>
101 """ % (title, title, htmlq(msg))
102     buf = buf.encode("us-ascii")
103     startreq("%i %s" % (code, title), [("Content-Type", "text/html"), ("Content-Length", str(len(buf)))])
104     return [buf]
105
106 def urlq(url):
107     if isinstance(url, str):
108         url = url.encode("utf-8")
109     ret = ""
110     invalid = b"%;&=+#?/\"'"
111     for c in url:
112         if c in invalid or (c <= 32) or (c >= 128):
113             ret += "%%%02X" % c
114         else:
115             ret += chr(c)
116     return ret
117
118 class urlerror(ValueError):
119     pass
120
121 def parseurl(url):
122     p = url.find("://")
123     if p < 0:
124         raise urlerror("Protocol not found in absolute URL `%s'" % url)
125     proto = url[:p]
126     l = url.find("/", p + 3)
127     if l < 0:
128         raise urlerror("Local part not found in absolute URL `%s'" % url)
129     host = url[p + 3:l]
130     local = url[l:]
131     q = local.find("?")
132     if q < 0:
133         query = ""
134     else:
135         query = local[q + 1:]
136         local = local[:q]
137     return proto, host, local, query
138
139 def consurl(proto, host, local, query=""):
140     if len(local) < 1 and local[0] != '/':
141         raise urlerror("Local part of URL must begin with a slash")
142     ret = "%s://%s%s" % (proto, host, local)
143     if len(query) > 0:
144         ret += "?" + query
145     return ret
146
147 def appendurl(url, other):
148     if "://" in other:
149         return other
150     proto, host, local, query = parseurl(url)
151     if len(other) > 0 and other[0] == '/':
152         return consurl(proto, host, other)
153     else:
154         p = local.rfind('/')
155         return consurl(proto, host, local[:p + 1] + other)
156
157 def siteurl(req):
158     host = req.ihead.get("Host", None)
159     if host is None:
160         raise Exception("Could not reconstruct URL because no Host header was sent")
161     proto = "http"
162     if req.https:
163         proto = "https"
164     return "%s://%s/" % (proto, host)
165
166 def scripturl(req):
167     s = siteurl(req)
168     if req.uriname[0] != '/':
169         raise Exception("Malformed local part when reconstructing URL")
170     return siteurl(req) + req.uriname[1:]
171
172 def requrl(req, qs=True):
173     s = siteurl(req)
174     if req.uri[0] != '/':
175         raise Exception("Malformed local part when reconstructing URL")
176     pf = req.uri[1:]
177     if not qs:
178         p = pf.find('?')
179         if not p < 0:
180             pf = pf[:p]
181     return siteurl(req) + pf
182
183 def parstring(pars={}, **augment):
184     buf = ""
185     for key in pars:
186         if key in augment:
187             val = augment[key]
188             del augment[key]
189         else:
190             val = pars[key]
191         if val is None:
192             continue
193         if buf != "": buf += "&"
194         buf += urlq(key) + "=" + urlq(str(val))
195     for key, val in augment.items():
196         if val is None:
197             continue
198         if buf != "": buf += "&"
199         buf += urlq(key) + "=" + urlq(str(val))
200     return buf
201
202 def parurl(url, pars={}, **augment):
203     qs = parstring(pars, **augment)
204     if qs != "":
205         return url + ("&" if "?" in url else "?") + qs
206     else:
207         return url
208
209 # Wrap these, since binascii is a bit funky. :P
210 def enhex(bs):
211     return base64.b16encode(bs).decode("us-ascii")
212 def unhex(es):
213     if not isinstance(es, collections.abc.ByteString):
214         try:
215             es = es.encode("us-ascii")
216         except UnicodeError:
217             raise binascii.Error("non-ascii character in hex-string")
218     return base64.b16decode(es)
219 def enb32(bs):
220     return base64.b32encode(bs).decode("us-ascii")
221 def unb32(es):
222     if not isinstance(es, collections.abc.ByteString):
223         try:
224             es = es.encode("us-ascii")
225         except UnicodeError:
226             raise binascii.Error("non-ascii character in base32-string")
227     if (len(es) % 8) != 0:
228         es += b"=" * (8 - (len(es) % 8))
229     es = es.upper()             # The whole point of Base32 is that it's case-insensitive :P
230     return base64.b32decode(es)
231 def enb64(bs):
232     return base64.b64encode(bs).decode("us-ascii")
233 def unb64(es):
234     if not isinstance(es, collections.abc.ByteString):
235         try:
236             es = es.encode("us-ascii")
237         except UnicodeError:
238             raise binascii.Error("non-ascii character in base64-string")
239     if (len(es) % 4) != 0:
240         es += b"=" * (4 - (len(es) % 4))
241     return base64.b64decode(es)
242
243 def _quoprisafe():
244     ret = [False] * 256
245     for c in "-!*+/":
246         ret[ord(c)] = True
247     for c in range(ord('0'), ord('9') + 1):
248         ret[c] = True
249     for c in range(ord('A'), ord('Z') + 1):
250         ret[c] = True
251     for c in range(ord('a'), ord('z') + 1):
252         ret[c] = True
253     return ret
254 _quoprisafe = _quoprisafe()
255 def quopri(s, charset="utf-8"):
256     bv = s.encode(charset)
257     qn = sum(not _quoprisafe[b] for b in bv)
258     if qn == 0:
259         return s
260     if qn > len(bv) / 2:
261         return "=?%s?B?%s?=" % (charset, enb64(bv))
262     else:
263         return "=?%s?Q?%s?=" % (charset, "".join(chr(b) if _quoprisafe[b] else "=%02X" % b for b in bv))
264
265 class mimeparam(object):
266     def __init__(self, name, val, fallback=None, charset="utf-8", lang=""):
267         self.name = name
268         self.val = val
269         self.fallback = fallback
270         self.charset = charset
271         self.lang = lang
272
273     def __str__(self):
274         self.name.encode("ascii")
275         try:
276             self.val.encode("ascii")
277         except UnicodeError:
278             pass
279         else:
280             return "%s=%s" % (self.name, self.val)
281         val = self.val.encode(self.charset)
282         self.charset.encode("ascii")
283         self.lang.encode("ascii")
284         ret = ""
285         if self.fallback is not None:
286             self.fallback.encode("ascii")
287             ret += "%s=%s; " % (self.name, self.fallback)
288         ret += "%s*=%s'%s'%s" % (self.name, self.charset, self.lang, urlq(val))
289         return ret
290
291 class mimeheader(object):
292     def __init__(self, name, val, *, mime_charset="utf-8", mime_lang="", **params):
293         self.name = name
294         self.val = val
295         self.params = {}
296         self.charset = mime_charset
297         self.lang = mime_lang
298         for k, v in params.items():
299             self[k] = v
300
301     def __getitem__(self, nm):
302         return self.params[nm.lower()]
303
304     def __setitem__(self, nm, val):
305         if not isinstance(val, mimeparam):
306             val = mimeparam(nm, val, charset=self.charset, lang=self.lang)
307         self.params[nm.lower()] = val
308
309     def __delitem__(self, nm):
310         del self.params[nm.lower()]
311
312     def value(self):
313         parts = []
314         if self.val != None:
315             parts.append(quopri(self.val))
316         parts.extend(str(x) for x in self.params.values())
317         return("; ".join(parts))
318
319     def __str__(self):
320         if self.name is None:
321             return self.value()
322         return "%s: %s" % (self.name, self.value())