Converted the multipart form reader to Python3.
[wrw.git] / wrw / form.py
1 import cgi
2 from . import proto
3
4 __all__ = ["formdata"]
5
6 class formwrap(object):
7     def __init__(self, req):
8         if req.ihead.get("Content-Type") == "application/x-www-form-urlencoded":
9             self.cf = cgi.parse(environ = req.env, fp = req.env["wsgi.input"])
10         else:
11             self.cf = cgi.parse(environ = req.env)
12
13     def __getitem__(self, key):
14         return self.cf[key][0]
15
16     def get(self, key, default = ""):
17         if key in self:
18             return self.cf[key][0]
19         return default
20
21     def __contains__(self, key):
22         return key in self.cf and len(self.cf[key]) > 0
23
24     def __iter__(self):
25         return iter(self.cf)
26
27     def items(self):
28         def iter():
29             for key, list in self.cf.items():
30                 for val in list:
31                     yield key, val
32         return list(iter())
33
34     def keys(self):
35         return list(self.cf.keys())
36
37     def values(self):
38         return [val for key, val in self.items()]
39
40 class badmultipart(Exception):
41     pass
42
43 class formpart(object):
44     def __init__(self, form):
45         self.form = form
46         self.buf = b""
47         self.eof = False
48         self.head = {}
49
50     def parsehead(self):
51         pass
52
53     def fillbuf(self, sz):
54         req = self.form.req
55         mboundary = b"\r\n--" + self.form.boundary + b"\r\n"
56         lboundary = b"\r\n--" + self.form.boundary + b"--\r\n"
57         while not self.eof:
58             p = self.form.buf.find(mboundary)
59             if p >= 0:
60                 self.buf += self.form.buf[:p]
61                 self.form.buf = self.form.buf[p + len(mboundary):]
62                 self.eof = True
63                 break
64             p = self.form.buf.find(lboundary)
65             if p >= 0:
66                 self.buf += self.form.buf[:p]
67                 self.form.buf = self.form.buf[p + len(lboundary):]
68                 self.eof = True
69                 self.form.eof = True
70                 break
71             self.buf += self.form.buf[:-len(lboundary)]
72             self.form.buf = self.form.buf[-len(lboundary):]
73             if sz >= 0 and len(self.buf) >= sz:
74                 break
75             while len(self.form.buf) <= len(lboundary):
76                 ret = req.env["wsgi.input"].read(8192)
77                 if ret == "":
78                     raise badmultipart("Missing last multipart boundary")
79                 self.form.buf += ret
80
81     def read(self, limit = -1):
82         self.fillbuf(limit)
83         if limit >= 0:
84             ret = self.buf[:limit]
85             self.buf = self.buf[limit:]
86         else:
87             ret = self.buf
88             self.buf = ""
89         return ret
90
91     def readline(self, limit = -1):
92         last = 0
93         while True:
94             p = self.buf.find(b'\n', last)
95             if p < 0:
96                 if self.eof:
97                     ret = self.buf
98                     self.buf = ""
99                     return ret
100                 last = len(self.buf)
101                 self.fillbuf(last + 128)
102             else:
103                 ret = self.buf[:p + 1]
104                 self.buf = self.buf[p + 1:]
105                 return ret
106
107     def close(self):
108         self.fillbuf(-1)
109
110     def __enter__(self):
111         return self
112
113     def __exit__(self, *excinfo):
114         return False
115
116     def parsehead(self, charset):
117         def headline():
118             ln = self.readline(256)
119             if ln[-1] != ord(b'\n'):
120                 raise badmultipart("Too long header line in part")
121             try:
122                 return ln.decode(charset).rstrip()
123             except UnicodeError:
124                 raise badmultipart("Form part header is not in assumed charset")
125
126         ln = headline()
127         while True:
128             if ln == "":
129                 break
130             buf = ln
131             while True:
132                 ln = headline()
133                 if not ln[1:].isspace():
134                     break
135                 buf += ln.lstrip()
136             p = buf.find(':')
137             if p < 0:
138                 raise badmultipart("Malformed multipart header line")
139             self.head[buf[:p].strip().lower()] = buf[p + 1:].lstrip()
140
141         val, par = proto.pmimehead(self.head.get("content-disposition", ""))
142         if val != "form-data":
143             raise badmultipart("Unexpected Content-Disposition in form part: %r" % val)
144         if not "name" in par:
145             raise badmultipart("Missing name in form part")
146         self.name = par["name"]
147         self.filename = par.get("filename")
148         val, par = proto.pmimehead(self.head.get("content-type", ""))
149         self.ctype = val
150         self.charset = par.get("charset")
151         encoding = self.head.get("content-transfer-encoding", "binary")
152         if encoding != "binary":
153             raise badmultipart("Form part uses unexpected transfer encoding: %r" % encoding)
154
155 class multipart(object):
156     def __init__(self, req, charset):
157         val, par = proto.pmimehead(req.ihead.get("Content-Type", ""))
158         if req.method != "POST" or val != "multipart/form-data":
159             raise badmultipart("Request is not a multipart form")
160         if "boundary" not in par:
161             raise badmultipart("Multipart form lacks boundary")
162         try:
163             self.boundary = par["boundary"].encode("us-ascii")
164         except UnicodeError:
165             raise badmultipart("Multipart boundary must be ASCII string")
166         self.req = req
167         self.buf = b"\r\n"
168         self.eof = False
169         self.headcs = charset
170         self.lastpart = formpart(self)
171         self.lastpart.close()
172
173     def __iter__(self):
174         return self
175
176     def __next__(self):
177         if not self.lastpart.eof:
178             raise RuntimeError("All form parts must be read entirely")
179         if self.eof:
180             raise StopIteration()
181         self.lastpart = formpart(self)
182         self.lastpart.parsehead(self.headcs)
183         return self.lastpart
184
185 def formdata(req):
186     return req.item(formwrap)