Classify trunacted input as its own exception type.
[wrw.git] / wrw / req.py
CommitLineData
609f664f
FT
1import io
2
b409a338
FT
3__all__ = ["request"]
4
5class headdict(object):
6 def __init__(self):
7 self.dict = {}
8
9 def __getitem__(self, key):
10 return self.dict[key.lower()][1]
11
12 def __setitem__(self, key, val):
13 self.dict[key.lower()] = [key, val]
14
15 def __contains__(self, key):
16 return key.lower() in self.dict
17
18 def __delitem__(self, key):
19 del self.dict[key.lower()]
20
21 def __iter__(self):
c33f2d6c 22 return iter((list[0] for list in self.dict.values()))
b409a338 23
9bc70dab 24 def get(self, key, default=""):
b409a338
FT
25 if key.lower() in self.dict:
26 return self.dict[key.lower()][1]
27 return default
28
29 def getlist(self, key):
30 return self.dict.setdefault(key.lower(), [key])[1:]
31
32 def add(self, key, val):
33 self.dict.setdefault(key.lower(), [key]).append(val)
34
35 def __repr__(self):
36 return repr(self.dict)
37
38 def __str__(self):
39 return str(self.dict)
40
41def fixcase(str):
42 str = str.lower()
43 i = 0
44 b = True
45 while i < len(str):
46 if b:
47 str = str[:i] + str[i].upper() + str[i + 1:]
48 b = False
49 if str[i] == '-':
50 b = True
51 i += 1
52 return str
53
d30502c8
FT
54class shortinput(IOError, EOFError):
55 def __init__(self):
56 super().__init__("Unexpected EOF")
57
609f664f
FT
58class limitreader(object):
59 def __init__(self, back, limit):
60 self.bk = back
61 self.limit = limit
62 self.rb = 0
63 self.buf = bytearray()
64
65 def close(self):
66 pass
67
68 def read(self, size=-1):
69 ra = self.limit - self.rb
70 if size >= 0:
71 ra = min(ra, size)
72 while len(self.buf) < ra:
73 ret = self.bk.read(ra - len(self.buf))
f54a3465 74 if ret == b"":
d30502c8 75 raise shortinput()
609f664f
FT
76 self.buf.extend(ret)
77 self.rb += len(ret)
f54a3465 78 ret = bytes(self.buf[:ra])
609f664f
FT
79 self.buf = self.buf[ra:]
80 return ret
81
82 def readline(self, size=-1):
83 off = 0
84 while True:
f54a3465 85 p = self.buf.find(b'\n', off)
609f664f 86 if p >= 0:
f54a3465 87 ret = bytes(self.buf[:p + 1])
609f664f
FT
88 self.buf = self.buf[p + 1:]
89 return ret
90 off = len(self.buf)
91 if size >= 0 and len(self.buf) >= size:
f54a3465 92 ret = bytes(self.buf[:size])
609f664f
FT
93 self.buf = self.buf[size:]
94 return ret
95 if self.rb == self.limit:
f54a3465 96 ret = bytes(self.buf)
609f664f
FT
97 self.buf = bytearray()
98 return ret
99 ra = self.limit - self.rb
100 if size >= 0:
101 ra = min(ra, size)
102 ra = min(ra, 1024)
103 ret = self.bk.read(ra)
f54a3465 104 if ret == b"":
d30502c8 105 raise shortinput()
609f664f
FT
106 self.buf.extend(ret)
107 self.rb += len(ret)
108
109 def readlines(self, hint=None):
110 return list(self)
111
112 def __iter__(rd):
113 class lineiter(object):
114 def __iter__(self):
115 return self
f54a3465 116 def __next__(self):
609f664f 117 ret = rd.readline()
f54a3465 118 if ret == b"":
609f664f
FT
119 raise StopIteration()
120 return ret
121 return lineiter()
122
b409a338 123class request(object):
0a59819d
FT
124 def copy(self):
125 return copyrequest(self)
126
127 def shift(self, n):
128 new = self.copy()
129 new.uriname = self.uriname + self.pathinfo[:n]
130 new.pathinfo = self.pathinfo[n:]
131 return new
132
133class origrequest(request):
b409a338
FT
134 def __init__(self, env):
135 self.env = env
40131e7c 136 self.method = env["REQUEST_METHOD"].upper()
b409a338
FT
137 self.uriname = env["SCRIPT_NAME"]
138 self.filename = env.get("SCRIPT_FILENAME")
139 self.uri = env["REQUEST_URI"]
140 self.pathinfo = env["PATH_INFO"]
141 self.query = env["QUERY_STRING"]
142 self.remoteaddr = env["REMOTE_ADDR"]
143 self.serverport = env["SERVER_PORT"]
eacc5938 144 self.servername = env["SERVER_NAME"]
b409a338
FT
145 self.https = "HTTPS" in env
146 self.ihead = headdict()
3e71b44b
FT
147 if "CONTENT_TYPE" in env:
148 self.ihead["Content-Type"] = env["CONTENT_TYPE"]
381b2eef
FT
149 if "CONTENT_LENGTH" in env:
150 clen = self.ihead["Content-Length"] = env["CONTENT_LENGTH"]
151 if clen.isdigit():
152 self.input = limitreader(env["wsgi.input"], int(clen))
153 else:
154 # XXX: What to do?
a0212733 155 self.input = io.BytesIO(b"")
381b2eef
FT
156 else:
157 # Assume input is chunked and read until ordinary EOF.
158 self.input = env["wsgi.input"]
159 else:
160 self.input = None
b409a338
FT
161 self.ohead = headdict()
162 for k, v in env.items():
163 if k[:5] == "HTTP_":
164 self.ihead.add(fixcase(k[5:].replace("_", "-")), v)
165 self.items = {}
166 self.statuscode = (200, "OK")
167 self.ohead["Content-Type"] = "text/html"
168 self.resources = set()
169 self.clean = set()
170 self.commitfuns = []
171
172 def status(self, code, msg):
173 self.statuscode = code, msg
174
175 def item(self, id):
176 if id in self.items:
177 return self.items[id]
178 self.items[id] = new = id(self)
179 if hasattr(new, "__enter__") and hasattr(new, "__exit__"):
180 self.withres(new)
181 return new
182
183 def withres(self, res):
184 if res not in self.resources:
185 done = False
186 res.__enter__()
187 try:
188 self.resources.add(res)
189 self.clean.add(res.__exit__)
190 done = True
191 finally:
192 if not done:
193 res.__exit__(None, None, None)
194 self.resources.discard(res)
195
196 def cleanup(self):
197 def clean1(list):
198 if len(list) > 0:
199 try:
200 list[0]()
201 finally:
202 clean1(list[1:])
203 clean1(list(self.clean))
204
205 def oncommit(self, fn):
206 if fn not in self.commitfuns:
207 self.commitfuns.append(fn)
208
209 def commit(self, startreq):
210 for fun in reversed(self.commitfuns):
211 fun(self)
212 hdrs = []
213 for nm in self.ohead:
214 for val in self.ohead.getlist(nm):
215 hdrs.append((nm, val))
216 startreq("%s %s" % self.statuscode, hdrs)
0a59819d
FT
217
218 def topreq(self):
219 return self
220
221class copyrequest(request):
222 def __init__(self, p):
223 self.parent = p
224 self.top = p.topreq()
225 self.env = p.env
6a6c9d8f 226 self.method = p.method
0a59819d
FT
227 self.uriname = p.uriname
228 self.filename = p.filename
229 self.uri = p.uri
230 self.pathinfo = p.pathinfo
231 self.query = p.query
232 self.remoteaddr = p.remoteaddr
233 self.serverport = p.serverport
234 self.https = p.https
235 self.ihead = p.ihead
236 self.ohead = p.ohead
0d77a23d 237 self.input = p.input
0a59819d
FT
238
239 def status(self, code, msg):
240 return self.parent.status(code, msg)
241
242 def item(self, id):
243 return self.top.item(id)
244
245 def withres(self, res):
246 return self.top.withres(res)
247
248 def oncommit(self, fn):
249 return self.top.oncommit(fn)
250
251 def topreq(self):
252 return self.parent.topreq()