X-Git-Url: http://dolda2000.com/gitweb/?a=blobdiff_plain;f=wrw%2Fform.py;h=c97b0f93102d38bab42370a7afd03766ef61a1d0;hb=784495582d53f5b759dd216ddd1268fbe2479bb9;hp=d089fd70d6a1f6321271e82c3995f0652e2fd55a;hpb=8b7cf2787e8cf2ee716c709d904f66c651cfed89;p=wrw.git diff --git a/wrw/form.py b/wrw/form.py index d089fd7..c97b0f9 100644 --- a/wrw/form.py +++ b/wrw/form.py @@ -6,14 +6,14 @@ __all__ = ["formdata"] class formwrap(object): def __init__(self, req): if req.ihead.get("Content-Type") == "application/x-www-form-urlencoded": - self.cf = cgi.parse(environ = req.env, fp = req.env["wsgi.input"]) + self.cf = cgi.parse(environ = req.env, fp = req.input) else: self.cf = cgi.parse(environ = req.env) def __getitem__(self, key): return self.cf[key][0] - def get(self, key, default = ""): + def get(self, key, default=""): if key in self: return self.cf[key][0] return default @@ -43,7 +43,7 @@ class badmultipart(Exception): class formpart(object): def __init__(self, form): self.form = form - self.buf = "" + self.buf = b"" self.eof = False self.head = {} @@ -52,8 +52,8 @@ class formpart(object): def fillbuf(self, sz): req = self.form.req - mboundary = "\r\n--" + self.form.boundary + "\r\n" - lboundary = "\r\n--" + self.form.boundary + "--\r\n" + mboundary = b"\r\n--" + self.form.boundary + b"\r\n" + lboundary = b"\r\n--" + self.form.boundary + b"--\r\n" while not self.eof: p = self.form.buf.find(mboundary) if p >= 0: @@ -73,12 +73,12 @@ class formpart(object): if sz >= 0 and len(self.buf) >= sz: break while len(self.form.buf) <= len(lboundary): - ret = req.env["wsgi.input"].read(8192) + ret = req.input.read(8192) if ret == "": raise badmultipart("Missing last multipart boundary") self.form.buf += ret - def read(self, limit = -1): + def read(self, limit=-1): self.fillbuf(limit) if limit >= 0: ret = self.buf[:limit] @@ -88,10 +88,10 @@ class formpart(object): self.buf = "" return ret - def readline(self, limit = -1): + def readline(self, limit=-1): last = 0 while True: - p = self.buf.find('\n', last) + p = self.buf.find(b'\n', last) if p < 0: if self.eof: ret = self.buf @@ -111,14 +111,18 @@ class formpart(object): return self def __exit__(self, *excinfo): + self.close() return False - def parsehead(self): + def parsehead(self, charset): def headline(): ln = self.readline(256) - if ln[-1] != '\n': + if ln[-1] != ord(b'\n'): raise badmultipart("Too long header line in part") - return ln.rstrip() + try: + return ln.decode(charset).rstrip() + except UnicodeError: + raise badmultipart("Form part header is not in assumed charset") ln = headline() while True: @@ -150,29 +154,33 @@ class formpart(object): raise badmultipart("Form part uses unexpected transfer encoding: %r" % encoding) class multipart(object): - def __init__(self, req): + def __init__(self, req, charset): val, par = proto.pmimehead(req.ihead.get("Content-Type", "")) if req.method != "POST" or val != "multipart/form-data": raise badmultipart("Request is not a multipart form") if "boundary" not in par: raise badmultipart("Multipart form lacks boundary") - self.boundary = par["boundary"] + try: + self.boundary = par["boundary"].encode("us-ascii") + except UnicodeError: + raise badmultipart("Multipart boundary must be ASCII string") self.req = req - self.buf = "\r\n" + self.buf = b"\r\n" self.eof = False + self.headcs = charset self.lastpart = formpart(self) self.lastpart.close() def __iter__(self): return self - def next(self): + def __next__(self): if not self.lastpart.eof: raise RuntimeError("All form parts must be read entirely") if self.eof: raise StopIteration() self.lastpart = formpart(self) - self.lastpart.parsehead() + self.lastpart.parsehead(self.headcs) return self.lastpart def formdata(req):