Make badmultipart an IOError.
[wrw.git] / wrw / form.py
index 3bad8fc..d39d4dd 100644 (file)
@@ -1,25 +1,28 @@
-import urlparse
-import proto
+import urllib.parse
+from . import proto
 
 __all__ = ["formdata"]
 
 def formparse(req):
     buf = {}
-    buf.update(urlparse.parse_qsl(req.query))
+    buf.update(urllib.parse.parse_qsl(req.query))
     if req.ihead.get("Content-Type") == "application/x-www-form-urlencoded":
-        if req.input.limit > 2 ** 20:
-            raise ValueError("x-www-form-urlencoded data is absurdly long")
-        rbody = req.input.read()
-        buf.update(urlparse.parse_qsl(rbody))
+        try:
+            rbody = req.input.read(2 ** 20)
+        except IOError as exc:
+            return exc
+        if len(rbody) >= 2 ** 20:
+            return ValueError("x-www-form-urlencoded data is absurdly long")
+        buf.update(urllib.parse.parse_qsl(rbody.decode("latin1")))
     return buf
 
-class badmultipart(Exception):
+class badmultipart(IOError):
     pass
 
 class formpart(object):
     def __init__(self, form):
         self.form = form
-        self.buf = ""
+        self.buf = b""
         self.eof = False
         self.head = {}
 
@@ -28,8 +31,8 @@ class formpart(object):
 
     def fillbuf(self, sz):
         req = self.form.req
-        mboundary = "\r\n--" + self.form.boundary + "\r\n"
-        lboundary = "\r\n--" + self.form.boundary + "--\r\n"
+        mboundary = b"\r\n--" + self.form.boundary + b"\r\n"
+        lboundary = b"\r\n--" + self.form.boundary + b"--\r\n"
         while not self.eof:
             p = self.form.buf.find(mboundary)
             if p >= 0:
@@ -50,7 +53,7 @@ class formpart(object):
                 break
             while len(self.form.buf) <= len(lboundary):
                 ret = req.input.read(8192)
-                if ret == "":
+                if ret == b"":
                     raise badmultipart("Missing last multipart boundary")
                 self.form.buf += ret
 
@@ -61,17 +64,17 @@ class formpart(object):
             self.buf = self.buf[limit:]
         else:
             ret = self.buf
-            self.buf = ""
+            self.buf = b""
         return ret
 
     def readline(self, limit=-1):
         last = 0
         while True:
-            p = self.buf.find('\n', last)
+            p = self.buf.find(b'\n', last)
             if p < 0:
                 if self.eof:
                     ret = self.buf
-                    self.buf = ""
+                    self.buf = b""
                     return ret
                 last = len(self.buf)
                 self.fillbuf(last + 128)
@@ -81,7 +84,9 @@ class formpart(object):
                 return ret
 
     def close(self):
-        self.fillbuf(-1)
+        while True:
+            if self.read(8192) == b"":
+                break
 
     def __enter__(self):
         return self
@@ -90,12 +95,15 @@ class formpart(object):
         self.close()
         return False
 
-    def parsehead(self):
+    def parsehead(self, charset):
         def headline():
             ln = self.readline(256)
-            if ln[-1] != '\n':
+            if ln[-1] != ord(b'\n'):
                 raise badmultipart("Too long header line in part")
-            return ln.rstrip()
+            try:
+                return ln.decode(charset).rstrip()
+            except UnicodeError:
+                raise badmultipart("Form part header is not in assumed charset")
 
         ln = headline()
         while True:
@@ -127,30 +135,39 @@ class formpart(object):
             raise badmultipart("Form part uses unexpected transfer encoding: %r" % encoding)
 
 class multipart(object):
-    def __init__(self, req):
+    def __init__(self, req, charset):
         val, par = proto.pmimehead(req.ihead.get("Content-Type", ""))
         if req.method != "POST" or val != "multipart/form-data":
             raise badmultipart("Request is not a multipart form")
         if "boundary" not in par:
             raise badmultipart("Multipart form lacks boundary")
-        self.boundary = par["boundary"]
+        try:
+            self.boundary = par["boundary"].encode("us-ascii")
+        except UnicodeError:
+            raise badmultipart("Multipart boundary must be ASCII string")
         self.req = req
-        self.buf = "\r\n"
+        self.buf = b"\r\n"
         self.eof = False
+        self.headcs = charset
         self.lastpart = formpart(self)
         self.lastpart.close()
 
     def __iter__(self):
         return self
 
-    def next(self):
+    def __next__(self):
         if not self.lastpart.eof:
             raise RuntimeError("All form parts must be read entirely")
         if self.eof:
             raise StopIteration()
         self.lastpart = formpart(self)
-        self.lastpart.parsehead()
+        self.lastpart.parsehead(self.headcs)
         return self.lastpart
 
-def formdata(req):
-    return req.item(formparse)
+def formdata(req, onerror=Exception):
+    data = req.item(formparse)
+    if isinstance(data, Exception):
+        if onerror is Exception:
+            raise data
+        return onerror
+    return data