Classify trunacted input as its own exception type.
[wrw.git] / wrw / req.py
index 4b9dd71..010e907 100644 (file)
@@ -19,9 +19,9 @@ class headdict(object):
         del self.dict[key.lower()]
 
     def __iter__(self):
-        return iter((list[0] for list in self.dict.itervalues()))
+        return iter((list[0] for list in self.dict.values()))
     
-    def get(self, key, default = ""):
+    def get(self, key, default=""):
         if key.lower() in self.dict:
             return self.dict[key.lower()][1]
         return default
@@ -51,6 +51,10 @@ def fixcase(str):
         i += 1
     return str
 
+class shortinput(IOError, EOFError):
+    def __init__(self):
+        super().__init__("Unexpected EOF")
+
 class limitreader(object):
     def __init__(self, back, limit):
         self.bk = back
@@ -67,29 +71,29 @@ class limitreader(object):
             ra = min(ra, size)
         while len(self.buf) < ra:
             ret = self.bk.read(ra - len(self.buf))
-            if ret == "":
-                raise IOError("Unexpected EOF")
+            if ret == b"":
+                raise shortinput()
             self.buf.extend(ret)
             self.rb += len(ret)
-        ret = str(self.buf[:ra])
+        ret = bytes(self.buf[:ra])
         self.buf = self.buf[ra:]
         return ret
 
     def readline(self, size=-1):
         off = 0
         while True:
-            p = self.buf.find('\n', off)
+            p = self.buf.find(b'\n', off)
             if p >= 0:
-                ret = str(self.buf[:p + 1])
+                ret = bytes(self.buf[:p + 1])
                 self.buf = self.buf[p + 1:]
                 return ret
             off = len(self.buf)
             if size >= 0 and len(self.buf) >= size:
-                ret = str(self.buf[:size])
+                ret = bytes(self.buf[:size])
                 self.buf = self.buf[size:]
                 return ret
             if self.rb == self.limit:
-                ret = str(self.buf)
+                ret = bytes(self.buf)
                 self.buf = bytearray()
                 return ret
             ra = self.limit - self.rb
@@ -97,8 +101,8 @@ class limitreader(object):
                 ra = min(ra, size)
             ra = min(ra, 1024)
             ret = self.bk.read(ra)
-            if ret == "":
-                raise IOError("Unpexpected EOF")
+            if ret == b"":
+                raise shortinput()
             self.buf.extend(ret)
             self.rb += len(ret)
 
@@ -109,9 +113,9 @@ class limitreader(object):
         class lineiter(object):
             def __iter__(self):
                 return self
-            def next(self):
+            def __next__(self):
                 ret = rd.readline()
-                if ret == "":
+                if ret == b"":
                     raise StopIteration()
                 return ret
         return lineiter()
@@ -140,15 +144,20 @@ class origrequest(request):
         self.servername = env["SERVER_NAME"]
         self.https = "HTTPS" in env
         self.ihead = headdict()
-        self.input = None
         if "CONTENT_TYPE" in env:
             self.ihead["Content-Type"] = env["CONTENT_TYPE"]
-        if "CONTENT_LENGTH" in env:
-            clen = self.ihead["Content-Length"] = env["CONTENT_LENGTH"]
-            if clen.isdigit():
-                self.input = limitreader(env["wsgi.input"], int(clen))
-        if self.input is None:
-            self.input = io.BytesIO("")
+            if "CONTENT_LENGTH" in env:
+                clen = self.ihead["Content-Length"] = env["CONTENT_LENGTH"]
+                if clen.isdigit():
+                    self.input = limitreader(env["wsgi.input"], int(clen))
+                else:
+                    # XXX: What to do?
+                    self.input = io.BytesIO(b"")
+            else:
+                # Assume input is chunked and read until ordinary EOF.
+                self.input = env["wsgi.input"]
+        else:
+            self.input = None
         self.ohead = headdict()
         for k, v in env.items():
             if k[:5] == "HTTP_":
@@ -225,6 +234,7 @@ class copyrequest(request):
         self.https = p.https
         self.ihead = p.ihead
         self.ohead = p.ohead
+        self.input = p.input
 
     def status(self, code, msg):
         return self.parent.status(code, msg)