Fix keyword-parameter handling bug in formparams.
[wrw.git] / wrw / form.py
1 import urllib.parse
2 from . import proto
3
4 __all__ = ["formdata"]
5
6 def formparse(req):
7     buf = {}
8     buf.update(urllib.parse.parse_qsl(req.query, keep_blank_values=True))
9     ctype, ctpars = proto.pmimehead(req.ihead.get("Content-Type", ""))
10     if ctype == "application/x-www-form-urlencoded":
11         try:
12             rbody = req.input.read(2 ** 20)
13         except IOError as exc:
14             return exc
15         if len(rbody) >= 2 ** 20:
16             return ValueError("x-www-form-urlencoded data is absurdly long")
17         buf.update(urllib.parse.parse_qsl(rbody.decode("latin1"), encoding=ctpars.get("charset", "utf-8"), keep_blank_values=True))
18     return buf
19
20 class badmultipart(IOError):
21     pass
22
23 class formpart(object):
24     def __init__(self, form):
25         self.form = form
26         self.buf = b""
27         self.eof = False
28         self.head = {}
29
30     def parsehead(self):
31         pass
32
33     def fillbuf(self, sz):
34         req = self.form.req
35         mboundary = b"\r\n--" + self.form.boundary + b"\r\n"
36         lboundary = b"\r\n--" + self.form.boundary + b"--\r\n"
37         while not self.eof:
38             p = self.form.buf.find(mboundary)
39             if p >= 0:
40                 self.buf += self.form.buf[:p]
41                 self.form.buf = self.form.buf[p + len(mboundary):]
42                 self.eof = True
43                 break
44             p = self.form.buf.find(lboundary)
45             if p >= 0:
46                 self.buf += self.form.buf[:p]
47                 self.form.buf = self.form.buf[p + len(lboundary):]
48                 self.eof = True
49                 self.form.eof = True
50                 break
51             self.buf += self.form.buf[:-len(lboundary)]
52             self.form.buf = self.form.buf[-len(lboundary):]
53             if sz >= 0 and len(self.buf) >= sz:
54                 break
55             while len(self.form.buf) <= len(lboundary):
56                 ret = req.input.read(8192)
57                 if ret == b"":
58                     raise badmultipart("Missing last multipart boundary")
59                 self.form.buf += ret
60
61     def read(self, limit=-1):
62         self.fillbuf(limit)
63         if limit >= 0:
64             ret = self.buf[:limit]
65             self.buf = self.buf[limit:]
66         else:
67             ret = self.buf
68             self.buf = b""
69         return ret
70
71     def readline(self, limit=-1):
72         last = 0
73         while True:
74             p = self.buf.find(b'\n', last)
75             if p < 0:
76                 if self.eof:
77                     ret = self.buf
78                     self.buf = b""
79                     return ret
80                 last = len(self.buf)
81                 self.fillbuf(last + 128)
82             else:
83                 ret = self.buf[:p + 1]
84                 self.buf = self.buf[p + 1:]
85                 return ret
86
87     def close(self):
88         while True:
89             if self.read(8192) == b"":
90                 break
91
92     def __enter__(self):
93         return self
94
95     def __exit__(self, *excinfo):
96         self.close()
97         return False
98
99     def parsehead(self, charset):
100         def headline():
101             ln = self.readline(256)
102             if ln[-1] != ord(b'\n'):
103                 raise badmultipart("Too long header line in part")
104             try:
105                 return ln.decode(charset).rstrip()
106             except UnicodeError:
107                 raise badmultipart("Form part header is not in assumed charset")
108
109         ln = headline()
110         while True:
111             if ln == "":
112                 break
113             buf = ln
114             while True:
115                 ln = headline()
116                 if not ln[1:].isspace():
117                     break
118                 buf += ln.lstrip()
119             p = buf.find(':')
120             if p < 0:
121                 raise badmultipart("Malformed multipart header line")
122             self.head[buf[:p].strip().lower()] = buf[p + 1:].lstrip()
123
124         val, par = proto.pmimehead(self.head.get("content-disposition", ""))
125         if val != "form-data":
126             raise badmultipart("Unexpected Content-Disposition in form part: %r" % val)
127         if not "name" in par:
128             raise badmultipart("Missing name in form part")
129         self.name = par["name"]
130         self.filename = par.get("filename")
131         val, par = proto.pmimehead(self.head.get("content-type", ""))
132         self.ctype = val
133         self.charset = par.get("charset")
134         encoding = self.head.get("content-transfer-encoding", "binary")
135         if encoding != "binary":
136             raise badmultipart("Form part uses unexpected transfer encoding: %r" % encoding)
137
138 class multipart(object):
139     def __init__(self, req, charset):
140         val, par = proto.pmimehead(req.ihead.get("Content-Type", ""))
141         if req.method != "POST" or val != "multipart/form-data":
142             raise badmultipart("Request is not a multipart form")
143         if "boundary" not in par:
144             raise badmultipart("Multipart form lacks boundary")
145         try:
146             self.boundary = par["boundary"].encode("us-ascii")
147         except UnicodeError:
148             raise badmultipart("Multipart boundary must be ASCII string")
149         self.req = req
150         self.buf = b"\r\n"
151         self.eof = False
152         self.headcs = charset
153         self.lastpart = formpart(self)
154         self.lastpart.close()
155
156     def __iter__(self):
157         return self
158
159     def __next__(self):
160         if not self.lastpart.eof:
161             raise RuntimeError("All form parts must be read entirely")
162         if self.eof:
163             raise StopIteration()
164         self.lastpart = formpart(self)
165         self.lastpart.parsehead(self.headcs)
166         return self.lastpart
167
168 def formdata(req, onerror=Exception):
169     data = req.item(formparse)
170     if isinstance(data, Exception):
171         if onerror is Exception:
172             raise data
173         return onerror
174     return data