Merge branch 'master' into jython
[wrw.git] / wrw / req.py
1 __all__ = ["request"]
2
3 class headdict(object):
4     def __init__(self):
5         self.dict = {}
6
7     def __getitem__(self, key):
8         return self.dict[key.lower()][1]
9
10     def __setitem__(self, key, val):
11         self.dict[key.lower()] = [key, val]
12
13     def __contains__(self, key):
14         return key.lower() in self.dict
15
16     def __delitem__(self, key):
17         del self.dict[key.lower()]
18
19     def __iter__(self):
20         return iter((list[0] for list in self.dict.itervalues()))
21     
22     def get(self, key, default = ""):
23         if key.lower() in self.dict:
24             return self.dict[key.lower()][1]
25         return default
26
27     def getlist(self, key):
28         return self.dict.setdefault(key.lower(), [key])[1:]
29
30     def add(self, key, val):
31         self.dict.setdefault(key.lower(), [key]).append(val)
32
33     def __repr__(self):
34         return repr(self.dict)
35
36     def __str__(self):
37         return str(self.dict)
38
39 def fixcase(str):
40     str = str.lower()
41     i = 0
42     b = True
43     while i < len(str):
44         if b:
45             str = str[:i] + str[i].upper() + str[i + 1:]
46         b = False
47         if str[i] == '-':
48             b = True
49         i += 1
50     return str
51
52 class request(object):
53     def copy(self):
54         return copyrequest(self)
55
56     def shift(self, n):
57         new = self.copy()
58         new.uriname = self.uriname + self.pathinfo[:n]
59         new.pathinfo = self.pathinfo[n:]
60         return new
61
62 class origrequest(request):
63     def __init__(self, env):
64         self.env = env
65         self.method = env["REQUEST_METHOD"].upper()
66         self.uriname = env["SCRIPT_NAME"]
67         self.filename = env.get("SCRIPT_FILENAME")
68         self.uri = env["REQUEST_URI"]
69         self.pathinfo = env["PATH_INFO"]
70         self.query = env["QUERY_STRING"]
71         self.remoteaddr = env["REMOTE_ADDR"]
72         self.serverport = env["SERVER_PORT"]
73         self.servername = env["SERVER_NAME"]
74         self.https = "HTTPS" in env
75         self.ihead = headdict()
76         if "CONTENT_TYPE" in env:
77             self.ihead["Content-Type"] = env["CONTENT_TYPE"]
78         if "CONTENT_LENGTH" in env:
79             self.ihead["Content-Length"] = env["CONTENT_LENGTH"]
80         self.ohead = headdict()
81         for k, v in env.items():
82             if k[:5] == "HTTP_":
83                 self.ihead.add(fixcase(k[5:].replace("_", "-")), v)
84         self.items = {}
85         self.statuscode = (200, "OK")
86         self.ohead["Content-Type"] = "text/html"
87         self.resources = set()
88         self.clean = set()
89         self.commitfuns = []
90
91     def status(self, code, msg):
92         self.statuscode = code, msg
93
94     def item(self, id):
95         if id in self.items:
96             return self.items[id]
97         self.items[id] = new = id(self)
98         if hasattr(new, "__enter__") and hasattr(new, "__exit__"):
99             self.withres(new)
100         return new
101
102     def withres(self, res):
103         if res not in self.resources:
104             done = False
105             res.__enter__()
106             try:
107                 self.resources.add(res)
108                 self.clean.add(res.__exit__)
109                 done = True
110             finally:
111                 if not done:
112                     res.__exit__(None, None, None)
113                     self.resources.discard(res)
114
115     def cleanup(self):
116         def clean1(list):
117             if len(list) > 0:
118                 try:
119                     list[0]()
120                 finally:
121                     clean1(list[1:])
122         clean1(list(self.clean))
123
124     def oncommit(self, fn):
125         if fn not in self.commitfuns:
126             self.commitfuns.append(fn)
127
128     def commit(self, startreq):
129         for fun in reversed(self.commitfuns):
130             fun(self)
131         hdrs = []
132         for nm in self.ohead:
133             for val in self.ohead.getlist(nm):
134                 hdrs.append((nm, val))
135         startreq("%s %s" % self.statuscode, hdrs)
136
137     def topreq(self):
138         return self
139
140 class copyrequest(request):
141     def __init__(self, p):
142         self.parent = p
143         self.top = p.topreq()
144         self.env = p.env
145         self.method = p.method
146         self.uriname = p.uriname
147         self.filename = p.filename
148         self.uri = p.uri
149         self.pathinfo = p.pathinfo
150         self.query = p.query
151         self.remoteaddr = p.remoteaddr
152         self.serverport = p.serverport
153         self.https = p.https
154         self.ihead = p.ihead
155         self.ohead = p.ohead
156
157     def status(self, code, msg):
158         return self.parent.status(code, msg)
159
160     def item(self, id):
161         return self.top.item(id)
162
163     def withres(self, res):
164         return self.top.withres(res)
165
166     def oncommit(self, fn):
167         return self.top.oncommit(fn)
168
169     def topreq(self):
170         return self.parent.topreq()