Merge branch 'python3' of git.dolda2000.com:/srv/git/r/wrw into python3
[wrw.git] / wrw / form.py
1 import cgi
2 from . import proto
3
4 __all__ = ["formdata"]
5
6 class formwrap(object):
7     def __init__(self, req):
8         if req.ihead.get("Content-Type") == "application/x-www-form-urlencoded":
9             self.cf = cgi.parse(environ = req.env, fp = req.input)
10         else:
11             self.cf = cgi.parse(environ = req.env)
12
13     def __getitem__(self, key):
14         return self.cf[key][0]
15
16     def get(self, key, default = ""):
17         if key in self:
18             return self.cf[key][0]
19         return default
20
21     def __contains__(self, key):
22         return key in self.cf and len(self.cf[key]) > 0
23
24     def __iter__(self):
25         return iter(self.cf)
26
27     def items(self):
28         def iter():
29             for key, list in self.cf.items():
30                 for val in list:
31                     yield key, val
32         return list(iter())
33
34     def keys(self):
35         return list(self.cf.keys())
36
37     def values(self):
38         return [val for key, val in self.items()]
39
40 class badmultipart(Exception):
41     pass
42
43 class formpart(object):
44     def __init__(self, form):
45         self.form = form
46         self.buf = b""
47         self.eof = False
48         self.head = {}
49
50     def parsehead(self):
51         pass
52
53     def fillbuf(self, sz):
54         req = self.form.req
55         mboundary = b"\r\n--" + self.form.boundary + b"\r\n"
56         lboundary = b"\r\n--" + self.form.boundary + b"--\r\n"
57         while not self.eof:
58             p = self.form.buf.find(mboundary)
59             if p >= 0:
60                 self.buf += self.form.buf[:p]
61                 self.form.buf = self.form.buf[p + len(mboundary):]
62                 self.eof = True
63                 break
64             p = self.form.buf.find(lboundary)
65             if p >= 0:
66                 self.buf += self.form.buf[:p]
67                 self.form.buf = self.form.buf[p + len(lboundary):]
68                 self.eof = True
69                 self.form.eof = True
70                 break
71             self.buf += self.form.buf[:-len(lboundary)]
72             self.form.buf = self.form.buf[-len(lboundary):]
73             if sz >= 0 and len(self.buf) >= sz:
74                 break
75             while len(self.form.buf) <= len(lboundary):
76                 ret = req.input.read(8192)
77                 if ret == "":
78                     raise badmultipart("Missing last multipart boundary")
79                 self.form.buf += ret
80
81     def read(self, limit = -1):
82         self.fillbuf(limit)
83         if limit >= 0:
84             ret = self.buf[:limit]
85             self.buf = self.buf[limit:]
86         else:
87             ret = self.buf
88             self.buf = ""
89         return ret
90
91     def readline(self, limit = -1):
92         last = 0
93         while True:
94             p = self.buf.find(b'\n', last)
95             if p < 0:
96                 if self.eof:
97                     ret = self.buf
98                     self.buf = ""
99                     return ret
100                 last = len(self.buf)
101                 self.fillbuf(last + 128)
102             else:
103                 ret = self.buf[:p + 1]
104                 self.buf = self.buf[p + 1:]
105                 return ret
106
107     def close(self):
108         self.fillbuf(-1)
109
110     def __enter__(self):
111         return self
112
113     def __exit__(self, *excinfo):
114         self.close()
115         return False
116
117     def parsehead(self, charset):
118         def headline():
119             ln = self.readline(256)
120             if ln[-1] != ord(b'\n'):
121                 raise badmultipart("Too long header line in part")
122             try:
123                 return ln.decode(charset).rstrip()
124             except UnicodeError:
125                 raise badmultipart("Form part header is not in assumed charset")
126
127         ln = headline()
128         while True:
129             if ln == "":
130                 break
131             buf = ln
132             while True:
133                 ln = headline()
134                 if not ln[1:].isspace():
135                     break
136                 buf += ln.lstrip()
137             p = buf.find(':')
138             if p < 0:
139                 raise badmultipart("Malformed multipart header line")
140             self.head[buf[:p].strip().lower()] = buf[p + 1:].lstrip()
141
142         val, par = proto.pmimehead(self.head.get("content-disposition", ""))
143         if val != "form-data":
144             raise badmultipart("Unexpected Content-Disposition in form part: %r" % val)
145         if not "name" in par:
146             raise badmultipart("Missing name in form part")
147         self.name = par["name"]
148         self.filename = par.get("filename")
149         val, par = proto.pmimehead(self.head.get("content-type", ""))
150         self.ctype = val
151         self.charset = par.get("charset")
152         encoding = self.head.get("content-transfer-encoding", "binary")
153         if encoding != "binary":
154             raise badmultipart("Form part uses unexpected transfer encoding: %r" % encoding)
155
156 class multipart(object):
157     def __init__(self, req, charset):
158         val, par = proto.pmimehead(req.ihead.get("Content-Type", ""))
159         if req.method != "POST" or val != "multipart/form-data":
160             raise badmultipart("Request is not a multipart form")
161         if "boundary" not in par:
162             raise badmultipart("Multipart form lacks boundary")
163         try:
164             self.boundary = par["boundary"].encode("us-ascii")
165         except UnicodeError:
166             raise badmultipart("Multipart boundary must be ASCII string")
167         self.req = req
168         self.buf = b"\r\n"
169         self.eof = False
170         self.headcs = charset
171         self.lastpart = formpart(self)
172         self.lastpart.close()
173
174     def __iter__(self):
175         return self
176
177     def __next__(self):
178         if not self.lastpart.eof:
179             raise RuntimeError("All form parts must be read entirely")
180         if self.eof:
181             raise StopIteration()
182         self.lastpart = formpart(self)
183         self.lastpart.parsehead(self.headcs)
184         return self.lastpart
185
186 def formdata(req):
187     return req.item(formwrap)