-import cgi
+import urllib.parse
from . import proto
__all__ = ["formdata"]
-class formwrap(object):
- def __init__(self, req):
- if req.ihead.get("Content-Type") == "application/x-www-form-urlencoded":
- self.cf = cgi.parse(environ = req.env, fp = req.env["wsgi.input"])
- else:
- self.cf = cgi.parse(environ = req.env)
-
- def __getitem__(self, key):
- return self.cf[key][0]
-
- def get(self, key, default = ""):
- if key in self:
- return self.cf[key][0]
- return default
-
- def __contains__(self, key):
- return key in self.cf and len(self.cf[key]) > 0
-
- def __iter__(self):
- return iter(self.cf)
-
- def items(self):
- def iter():
- for key, list in self.cf.items():
- for val in list:
- yield key, val
- return list(iter())
-
- def keys(self):
- return list(self.cf.keys())
-
- def values(self):
- return [val for key, val in self.items()]
-
-class badmultipart(Exception):
+def formparse(req):
+ buf = {}
+ buf.update(urllib.parse.parse_qsl(req.query, keep_blank_values=True))
+ ctype, ctpars = proto.pmimehead(req.ihead.get("Content-Type", ""))
+ if ctype == "application/x-www-form-urlencoded":
+ try:
+ rbody = req.input.read(2 ** 20)
+ except IOError as exc:
+ return exc
+ if len(rbody) >= 2 ** 20:
+ return ValueError("x-www-form-urlencoded data is absurdly long")
+ buf.update(urllib.parse.parse_qsl(rbody.decode("latin1"), encoding=ctpars.get("charset", "utf-8"), keep_blank_values=True))
+ return buf
+
+class badmultipart(IOError):
pass
class formpart(object):
def __init__(self, form):
self.form = form
- self.buf = ""
+ self.buf = b""
self.eof = False
self.head = {}
def fillbuf(self, sz):
req = self.form.req
- mboundary = "\r\n--" + self.form.boundary + "\r\n"
- lboundary = "\r\n--" + self.form.boundary + "--\r\n"
+ mboundary = b"\r\n--" + self.form.boundary + b"\r\n"
+ lboundary = b"\r\n--" + self.form.boundary + b"--\r\n"
while not self.eof:
p = self.form.buf.find(mboundary)
if p >= 0:
if sz >= 0 and len(self.buf) >= sz:
break
while len(self.form.buf) <= len(lboundary):
- ret = req.env["wsgi.input"].read(8192)
- if ret == "":
+ ret = req.input.read(8192)
+ if ret == b"":
raise badmultipart("Missing last multipart boundary")
self.form.buf += ret
- def read(self, limit = -1):
+ def read(self, limit=-1):
self.fillbuf(limit)
if limit >= 0:
ret = self.buf[:limit]
self.buf = self.buf[limit:]
else:
ret = self.buf
- self.buf = ""
+ self.buf = b""
return ret
- def readline(self, limit = -1):
+ def readline(self, limit=-1):
last = 0
while True:
- p = self.buf.find('\n', last)
+ p = self.buf.find(b'\n', last)
if p < 0:
if self.eof:
ret = self.buf
- self.buf = ""
+ self.buf = b""
return ret
last = len(self.buf)
self.fillbuf(last + 128)
return ret
def close(self):
- self.fillbuf(-1)
+ while True:
+ if self.read(8192) == b"":
+ break
def __enter__(self):
return self
def __exit__(self, *excinfo):
+ self.close()
return False
- def parsehead(self):
+ def parsehead(self, charset):
def headline():
ln = self.readline(256)
- if ln[-1] != '\n':
+ if ln[-1] != ord(b'\n'):
raise badmultipart("Too long header line in part")
- return ln.rstrip()
+ try:
+ return ln.decode(charset).rstrip()
+ except UnicodeError:
+ raise badmultipart("Form part header is not in assumed charset")
ln = headline()
while True:
raise badmultipart("Form part uses unexpected transfer encoding: %r" % encoding)
class multipart(object):
- def __init__(self, req):
+ def __init__(self, req, charset):
val, par = proto.pmimehead(req.ihead.get("Content-Type", ""))
if req.method != "POST" or val != "multipart/form-data":
raise badmultipart("Request is not a multipart form")
if "boundary" not in par:
raise badmultipart("Multipart form lacks boundary")
- self.boundary = par["boundary"]
+ try:
+ self.boundary = par["boundary"].encode("us-ascii")
+ except UnicodeError:
+ raise badmultipart("Multipart boundary must be ASCII string")
self.req = req
- self.buf = "\r\n"
+ self.buf = b"\r\n"
self.eof = False
+ self.headcs = charset
self.lastpart = formpart(self)
self.lastpart.close()
def __iter__(self):
return self
- def next(self):
+ def __next__(self):
if not self.lastpart.eof:
raise RuntimeError("All form parts must be read entirely")
if self.eof:
raise StopIteration()
self.lastpart = formpart(self)
- self.lastpart.parsehead()
+ self.lastpart.parsehead(self.headcs)
return self.lastpart
-def formdata(req):
- return req.item(formwrap)
+def formdata(req, onerror=Exception):
+ data = req.item(formparse)
+ if isinstance(data, Exception):
+ if onerror is Exception:
+ raise data
+ return onerror
+ return data