Classify trunacted input as its own exception type.
[wrw.git] / wrw / req.py
index 0689cbd..010e907 100644 (file)
@@ -1,3 +1,5 @@
+import io
+
 __all__ = ["request"]
 
 class headdict(object):
@@ -19,7 +21,7 @@ class headdict(object):
     def __iter__(self):
         return iter((list[0] for list in self.dict.values()))
     
-    def get(self, key, default = ""):
+    def get(self, key, default=""):
         if key.lower() in self.dict:
             return self.dict[key.lower()][1]
         return default
@@ -49,6 +51,75 @@ def fixcase(str):
         i += 1
     return str
 
+class shortinput(IOError, EOFError):
+    def __init__(self):
+        super().__init__("Unexpected EOF")
+
+class limitreader(object):
+    def __init__(self, back, limit):
+        self.bk = back
+        self.limit = limit
+        self.rb = 0
+        self.buf = bytearray()
+
+    def close(self):
+        pass
+
+    def read(self, size=-1):
+        ra = self.limit - self.rb
+        if size >= 0:
+            ra = min(ra, size)
+        while len(self.buf) < ra:
+            ret = self.bk.read(ra - len(self.buf))
+            if ret == b"":
+                raise shortinput()
+            self.buf.extend(ret)
+            self.rb += len(ret)
+        ret = bytes(self.buf[:ra])
+        self.buf = self.buf[ra:]
+        return ret
+
+    def readline(self, size=-1):
+        off = 0
+        while True:
+            p = self.buf.find(b'\n', off)
+            if p >= 0:
+                ret = bytes(self.buf[:p + 1])
+                self.buf = self.buf[p + 1:]
+                return ret
+            off = len(self.buf)
+            if size >= 0 and len(self.buf) >= size:
+                ret = bytes(self.buf[:size])
+                self.buf = self.buf[size:]
+                return ret
+            if self.rb == self.limit:
+                ret = bytes(self.buf)
+                self.buf = bytearray()
+                return ret
+            ra = self.limit - self.rb
+            if size >= 0:
+                ra = min(ra, size)
+            ra = min(ra, 1024)
+            ret = self.bk.read(ra)
+            if ret == b"":
+                raise shortinput()
+            self.buf.extend(ret)
+            self.rb += len(ret)
+
+    def readlines(self, hint=None):
+        return list(self)
+
+    def __iter__(rd):
+        class lineiter(object):
+            def __iter__(self):
+                return self
+            def __next__(self):
+                ret = rd.readline()
+                if ret == b"":
+                    raise StopIteration()
+                return ret
+        return lineiter()
+
 class request(object):
     def copy(self):
         return copyrequest(self)
@@ -62,6 +133,7 @@ class request(object):
 class origrequest(request):
     def __init__(self, env):
         self.env = env
+        self.method = env["REQUEST_METHOD"].upper()
         self.uriname = env["SCRIPT_NAME"]
         self.filename = env.get("SCRIPT_FILENAME")
         self.uri = env["REQUEST_URI"]
@@ -69,8 +141,23 @@ class origrequest(request):
         self.query = env["QUERY_STRING"]
         self.remoteaddr = env["REMOTE_ADDR"]
         self.serverport = env["SERVER_PORT"]
+        self.servername = env["SERVER_NAME"]
         self.https = "HTTPS" in env
         self.ihead = headdict()
+        if "CONTENT_TYPE" in env:
+            self.ihead["Content-Type"] = env["CONTENT_TYPE"]
+            if "CONTENT_LENGTH" in env:
+                clen = self.ihead["Content-Length"] = env["CONTENT_LENGTH"]
+                if clen.isdigit():
+                    self.input = limitreader(env["wsgi.input"], int(clen))
+                else:
+                    # XXX: What to do?
+                    self.input = io.BytesIO(b"")
+            else:
+                # Assume input is chunked and read until ordinary EOF.
+                self.input = env["wsgi.input"]
+        else:
+            self.input = None
         self.ohead = headdict()
         for k, v in env.items():
             if k[:5] == "HTTP_":
@@ -136,6 +223,7 @@ class copyrequest(request):
         self.parent = p
         self.top = p.topreq()
         self.env = p.env
+        self.method = p.method
         self.uriname = p.uriname
         self.filename = p.filename
         self.uri = p.uri
@@ -146,6 +234,7 @@ class copyrequest(request):
         self.https = p.https
         self.ihead = p.ihead
         self.ohead = p.ohead
+        self.input = p.input
 
     def status(self, code, msg):
         return self.parent.status(code, msg)