X-Git-Url: http://dolda2000.com/gitweb/?p=wrw.git;a=blobdiff_plain;f=wrw%2Freq.py;h=010e907a5d203c86281f0a61e298498a8504518b;hp=f3dc31f10d93528d639f423b5f4e976145ef2971;hb=d30502c8fe37bedb30ad9f3ddecd5191c5b077fb;hpb=b409a33843abb3221edd27016558c39cf33a6510 diff --git a/wrw/req.py b/wrw/req.py index f3dc31f..010e907 100644 --- a/wrw/req.py +++ b/wrw/req.py @@ -1,3 +1,5 @@ +import io + __all__ = ["request"] class headdict(object): @@ -17,9 +19,9 @@ class headdict(object): del self.dict[key.lower()] def __iter__(self): - return iter((list[0] for list in self.dict.itervalues())) + return iter((list[0] for list in self.dict.values())) - def get(self, key, default = ""): + def get(self, key, default=""): if key.lower() in self.dict: return self.dict[key.lower()][1] return default @@ -49,9 +51,89 @@ def fixcase(str): i += 1 return str +class shortinput(IOError, EOFError): + def __init__(self): + super().__init__("Unexpected EOF") + +class limitreader(object): + def __init__(self, back, limit): + self.bk = back + self.limit = limit + self.rb = 0 + self.buf = bytearray() + + def close(self): + pass + + def read(self, size=-1): + ra = self.limit - self.rb + if size >= 0: + ra = min(ra, size) + while len(self.buf) < ra: + ret = self.bk.read(ra - len(self.buf)) + if ret == b"": + raise shortinput() + self.buf.extend(ret) + self.rb += len(ret) + ret = bytes(self.buf[:ra]) + self.buf = self.buf[ra:] + return ret + + def readline(self, size=-1): + off = 0 + while True: + p = self.buf.find(b'\n', off) + if p >= 0: + ret = bytes(self.buf[:p + 1]) + self.buf = self.buf[p + 1:] + return ret + off = len(self.buf) + if size >= 0 and len(self.buf) >= size: + ret = bytes(self.buf[:size]) + self.buf = self.buf[size:] + return ret + if self.rb == self.limit: + ret = bytes(self.buf) + self.buf = bytearray() + return ret + ra = self.limit - self.rb + if size >= 0: + ra = min(ra, size) + ra = min(ra, 1024) + ret = self.bk.read(ra) + if ret == b"": + raise shortinput() + self.buf.extend(ret) + self.rb += len(ret) + + def readlines(self, hint=None): + return list(self) + + def __iter__(rd): + class lineiter(object): + def __iter__(self): + return self + def __next__(self): + ret = rd.readline() + if ret == b"": + raise StopIteration() + return ret + return lineiter() + class request(object): + def copy(self): + return copyrequest(self) + + def shift(self, n): + new = self.copy() + new.uriname = self.uriname + self.pathinfo[:n] + new.pathinfo = self.pathinfo[n:] + return new + +class origrequest(request): def __init__(self, env): self.env = env + self.method = env["REQUEST_METHOD"].upper() self.uriname = env["SCRIPT_NAME"] self.filename = env.get("SCRIPT_FILENAME") self.uri = env["REQUEST_URI"] @@ -59,8 +141,23 @@ class request(object): self.query = env["QUERY_STRING"] self.remoteaddr = env["REMOTE_ADDR"] self.serverport = env["SERVER_PORT"] + self.servername = env["SERVER_NAME"] self.https = "HTTPS" in env self.ihead = headdict() + if "CONTENT_TYPE" in env: + self.ihead["Content-Type"] = env["CONTENT_TYPE"] + if "CONTENT_LENGTH" in env: + clen = self.ihead["Content-Length"] = env["CONTENT_LENGTH"] + if clen.isdigit(): + self.input = limitreader(env["wsgi.input"], int(clen)) + else: + # XXX: What to do? + self.input = io.BytesIO(b"") + else: + # Assume input is chunked and read until ordinary EOF. + self.input = env["wsgi.input"] + else: + self.input = None self.ohead = headdict() for k, v in env.items(): if k[:5] == "HTTP_": @@ -117,3 +214,39 @@ class request(object): for val in self.ohead.getlist(nm): hdrs.append((nm, val)) startreq("%s %s" % self.statuscode, hdrs) + + def topreq(self): + return self + +class copyrequest(request): + def __init__(self, p): + self.parent = p + self.top = p.topreq() + self.env = p.env + self.method = p.method + self.uriname = p.uriname + self.filename = p.filename + self.uri = p.uri + self.pathinfo = p.pathinfo + self.query = p.query + self.remoteaddr = p.remoteaddr + self.serverport = p.serverport + self.https = p.https + self.ihead = p.ihead + self.ohead = p.ohead + self.input = p.input + + def status(self, code, msg): + return self.parent.status(code, msg) + + def item(self, id): + return self.top.item(id) + + def withres(self, res): + return self.top.withres(res) + + def oncommit(self, fn): + return self.top.oncommit(fn) + + def topreq(self): + return self.parent.topreq()