e89901d0221c95432740b20251630b1b45c13918
[wrw.git] / wrw / form.py
1 import urllib.parse
2 from . import proto
3
4 __all__ = ["formdata"]
5
6 def formparse(req):
7     buf = {}
8     buf.update(urllib.parse.parse_qsl(req.query))
9     if req.ihead.get("Content-Type") == "application/x-www-form-urlencoded":
10         rbody = req.input.read(2 ** 20)
11         if len(rbody) >= 2 ** 20:
12             raise ValueError("x-www-form-urlencoded data is absurdly long")
13         buf.update(urllib.parse.parse_qsl(rbody.decode("latin1")))
14     return buf
15
16 class badmultipart(Exception):
17     pass
18
19 class formpart(object):
20     def __init__(self, form):
21         self.form = form
22         self.buf = b""
23         self.eof = False
24         self.head = {}
25
26     def parsehead(self):
27         pass
28
29     def fillbuf(self, sz):
30         req = self.form.req
31         mboundary = b"\r\n--" + self.form.boundary + b"\r\n"
32         lboundary = b"\r\n--" + self.form.boundary + b"--\r\n"
33         while not self.eof:
34             p = self.form.buf.find(mboundary)
35             if p >= 0:
36                 self.buf += self.form.buf[:p]
37                 self.form.buf = self.form.buf[p + len(mboundary):]
38                 self.eof = True
39                 break
40             p = self.form.buf.find(lboundary)
41             if p >= 0:
42                 self.buf += self.form.buf[:p]
43                 self.form.buf = self.form.buf[p + len(lboundary):]
44                 self.eof = True
45                 self.form.eof = True
46                 break
47             self.buf += self.form.buf[:-len(lboundary)]
48             self.form.buf = self.form.buf[-len(lboundary):]
49             if sz >= 0 and len(self.buf) >= sz:
50                 break
51             while len(self.form.buf) <= len(lboundary):
52                 ret = req.input.read(8192)
53                 if ret == b"":
54                     raise badmultipart("Missing last multipart boundary")
55                 self.form.buf += ret
56
57     def read(self, limit=-1):
58         self.fillbuf(limit)
59         if limit >= 0:
60             ret = self.buf[:limit]
61             self.buf = self.buf[limit:]
62         else:
63             ret = self.buf
64             self.buf = b""
65         return ret
66
67     def readline(self, limit=-1):
68         last = 0
69         while True:
70             p = self.buf.find(b'\n', last)
71             if p < 0:
72                 if self.eof:
73                     ret = self.buf
74                     self.buf = b""
75                     return ret
76                 last = len(self.buf)
77                 self.fillbuf(last + 128)
78             else:
79                 ret = self.buf[:p + 1]
80                 self.buf = self.buf[p + 1:]
81                 return ret
82
83     def close(self):
84         while True:
85             if self.read(8192) == b"":
86                 break
87
88     def __enter__(self):
89         return self
90
91     def __exit__(self, *excinfo):
92         self.close()
93         return False
94
95     def parsehead(self, charset):
96         def headline():
97             ln = self.readline(256)
98             if ln[-1] != ord(b'\n'):
99                 raise badmultipart("Too long header line in part")
100             try:
101                 return ln.decode(charset).rstrip()
102             except UnicodeError:
103                 raise badmultipart("Form part header is not in assumed charset")
104
105         ln = headline()
106         while True:
107             if ln == "":
108                 break
109             buf = ln
110             while True:
111                 ln = headline()
112                 if not ln[1:].isspace():
113                     break
114                 buf += ln.lstrip()
115             p = buf.find(':')
116             if p < 0:
117                 raise badmultipart("Malformed multipart header line")
118             self.head[buf[:p].strip().lower()] = buf[p + 1:].lstrip()
119
120         val, par = proto.pmimehead(self.head.get("content-disposition", ""))
121         if val != "form-data":
122             raise badmultipart("Unexpected Content-Disposition in form part: %r" % val)
123         if not "name" in par:
124             raise badmultipart("Missing name in form part")
125         self.name = par["name"]
126         self.filename = par.get("filename")
127         val, par = proto.pmimehead(self.head.get("content-type", ""))
128         self.ctype = val
129         self.charset = par.get("charset")
130         encoding = self.head.get("content-transfer-encoding", "binary")
131         if encoding != "binary":
132             raise badmultipart("Form part uses unexpected transfer encoding: %r" % encoding)
133
134 class multipart(object):
135     def __init__(self, req, charset):
136         val, par = proto.pmimehead(req.ihead.get("Content-Type", ""))
137         if req.method != "POST" or val != "multipart/form-data":
138             raise badmultipart("Request is not a multipart form")
139         if "boundary" not in par:
140             raise badmultipart("Multipart form lacks boundary")
141         try:
142             self.boundary = par["boundary"].encode("us-ascii")
143         except UnicodeError:
144             raise badmultipart("Multipart boundary must be ASCII string")
145         self.req = req
146         self.buf = b"\r\n"
147         self.eof = False
148         self.headcs = charset
149         self.lastpart = formpart(self)
150         self.lastpart.close()
151
152     def __iter__(self):
153         return self
154
155     def __next__(self):
156         if not self.lastpart.eof:
157             raise RuntimeError("All form parts must be read entirely")
158         if self.eof:
159             raise StopIteration()
160         self.lastpart = formpart(self)
161         self.lastpart.parsehead(self.headcs)
162         return self.lastpart
163
164 def formdata(req):
165     return req.item(formparse)