Added a, perhaps overly simplistic, multipart form handler.
[wrw.git] / wrw / form.py
1 import cgi
2 import proto
3
4 __all__ = ["formdata"]
5
6 class formwrap(object):
7     def __init__(self, req):
8         if req.ihead.get("Content-Type") == "application/x-www-form-urlencoded":
9             self.cf = cgi.parse(environ = req.env, fp = req.env["wsgi.input"])
10         else:
11             self.cf = cgi.parse(environ = req.env)
12
13     def __getitem__(self, key):
14         return self.cf[key][0]
15
16     def get(self, key, default = ""):
17         if key in self:
18             return self.cf[key][0]
19         return default
20
21     def __contains__(self, key):
22         return key in self.cf and len(self.cf[key]) > 0
23
24     def __iter__(self):
25         return iter(self.cf)
26
27     def items(self):
28         def iter():
29             for key, list in self.cf.items():
30                 for val in list:
31                     yield key, val
32         return list(iter())
33
34     def keys(self):
35         return self.cf.keys()
36
37     def values(self):
38         return [val for key, val in self.items()]
39
40 class badmultipart(Exception):
41     pass
42
43 class formpart(object):
44     def __init__(self, form):
45         self.form = form
46         self.buf = ""
47         self.eof = False
48         self.head = {}
49
50     def parsehead(self):
51         pass
52
53     def fillbuf(self, sz):
54         req = self.form.req
55         mboundary = "\r\n--" + self.form.boundary + "\r\n"
56         lboundary = "\r\n--" + self.form.boundary + "--\r\n"
57         while not self.eof:
58             p = self.form.buf.find(mboundary)
59             if p >= 0:
60                 self.buf += self.form.buf[:p]
61                 self.form.buf = self.form.buf[p + len(mboundary):]
62                 self.eof = True
63                 break
64             p = self.form.buf.find(lboundary)
65             if p >= 0:
66                 self.buf += self.form.buf[:p]
67                 self.form.buf = self.form.buf[p + len(lboundary):]
68                 self.eof = True
69                 self.form.eof = True
70                 break
71             self.buf += self.form.buf[:-len(lboundary)]
72             self.form.buf = self.form.buf[-len(lboundary):]
73             if sz >= 0 and len(self.buf) >= sz:
74                 break
75             while len(self.form.buf) <= len(lboundary):
76                 ret = req.env["wsgi.input"].read(8192)
77                 if ret == "":
78                     raise badmultipart("Missing last multipart boundary")
79                 self.form.buf += ret
80
81     def read(self, limit = -1):
82         self.fillbuf(limit)
83         if limit >= 0:
84             ret = self.buf[:limit]
85             self.buf = self.buf[limit:]
86         else:
87             ret = self.buf
88             self.buf = ""
89         return ret
90
91     def readline(self, limit = -1):
92         last = 0
93         while True:
94             p = self.buf.find('\n', last)
95             if p < 0:
96                 if self.eof:
97                     ret = self.buf
98                     self.buf = ""
99                     return ret
100                 last = len(self.buf)
101                 self.fillbuf(last + 128)
102             else:
103                 ret = self.buf[:p + 1]
104                 self.buf = self.buf[p + 1:]
105                 return ret
106
107     def close(self):
108         self.fillbuf(-1)
109
110     def __enter__(self):
111         return self
112
113     def __exit__(self, *excinfo):
114         return False
115
116     def parsehead(self):
117         def headline():
118             ln = self.readline(256)
119             if ln[-1] != '\n':
120                 raise badmultipart("Too long header line in part")
121             return ln.rstrip()
122
123         ln = headline()
124         while True:
125             if ln == "":
126                 break
127             buf = ln
128             while True:
129                 ln = headline()
130                 if not ln[1:].isspace():
131                     break
132                 buf += ln.lstrip()
133             p = buf.find(':')
134             if p < 0:
135                 raise badmultipart("Malformed multipart header line")
136             self.head[buf[:p].strip().lower()] = buf[p + 1:].lstrip()
137
138         val, par = proto.pmimehead(self.head.get("content-disposition", ""))
139         if val != "form-data":
140             raise badmultipart("Unexpected Content-Disposition in form part: %r" % val)
141         if not "name" in par:
142             raise badmultipart("Missing name in form part")
143         self.name = par["name"]
144         self.filename = par.get("filename")
145         val, par = proto.pmimehead(self.head.get("content-type", ""))
146         self.ctype = val
147         self.charset = par.get("charset")
148         encoding = self.head.get("content-transfer-encoding", "binary")
149         if encoding != "binary":
150             raise badmultipart("Form part uses unexpected transfer encoding: %r" % encoding)
151
152 class multipart(object):
153     def __init__(self, req):
154         val, par = proto.pmimehead(req.ihead.get("Content-Type", ""))
155         if req.method != "POST" or val != "multipart/form-data":
156             raise badmultipart("Request is not a multipart form")
157         if "boundary" not in par:
158             raise badmultipart("Multipart form lacks boundary")
159         self.boundary = par["boundary"]
160         self.req = req
161         self.buf = "\r\n"
162         self.eof = False
163         self.lastpart = formpart(self)
164         self.lastpart.close()
165
166     def __iter__(self):
167         return self
168
169     def next(self):
170         if not self.lastpart.eof:
171             raise RuntimeError("All form parts must be read entirely")
172         if self.eof:
173             raise StopIteration()
174         self.lastpart = formpart(self)
175         self.lastpart.parsehead()
176         return self.lastpart
177
178 def formdata(req):
179     return req.item(formwrap)