9f1e6d4020461b07f00162c112593781d1510ee3
[wrw.git] / wrw / form.py
1 import cgi
2 import proto
3
4 __all__ = ["formdata"]
5
6 class formwrap(object):
7     def __init__(self, req):
8         if req.ihead.get("Content-Type") == "application/x-www-form-urlencoded":
9             self.cf = cgi.parse(environ = req.env, fp = req.env["wsgi.input"])
10         else:
11             self.cf = cgi.parse(environ = req.env)
12
13     def __getitem__(self, key):
14         return self.cf[key][0]
15
16     def get(self, key, default = ""):
17         if key in self:
18             return self.cf[key][0]
19         return default
20
21     def __contains__(self, key):
22         return key in self.cf and len(self.cf[key]) > 0
23
24     def __iter__(self):
25         return iter(self.cf)
26
27     def items(self):
28         def iter():
29             for key, list in self.cf.items():
30                 for val in list:
31                     yield key, val
32         return list(iter())
33
34     def keys(self):
35         return self.cf.keys()
36
37     def values(self):
38         return [val for key, val in self.items()]
39
40 class badmultipart(Exception):
41     pass
42
43 class formpart(object):
44     def __init__(self, form):
45         self.form = form
46         self.buf = ""
47         self.eof = False
48         self.head = {}
49
50     def parsehead(self):
51         pass
52
53     def fillbuf(self, sz):
54         req = self.form.req
55         mboundary = "\r\n--" + self.form.boundary + "\r\n"
56         lboundary = "\r\n--" + self.form.boundary + "--\r\n"
57         while not self.eof:
58             p = self.form.buf.find(mboundary)
59             if p >= 0:
60                 self.buf += self.form.buf[:p]
61                 self.form.buf = self.form.buf[p + len(mboundary):]
62                 self.eof = True
63                 break
64             p = self.form.buf.find(lboundary)
65             if p >= 0:
66                 self.buf += self.form.buf[:p]
67                 self.form.buf = self.form.buf[p + len(lboundary):]
68                 self.eof = True
69                 self.form.eof = True
70                 break
71             self.buf += self.form.buf[:-len(lboundary)]
72             self.form.buf = self.form.buf[-len(lboundary):]
73             if sz >= 0 and len(self.buf) >= sz:
74                 break
75             while len(self.form.buf) <= len(lboundary):
76                 ret = req.env["wsgi.input"].read(8192)
77                 if ret == "":
78                     raise badmultipart("Missing last multipart boundary")
79                 self.form.buf += ret
80
81     def read(self, limit = -1):
82         self.fillbuf(limit)
83         if limit >= 0:
84             ret = self.buf[:limit]
85             self.buf = self.buf[limit:]
86         else:
87             ret = self.buf
88             self.buf = ""
89         return ret
90
91     def readline(self, limit = -1):
92         last = 0
93         while True:
94             p = self.buf.find('\n', last)
95             if p < 0:
96                 if self.eof:
97                     ret = self.buf
98                     self.buf = ""
99                     return ret
100                 last = len(self.buf)
101                 self.fillbuf(last + 128)
102             else:
103                 ret = self.buf[:p + 1]
104                 self.buf = self.buf[p + 1:]
105                 return ret
106
107     def close(self):
108         self.fillbuf(-1)
109
110     def __enter__(self):
111         return self
112
113     def __exit__(self, *excinfo):
114         self.close()
115         return False
116
117     def parsehead(self):
118         def headline():
119             ln = self.readline(256)
120             if ln[-1] != '\n':
121                 raise badmultipart("Too long header line in part")
122             return ln.rstrip()
123
124         ln = headline()
125         while True:
126             if ln == "":
127                 break
128             buf = ln
129             while True:
130                 ln = headline()
131                 if not ln[1:].isspace():
132                     break
133                 buf += ln.lstrip()
134             p = buf.find(':')
135             if p < 0:
136                 raise badmultipart("Malformed multipart header line")
137             self.head[buf[:p].strip().lower()] = buf[p + 1:].lstrip()
138
139         val, par = proto.pmimehead(self.head.get("content-disposition", ""))
140         if val != "form-data":
141             raise badmultipart("Unexpected Content-Disposition in form part: %r" % val)
142         if not "name" in par:
143             raise badmultipart("Missing name in form part")
144         self.name = par["name"]
145         self.filename = par.get("filename")
146         val, par = proto.pmimehead(self.head.get("content-type", ""))
147         self.ctype = val
148         self.charset = par.get("charset")
149         encoding = self.head.get("content-transfer-encoding", "binary")
150         if encoding != "binary":
151             raise badmultipart("Form part uses unexpected transfer encoding: %r" % encoding)
152
153 class multipart(object):
154     def __init__(self, req):
155         val, par = proto.pmimehead(req.ihead.get("Content-Type", ""))
156         if req.method != "POST" or val != "multipart/form-data":
157             raise badmultipart("Request is not a multipart form")
158         if "boundary" not in par:
159             raise badmultipart("Multipart form lacks boundary")
160         self.boundary = par["boundary"]
161         self.req = req
162         self.buf = "\r\n"
163         self.eof = False
164         self.lastpart = formpart(self)
165         self.lastpart.close()
166
167     def __iter__(self):
168         return self
169
170     def next(self):
171         if not self.lastpart.eof:
172             raise RuntimeError("All form parts must be read entirely")
173         if self.eof:
174             raise StopIteration()
175         self.lastpart = formpart(self)
176         self.lastpart.parsehead()
177         return self.lastpart
178
179 def formdata(req):
180     return req.item(formwrap)