Save references to wrapped functions.
[wrw.git] / wrw / util.py
1 import inspect, math
2 import req, dispatch, session, form, resp, proto
3
4 def wsgiwrap(callable):
5     def wrapper(env, startreq):
6         return dispatch.handleenv(env, startreq, callable)
7     wrapper.__wrapped__ = callable
8     return wrapper
9
10 def formparams(callable):
11     def wrapper(req):
12         data = form.formdata(req)
13         spec = inspect.getargspec(callable)
14         args = dict(data.items())
15         args["req"] = req
16         if not spec.keywords:
17             for arg in list(args):
18                 if arg not in spec.args:
19                     del args[arg]
20         for i in xrange(len(spec.args) - len(spec.defaults)):
21             if spec.args[i] not in args:
22                 raise resp.httperror(400, "Missing parameter", ("The query parameter `", resp.h.code(spec.args[i]), "' is required but not supplied."))
23         return callable(**args)
24     wrapper.__wrapped__ = callable
25     return wrapper
26
27 def funplex(*funs, **nfuns):
28     dir = {}
29     dir.update(((fun.__name__, fun) for fun in funs))
30     dir.update(nfuns)
31     def handler(req):
32         if req.pathinfo == "":
33             raise resp.redirect(req.uriname + "/")
34         if req.pathinfo[:1] != "/":
35             raise resp.notfound()
36         p = req.pathinfo[1:]
37         if p == "":
38             p = "__index__"
39             bi = 1
40         else:
41             p = p.partition("/")[0]
42             bi = len(p) + 1
43         if p in dir:
44             return dir[p](req.shift(bi))
45         raise resp.notfound()
46     return handler
47
48 def persession(data = None):
49     def dec(callable):
50         def wrapper(req):
51             sess = session.get(req)
52             if callable not in sess:
53                 if data is None:
54                     sess[callable] = callable()
55                 else:
56                     if data not in sess:
57                         sess[data] = data()
58                     sess[callable] = callable(data)
59             return sess[callable].handle(req)
60         wrapper.__wrapped__ = callable
61         return wrapper
62     return dec
63
64 class preiter(object):
65     __slots__ = ["bk", "bki", "_next"]
66     end = object()
67     def __init__(self, real):
68         self.bk = real
69         self.bki = iter(real)
70         self._next = None
71         self.next()
72
73     def __iter__(self):
74         return self
75
76     def next(self):
77         if self._next is self.end:
78             raise StopIteration()
79         ret = self._next
80         try:
81             self._next = next(self.bki)
82         except StopIteration:
83             self._next = self.end
84         return ret
85
86     def close(self):
87         if hasattr(self.bk, "close"):
88             self.bk.close()
89
90 def pregen(callable):
91     def wrapper(*args, **kwargs):
92         return preiter(callable(*args, **kwargs))
93     wrapper.__wrapped__ = callable
94     return wrapper
95
96 class sessiondata(object):
97     @classmethod
98     def get(cls, req, create = True):
99         sess = cls.sessdb().get(req)
100         with sess.lock:
101             try:
102                 return sess[cls]
103             except KeyError:
104                 if not create:
105                     return None
106                 ret = cls(req, sess)
107                 sess[cls] = ret
108                 return ret
109
110     @classmethod
111     def sessdb(cls):
112         return session.default.val
113
114 class autodirty(sessiondata):
115     @classmethod
116     def get(cls, req):
117         ret = super(autodirty, cls).get(req)
118         if "_is_dirty" not in ret.__dict__:
119             ret.__dict__["_is_dirty"] = False
120         return ret
121
122     def sessfrozen(self):
123         self.__dict__["_is_dirty"] = False
124
125     def sessdirty(self):
126         return self._is_dirty
127
128     def __setattr__(self, name, value):
129         super(autodirty, self).__setattr__(name, value)
130         if "_is_dirty" in self.__dict__:
131             self.__dict__["_is_dirty"] = True
132
133     def __delattr__(self, name):
134         super(autodirty, self).__delattr__(name, value)
135         if "_is_dirty" in self.__dict__:
136             self.__dict__["_is_dirty"] = True
137
138 class manudirty(object):
139     def __init__(self, *args, **kwargs):
140         super(manudirty, self).__init__(*args, **kwargs)
141         self.__dirty = False
142
143     def sessfrozen(self):
144         self.__dirty = False
145
146     def sessdirty(self):
147         return self.__dirty
148
149     def dirty(self):
150         self.__dirty = True
151
152 class specslot(object):
153     __slots__ = ["nm", "idx", "dirty"]
154     unbound = object()
155     
156     def __init__(self, nm, idx, dirty):
157         self.nm = nm
158         self.idx = idx
159         self.dirty = dirty
160
161     @staticmethod
162     def slist(ins):
163         # Avoid calling __getattribute__
164         return specdirty.__sslots__.__get__(ins, type(ins))
165
166     def __get__(self, ins, cls):
167         val = self.slist(ins)[self.idx]
168         if val is specslot.unbound:
169             raise AttributeError("specslot %r is unbound" % self.nm)
170         return val
171
172     def __set__(self, ins, val):
173         self.slist(ins)[self.idx] = val
174         if self.dirty:
175             ins.dirty()
176
177     def __delete__(self, ins):
178         self.slist(ins)[self.idx] = specslot.unbound
179         ins.dirty()
180
181 class specclass(type):
182     def __init__(self, name, bases, tdict):
183         super(specclass, self).__init__(name, bases, tdict)
184         sslots = set()
185         dslots = set()
186         for cls in self.__mro__:
187             css = cls.__dict__.get("__saveslots__", ())
188             sslots.update(css)
189             dslots.update(cls.__dict__.get("__dirtyslots__", css))
190         self.__sslots_l__ = list(sslots)
191         self.__sslots_a__ = list(sslots | dslots)
192         for i, slot in enumerate(self.__sslots_a__):
193             setattr(self, slot, specslot(slot, i, slot in dslots))
194
195 class specdirty(sessiondata):
196     __metaclass__ = specclass
197     __slots__ = ["session", "__sslots__", "_is_dirty"]
198     
199     def __specinit__(self):
200         pass
201
202     @staticmethod
203     def __new__(cls, req, sess):
204         self = super(specdirty, cls).__new__(cls)
205         self.session = sess
206         self.__sslots__ = [specslot.unbound] * len(cls.__sslots_a__)
207         self.__specinit__()
208         self._is_dirty = False
209         return self
210
211     def __getnewargs__(self):
212         return (None, self.session)
213
214     def dirty(self):
215         self._is_dirty = True
216
217     def sessfrozen(self):
218         self._is_dirty = False
219
220     def sessdirty(self):
221         return self._is_dirty
222
223     def __getstate__(self):
224         ret = {}
225         for nm, val in zip(type(self).__sslots_a__, specslot.slist(self)):
226             if val is specslot.unbound:
227                 ret[nm] = False, None
228             else:
229                 ret[nm] = True, val
230         return ret
231
232     def __setstate__(self, st):
233         ss = specslot.slist(self)
234         for i, nm in enumerate(type(self).__sslots_a__):
235             bound, val = st.pop(nm, (False, None))
236             if not bound:
237                 ss[i] = specslot.unbound
238             else:
239                 ss[i] = val
240
241 def datecheck(req, mtime):
242     if "If-Modified-Since" in req.ihead:
243         rtime = proto.phttpdate(req.ihead["If-Modified-Since"])
244         if rtime >= math.floor(mtime):
245             raise resp.unmodified()
246     req.ohead["Last-Modified"] = proto.httpdate(mtime)