Make badmultipart an IOError.
[wrw.git] / wrw / form.py
index 10b10f5..d39d4dd 100644 (file)
@@ -1,43 +1,22 @@
-import cgi
+import urllib.parse
 from . import proto
 
 __all__ = ["formdata"]
 
-class formwrap(object):
-    def __init__(self, req):
-        if req.ihead.get("Content-Type") == "application/x-www-form-urlencoded":
-            self.cf = cgi.parse(environ = req.env, fp = req.env["wsgi.input"])
-        else:
-            self.cf = cgi.parse(environ = req.env)
-
-    def __getitem__(self, key):
-        return self.cf[key][0]
-
-    def get(self, key, default = ""):
-        if key in self:
-            return self.cf[key][0]
-        return default
-
-    def __contains__(self, key):
-        return key in self.cf and len(self.cf[key]) > 0
-
-    def __iter__(self):
-        return iter(self.cf)
-
-    def items(self):
-        def iter():
-            for key, list in self.cf.items():
-                for val in list:
-                    yield key, val
-        return list(iter())
-
-    def keys(self):
-        return list(self.cf.keys())
-
-    def values(self):
-        return [val for key, val in self.items()]
-
-class badmultipart(Exception):
+def formparse(req):
+    buf = {}
+    buf.update(urllib.parse.parse_qsl(req.query))
+    if req.ihead.get("Content-Type") == "application/x-www-form-urlencoded":
+        try:
+            rbody = req.input.read(2 ** 20)
+        except IOError as exc:
+            return exc
+        if len(rbody) >= 2 ** 20:
+            return ValueError("x-www-form-urlencoded data is absurdly long")
+        buf.update(urllib.parse.parse_qsl(rbody.decode("latin1")))
+    return buf
+
+class badmultipart(IOError):
     pass
 
 class formpart(object):
@@ -73,29 +52,29 @@ class formpart(object):
             if sz >= 0 and len(self.buf) >= sz:
                 break
             while len(self.form.buf) <= len(lboundary):
-                ret = req.env["wsgi.input"].read(8192)
-                if ret == "":
+                ret = req.input.read(8192)
+                if ret == b"":
                     raise badmultipart("Missing last multipart boundary")
                 self.form.buf += ret
 
-    def read(self, limit = -1):
+    def read(self, limit=-1):
         self.fillbuf(limit)
         if limit >= 0:
             ret = self.buf[:limit]
             self.buf = self.buf[limit:]
         else:
             ret = self.buf
-            self.buf = ""
+            self.buf = b""
         return ret
 
-    def readline(self, limit = -1):
+    def readline(self, limit=-1):
         last = 0
         while True:
             p = self.buf.find(b'\n', last)
             if p < 0:
                 if self.eof:
                     ret = self.buf
-                    self.buf = ""
+                    self.buf = b""
                     return ret
                 last = len(self.buf)
                 self.fillbuf(last + 128)
@@ -105,12 +84,15 @@ class formpart(object):
                 return ret
 
     def close(self):
-        self.fillbuf(-1)
+        while True:
+            if self.read(8192) == b"":
+                break
 
     def __enter__(self):
         return self
 
     def __exit__(self, *excinfo):
+        self.close()
         return False
 
     def parsehead(self, charset):
@@ -182,5 +164,10 @@ class multipart(object):
         self.lastpart.parsehead(self.headcs)
         return self.lastpart
 
-def formdata(req):
-    return req.item(formwrap)
+def formdata(req, onerror=Exception):
+    data = req.item(formparse)
+    if isinstance(data, Exception):
+        if onerror is Exception:
+            raise data
+        return onerror
+    return data