Make badmultipart an IOError.
[wrw.git] / wrw / form.py
1 import urllib.parse
2 from . import proto
3
4 __all__ = ["formdata"]
5
6 def formparse(req):
7     buf = {}
8     buf.update(urllib.parse.parse_qsl(req.query))
9     if req.ihead.get("Content-Type") == "application/x-www-form-urlencoded":
10         try:
11             rbody = req.input.read(2 ** 20)
12         except IOError as exc:
13             return exc
14         if len(rbody) >= 2 ** 20:
15             return ValueError("x-www-form-urlencoded data is absurdly long")
16         buf.update(urllib.parse.parse_qsl(rbody.decode("latin1")))
17     return buf
18
19 class badmultipart(IOError):
20     pass
21
22 class formpart(object):
23     def __init__(self, form):
24         self.form = form
25         self.buf = b""
26         self.eof = False
27         self.head = {}
28
29     def parsehead(self):
30         pass
31
32     def fillbuf(self, sz):
33         req = self.form.req
34         mboundary = b"\r\n--" + self.form.boundary + b"\r\n"
35         lboundary = b"\r\n--" + self.form.boundary + b"--\r\n"
36         while not self.eof:
37             p = self.form.buf.find(mboundary)
38             if p >= 0:
39                 self.buf += self.form.buf[:p]
40                 self.form.buf = self.form.buf[p + len(mboundary):]
41                 self.eof = True
42                 break
43             p = self.form.buf.find(lboundary)
44             if p >= 0:
45                 self.buf += self.form.buf[:p]
46                 self.form.buf = self.form.buf[p + len(lboundary):]
47                 self.eof = True
48                 self.form.eof = True
49                 break
50             self.buf += self.form.buf[:-len(lboundary)]
51             self.form.buf = self.form.buf[-len(lboundary):]
52             if sz >= 0 and len(self.buf) >= sz:
53                 break
54             while len(self.form.buf) <= len(lboundary):
55                 ret = req.input.read(8192)
56                 if ret == b"":
57                     raise badmultipart("Missing last multipart boundary")
58                 self.form.buf += ret
59
60     def read(self, limit=-1):
61         self.fillbuf(limit)
62         if limit >= 0:
63             ret = self.buf[:limit]
64             self.buf = self.buf[limit:]
65         else:
66             ret = self.buf
67             self.buf = b""
68         return ret
69
70     def readline(self, limit=-1):
71         last = 0
72         while True:
73             p = self.buf.find(b'\n', last)
74             if p < 0:
75                 if self.eof:
76                     ret = self.buf
77                     self.buf = b""
78                     return ret
79                 last = len(self.buf)
80                 self.fillbuf(last + 128)
81             else:
82                 ret = self.buf[:p + 1]
83                 self.buf = self.buf[p + 1:]
84                 return ret
85
86     def close(self):
87         while True:
88             if self.read(8192) == b"":
89                 break
90
91     def __enter__(self):
92         return self
93
94     def __exit__(self, *excinfo):
95         self.close()
96         return False
97
98     def parsehead(self, charset):
99         def headline():
100             ln = self.readline(256)
101             if ln[-1] != ord(b'\n'):
102                 raise badmultipart("Too long header line in part")
103             try:
104                 return ln.decode(charset).rstrip()
105             except UnicodeError:
106                 raise badmultipart("Form part header is not in assumed charset")
107
108         ln = headline()
109         while True:
110             if ln == "":
111                 break
112             buf = ln
113             while True:
114                 ln = headline()
115                 if not ln[1:].isspace():
116                     break
117                 buf += ln.lstrip()
118             p = buf.find(':')
119             if p < 0:
120                 raise badmultipart("Malformed multipart header line")
121             self.head[buf[:p].strip().lower()] = buf[p + 1:].lstrip()
122
123         val, par = proto.pmimehead(self.head.get("content-disposition", ""))
124         if val != "form-data":
125             raise badmultipart("Unexpected Content-Disposition in form part: %r" % val)
126         if not "name" in par:
127             raise badmultipart("Missing name in form part")
128         self.name = par["name"]
129         self.filename = par.get("filename")
130         val, par = proto.pmimehead(self.head.get("content-type", ""))
131         self.ctype = val
132         self.charset = par.get("charset")
133         encoding = self.head.get("content-transfer-encoding", "binary")
134         if encoding != "binary":
135             raise badmultipart("Form part uses unexpected transfer encoding: %r" % encoding)
136
137 class multipart(object):
138     def __init__(self, req, charset):
139         val, par = proto.pmimehead(req.ihead.get("Content-Type", ""))
140         if req.method != "POST" or val != "multipart/form-data":
141             raise badmultipart("Request is not a multipart form")
142         if "boundary" not in par:
143             raise badmultipart("Multipart form lacks boundary")
144         try:
145             self.boundary = par["boundary"].encode("us-ascii")
146         except UnicodeError:
147             raise badmultipart("Multipart boundary must be ASCII string")
148         self.req = req
149         self.buf = b"\r\n"
150         self.eof = False
151         self.headcs = charset
152         self.lastpart = formpart(self)
153         self.lastpart.close()
154
155     def __iter__(self):
156         return self
157
158     def __next__(self):
159         if not self.lastpart.eof:
160             raise RuntimeError("All form parts must be read entirely")
161         if self.eof:
162             raise StopIteration()
163         self.lastpart = formpart(self)
164         self.lastpart.parsehead(self.headcs)
165         return self.lastpart
166
167 def formdata(req, onerror=Exception):
168     data = req.item(formparse)
169     if isinstance(data, Exception):
170         if onerror is Exception:
171             raise data
172         return onerror
173     return data