Decode binary form input.
[wrw.git] / wrw / form.py
1 import urllib.parse
2 from . import proto
3
4 __all__ = ["formdata"]
5
6 def formparse(req):
7     buf = {}
8     buf.update(urllib.parse.parse_qsl(req.query))
9     if req.ihead.get("Content-Type") == "application/x-www-form-urlencoded":
10         if req.input.limit > 2 ** 20:
11             raise ValueError("x-www-form-urlencoded data is absurdly long")
12         rbody = req.input.read()
13         buf.update(urllib.parse.parse_qsl(rbody.decode("latin1")))
14     return buf
15
16 class badmultipart(Exception):
17     pass
18
19 class formpart(object):
20     def __init__(self, form):
21         self.form = form
22         self.buf = b""
23         self.eof = False
24         self.head = {}
25
26     def parsehead(self):
27         pass
28
29     def fillbuf(self, sz):
30         req = self.form.req
31         mboundary = b"\r\n--" + self.form.boundary + b"\r\n"
32         lboundary = b"\r\n--" + self.form.boundary + b"--\r\n"
33         while not self.eof:
34             p = self.form.buf.find(mboundary)
35             if p >= 0:
36                 self.buf += self.form.buf[:p]
37                 self.form.buf = self.form.buf[p + len(mboundary):]
38                 self.eof = True
39                 break
40             p = self.form.buf.find(lboundary)
41             if p >= 0:
42                 self.buf += self.form.buf[:p]
43                 self.form.buf = self.form.buf[p + len(lboundary):]
44                 self.eof = True
45                 self.form.eof = True
46                 break
47             self.buf += self.form.buf[:-len(lboundary)]
48             self.form.buf = self.form.buf[-len(lboundary):]
49             if sz >= 0 and len(self.buf) >= sz:
50                 break
51             while len(self.form.buf) <= len(lboundary):
52                 ret = req.input.read(8192)
53                 if ret == "":
54                     raise badmultipart("Missing last multipart boundary")
55                 self.form.buf += ret
56
57     def read(self, limit=-1):
58         self.fillbuf(limit)
59         if limit >= 0:
60             ret = self.buf[:limit]
61             self.buf = self.buf[limit:]
62         else:
63             ret = self.buf
64             self.buf = ""
65         return ret
66
67     def readline(self, limit=-1):
68         last = 0
69         while True:
70             p = self.buf.find(b'\n', last)
71             if p < 0:
72                 if self.eof:
73                     ret = self.buf
74                     self.buf = ""
75                     return ret
76                 last = len(self.buf)
77                 self.fillbuf(last + 128)
78             else:
79                 ret = self.buf[:p + 1]
80                 self.buf = self.buf[p + 1:]
81                 return ret
82
83     def close(self):
84         self.fillbuf(-1)
85
86     def __enter__(self):
87         return self
88
89     def __exit__(self, *excinfo):
90         self.close()
91         return False
92
93     def parsehead(self, charset):
94         def headline():
95             ln = self.readline(256)
96             if ln[-1] != ord(b'\n'):
97                 raise badmultipart("Too long header line in part")
98             try:
99                 return ln.decode(charset).rstrip()
100             except UnicodeError:
101                 raise badmultipart("Form part header is not in assumed charset")
102
103         ln = headline()
104         while True:
105             if ln == "":
106                 break
107             buf = ln
108             while True:
109                 ln = headline()
110                 if not ln[1:].isspace():
111                     break
112                 buf += ln.lstrip()
113             p = buf.find(':')
114             if p < 0:
115                 raise badmultipart("Malformed multipart header line")
116             self.head[buf[:p].strip().lower()] = buf[p + 1:].lstrip()
117
118         val, par = proto.pmimehead(self.head.get("content-disposition", ""))
119         if val != "form-data":
120             raise badmultipart("Unexpected Content-Disposition in form part: %r" % val)
121         if not "name" in par:
122             raise badmultipart("Missing name in form part")
123         self.name = par["name"]
124         self.filename = par.get("filename")
125         val, par = proto.pmimehead(self.head.get("content-type", ""))
126         self.ctype = val
127         self.charset = par.get("charset")
128         encoding = self.head.get("content-transfer-encoding", "binary")
129         if encoding != "binary":
130             raise badmultipart("Form part uses unexpected transfer encoding: %r" % encoding)
131
132 class multipart(object):
133     def __init__(self, req, charset):
134         val, par = proto.pmimehead(req.ihead.get("Content-Type", ""))
135         if req.method != "POST" or val != "multipart/form-data":
136             raise badmultipart("Request is not a multipart form")
137         if "boundary" not in par:
138             raise badmultipart("Multipart form lacks boundary")
139         try:
140             self.boundary = par["boundary"].encode("us-ascii")
141         except UnicodeError:
142             raise badmultipart("Multipart boundary must be ASCII string")
143         self.req = req
144         self.buf = b"\r\n"
145         self.eof = False
146         self.headcs = charset
147         self.lastpart = formpart(self)
148         self.lastpart.close()
149
150     def __iter__(self):
151         return self
152
153     def __next__(self):
154         if not self.lastpart.eof:
155             raise RuntimeError("All form parts must be read entirely")
156         if self.eof:
157             raise StopIteration()
158         self.lastpart = formpart(self)
159         self.lastpart.parsehead(self.headcs)
160         return self.lastpart
161
162 def formdata(req):
163     return req.item(formparse)