Fix keyword-parameter handling bug in formparams.
[wrw.git] / wrw / req.py
1 import io
2
3 __all__ = ["request"]
4
5 class headdict(object):
6     def __init__(self):
7         self.dict = {}
8
9     def __getitem__(self, key):
10         return self.dict[key.lower()][1]
11
12     def __setitem__(self, key, val):
13         self.dict[key.lower()] = [key, val]
14
15     def __contains__(self, key):
16         return key.lower() in self.dict
17
18     def __delitem__(self, key):
19         del self.dict[key.lower()]
20
21     def __iter__(self):
22         return iter((list[0] for list in self.dict.values()))
23     
24     def get(self, key, default=""):
25         if key.lower() in self.dict:
26             return self.dict[key.lower()][1]
27         return default
28
29     def getlist(self, key):
30         return self.dict.setdefault(key.lower(), [key])[1:]
31
32     def add(self, key, val):
33         self.dict.setdefault(key.lower(), [key]).append(val)
34
35     def __repr__(self):
36         return repr(self.dict)
37
38     def __str__(self):
39         return str(self.dict)
40
41 def fixcase(str):
42     str = str.lower()
43     i = 0
44     b = True
45     while i < len(str):
46         if b:
47             str = str[:i] + str[i].upper() + str[i + 1:]
48         b = False
49         if str[i] == '-':
50             b = True
51         i += 1
52     return str
53
54 class shortinput(IOError, EOFError):
55     def __init__(self):
56         super().__init__("Unexpected EOF")
57
58 class limitreader(object):
59     def __init__(self, back, limit):
60         self.bk = back
61         self.limit = limit
62         self.rb = 0
63         self.buf = bytearray()
64
65     def close(self):
66         pass
67
68     def read(self, size=-1):
69         ra = self.limit - self.rb
70         if size >= 0:
71             ra = min(ra, size)
72         while len(self.buf) < ra:
73             ret = self.bk.read(ra - len(self.buf))
74             if ret == b"":
75                 raise shortinput()
76             self.buf.extend(ret)
77             self.rb += len(ret)
78         ret = bytes(self.buf[:ra])
79         self.buf = self.buf[ra:]
80         return ret
81
82     def readline(self, size=-1):
83         off = 0
84         while True:
85             p = self.buf.find(b'\n', off)
86             if p >= 0:
87                 ret = bytes(self.buf[:p + 1])
88                 self.buf = self.buf[p + 1:]
89                 return ret
90             off = len(self.buf)
91             if size >= 0 and len(self.buf) >= size:
92                 ret = bytes(self.buf[:size])
93                 self.buf = self.buf[size:]
94                 return ret
95             if self.rb == self.limit:
96                 ret = bytes(self.buf)
97                 self.buf = bytearray()
98                 return ret
99             ra = self.limit - self.rb
100             if size >= 0:
101                 ra = min(ra, size)
102             ra = min(ra, 1024)
103             ret = self.bk.read(ra)
104             if ret == b"":
105                 raise shortinput()
106             self.buf.extend(ret)
107             self.rb += len(ret)
108
109     def readlines(self, hint=None):
110         return list(self)
111
112     def __iter__(rd):
113         class lineiter(object):
114             def __iter__(self):
115                 return self
116             def __next__(self):
117                 ret = rd.readline()
118                 if ret == b"":
119                     raise StopIteration()
120                 return ret
121         return lineiter()
122
123 class request(object):
124     def copy(self):
125         return copyrequest(self)
126
127     def shift(self, n):
128         new = self.copy()
129         new.uriname = self.uriname + self.pathinfo[:n]
130         new.pathinfo = self.pathinfo[n:]
131         return new
132
133 class origrequest(request):
134     def __init__(self, env):
135         self.env = env
136         self.method = env["REQUEST_METHOD"].upper()
137         self.uriname = env["SCRIPT_NAME"]
138         self.filename = env.get("SCRIPT_FILENAME")
139         self.uri = env["REQUEST_URI"]
140         self.pathinfo = env["PATH_INFO"]
141         self.query = env["QUERY_STRING"]
142         self.remoteaddr = env["REMOTE_ADDR"]
143         self.serverport = env["SERVER_PORT"]
144         self.servername = env["SERVER_NAME"]
145         self.https = "HTTPS" in env
146         self.ihead = headdict()
147         if "CONTENT_TYPE" in env:
148             self.ihead["Content-Type"] = env["CONTENT_TYPE"]
149             if "CONTENT_LENGTH" in env:
150                 clen = self.ihead["Content-Length"] = env["CONTENT_LENGTH"]
151                 if clen.isdigit():
152                     self.input = limitreader(env["wsgi.input"], int(clen))
153                 else:
154                     # XXX: What to do?
155                     self.input = io.BytesIO(b"")
156             else:
157                 # Assume input is chunked and read until ordinary EOF.
158                 self.input = env["wsgi.input"]
159         else:
160             self.input = None
161         self.ohead = headdict()
162         for k, v in env.items():
163             if k[:5] == "HTTP_":
164                 self.ihead.add(fixcase(k[5:].replace("_", "-")), v)
165         self.items = {}
166         self.statuscode = (200, "OK")
167         self.ohead["Content-Type"] = "text/html"
168         self.resources = set()
169         self.clean = set()
170         self.commitfuns = []
171
172     def status(self, code, msg):
173         self.statuscode = code, msg
174
175     def item(self, id):
176         if id in self.items:
177             return self.items[id]
178         self.items[id] = new = id(self)
179         if hasattr(new, "__enter__") and hasattr(new, "__exit__"):
180             self.withres(new)
181         return new
182
183     def withres(self, res):
184         if res not in self.resources:
185             done = False
186             res.__enter__()
187             try:
188                 self.resources.add(res)
189                 self.clean.add(res.__exit__)
190                 done = True
191             finally:
192                 if not done:
193                     res.__exit__(None, None, None)
194                     self.resources.discard(res)
195
196     def cleanup(self):
197         def clean1(list):
198             if len(list) > 0:
199                 try:
200                     list[0]()
201                 finally:
202                     clean1(list[1:])
203         clean1(list(self.clean))
204
205     def oncommit(self, fn):
206         if fn not in self.commitfuns:
207             self.commitfuns.append(fn)
208
209     def commit(self, startreq):
210         for fun in reversed(self.commitfuns):
211             fun(self)
212         hdrs = []
213         for nm in self.ohead:
214             for val in self.ohead.getlist(nm):
215                 hdrs.append((nm, val))
216         startreq("%s %s" % self.statuscode, hdrs)
217
218     def topreq(self):
219         return self
220
221 class copyrequest(request):
222     def __init__(self, p):
223         self.parent = p
224         self.top = p.topreq()
225         self.env = p.env
226         self.method = p.method
227         self.uriname = p.uriname
228         self.filename = p.filename
229         self.uri = p.uri
230         self.pathinfo = p.pathinfo
231         self.query = p.query
232         self.remoteaddr = p.remoteaddr
233         self.serverport = p.serverport
234         self.https = p.https
235         self.ihead = p.ihead
236         self.ohead = p.ohead
237         self.input = p.input
238
239     def status(self, code, msg):
240         return self.parent.status(code, msg)
241
242     def item(self, id):
243         return self.top.item(id)
244
245     def withres(self, res):
246         return self.top.withres(res)
247
248     def oncommit(self, fn):
249         return self.top.oncommit(fn)
250
251     def topreq(self):
252         return self.parent.topreq()