Try to unwrap functions passed to funplex.
[wrw.git] / wrw / util.py
1 import inspect, math
2 import req, dispatch, session, form, resp, proto
3
4 def wsgiwrap(callable):
5     def wrapper(env, startreq):
6         return dispatch.handleenv(env, startreq, callable)
7     wrapper.__wrapped__ = callable
8     return wrapper
9
10 def formparams(callable):
11     def wrapper(req):
12         data = form.formdata(req)
13         spec = inspect.getargspec(callable)
14         args = dict(data.items())
15         args["req"] = req
16         if not spec.keywords:
17             for arg in list(args):
18                 if arg not in spec.args:
19                     del args[arg]
20         for i in xrange(len(spec.args) - len(spec.defaults)):
21             if spec.args[i] not in args:
22                 raise resp.httperror(400, "Missing parameter", ("The query parameter `", resp.h.code(spec.args[i]), "' is required but not supplied."))
23         return callable(**args)
24     wrapper.__wrapped__ = callable
25     return wrapper
26
27 def funplex(*funs, **nfuns):
28     def unwrap(fun):
29         while hasattr(fun, "__wrapped__"):
30             fun = fun.__wrapped__
31         return fun
32     dir = {}
33     dir.update(((unwrap(fun).__name__, fun) for fun in funs))
34     dir.update(nfuns)
35     def handler(req):
36         if req.pathinfo == "":
37             raise resp.redirect(req.uriname + "/")
38         if req.pathinfo[:1] != "/":
39             raise resp.notfound()
40         p = req.pathinfo[1:]
41         if p == "":
42             p = "__index__"
43             bi = 1
44         else:
45             p = p.partition("/")[0]
46             bi = len(p) + 1
47         if p in dir:
48             return dir[p](req.shift(bi))
49         raise resp.notfound()
50     return handler
51
52 def persession(data = None):
53     def dec(callable):
54         def wrapper(req):
55             sess = session.get(req)
56             if callable not in sess:
57                 if data is None:
58                     sess[callable] = callable()
59                 else:
60                     if data not in sess:
61                         sess[data] = data()
62                     sess[callable] = callable(data)
63             return sess[callable].handle(req)
64         wrapper.__wrapped__ = callable
65         return wrapper
66     return dec
67
68 class preiter(object):
69     __slots__ = ["bk", "bki", "_next"]
70     end = object()
71     def __init__(self, real):
72         self.bk = real
73         self.bki = iter(real)
74         self._next = None
75         self.next()
76
77     def __iter__(self):
78         return self
79
80     def next(self):
81         if self._next is self.end:
82             raise StopIteration()
83         ret = self._next
84         try:
85             self._next = next(self.bki)
86         except StopIteration:
87             self._next = self.end
88         return ret
89
90     def close(self):
91         if hasattr(self.bk, "close"):
92             self.bk.close()
93
94 def pregen(callable):
95     def wrapper(*args, **kwargs):
96         return preiter(callable(*args, **kwargs))
97     wrapper.__wrapped__ = callable
98     return wrapper
99
100 class sessiondata(object):
101     @classmethod
102     def get(cls, req, create = True):
103         sess = cls.sessdb().get(req)
104         with sess.lock:
105             try:
106                 return sess[cls]
107             except KeyError:
108                 if not create:
109                     return None
110                 ret = cls(req, sess)
111                 sess[cls] = ret
112                 return ret
113
114     @classmethod
115     def sessdb(cls):
116         return session.default.val
117
118 class autodirty(sessiondata):
119     @classmethod
120     def get(cls, req):
121         ret = super(autodirty, cls).get(req)
122         if "_is_dirty" not in ret.__dict__:
123             ret.__dict__["_is_dirty"] = False
124         return ret
125
126     def sessfrozen(self):
127         self.__dict__["_is_dirty"] = False
128
129     def sessdirty(self):
130         return self._is_dirty
131
132     def __setattr__(self, name, value):
133         super(autodirty, self).__setattr__(name, value)
134         if "_is_dirty" in self.__dict__:
135             self.__dict__["_is_dirty"] = True
136
137     def __delattr__(self, name):
138         super(autodirty, self).__delattr__(name, value)
139         if "_is_dirty" in self.__dict__:
140             self.__dict__["_is_dirty"] = True
141
142 class manudirty(object):
143     def __init__(self, *args, **kwargs):
144         super(manudirty, self).__init__(*args, **kwargs)
145         self.__dirty = False
146
147     def sessfrozen(self):
148         self.__dirty = False
149
150     def sessdirty(self):
151         return self.__dirty
152
153     def dirty(self):
154         self.__dirty = True
155
156 class specslot(object):
157     __slots__ = ["nm", "idx", "dirty"]
158     unbound = object()
159     
160     def __init__(self, nm, idx, dirty):
161         self.nm = nm
162         self.idx = idx
163         self.dirty = dirty
164
165     @staticmethod
166     def slist(ins):
167         # Avoid calling __getattribute__
168         return specdirty.__sslots__.__get__(ins, type(ins))
169
170     def __get__(self, ins, cls):
171         val = self.slist(ins)[self.idx]
172         if val is specslot.unbound:
173             raise AttributeError("specslot %r is unbound" % self.nm)
174         return val
175
176     def __set__(self, ins, val):
177         self.slist(ins)[self.idx] = val
178         if self.dirty:
179             ins.dirty()
180
181     def __delete__(self, ins):
182         self.slist(ins)[self.idx] = specslot.unbound
183         ins.dirty()
184
185 class specclass(type):
186     def __init__(self, name, bases, tdict):
187         super(specclass, self).__init__(name, bases, tdict)
188         sslots = set()
189         dslots = set()
190         for cls in self.__mro__:
191             css = cls.__dict__.get("__saveslots__", ())
192             sslots.update(css)
193             dslots.update(cls.__dict__.get("__dirtyslots__", css))
194         self.__sslots_l__ = list(sslots)
195         self.__sslots_a__ = list(sslots | dslots)
196         for i, slot in enumerate(self.__sslots_a__):
197             setattr(self, slot, specslot(slot, i, slot in dslots))
198
199 class specdirty(sessiondata):
200     __metaclass__ = specclass
201     __slots__ = ["session", "__sslots__", "_is_dirty"]
202     
203     def __specinit__(self):
204         pass
205
206     @staticmethod
207     def __new__(cls, req, sess):
208         self = super(specdirty, cls).__new__(cls)
209         self.session = sess
210         self.__sslots__ = [specslot.unbound] * len(cls.__sslots_a__)
211         self.__specinit__()
212         self._is_dirty = False
213         return self
214
215     def __getnewargs__(self):
216         return (None, self.session)
217
218     def dirty(self):
219         self._is_dirty = True
220
221     def sessfrozen(self):
222         self._is_dirty = False
223
224     def sessdirty(self):
225         return self._is_dirty
226
227     def __getstate__(self):
228         ret = {}
229         for nm, val in zip(type(self).__sslots_a__, specslot.slist(self)):
230             if val is specslot.unbound:
231                 ret[nm] = False, None
232             else:
233                 ret[nm] = True, val
234         return ret
235
236     def __setstate__(self, st):
237         ss = specslot.slist(self)
238         for i, nm in enumerate(type(self).__sslots_a__):
239             bound, val = st.pop(nm, (False, None))
240             if not bound:
241                 ss[i] = specslot.unbound
242             else:
243                 ss[i] = val
244
245 def datecheck(req, mtime):
246     if "If-Modified-Since" in req.ihead:
247         rtime = proto.phttpdate(req.ihead["If-Modified-Since"])
248         if rtime >= math.floor(mtime):
249             raise resp.unmodified()
250     req.ohead["Last-Modified"] = proto.httpdate(mtime)