Include a selfpath in requests derived from a funplex.
[wrw.git] / wrw / util.py
1 import inspect, math
2 from . import req, dispatch, session, form, resp, proto
3
4 def wsgiwrap(callable):
5     def wrapper(env, startreq):
6         return dispatch.handleenv(env, startreq, callable)
7     wrapper.__wrapped__ = callable
8     return wrapper
9
10 def formparams(callable):
11     spec = inspect.getargspec(callable)
12     def wrapper(req):
13         data = form.formdata(req)
14         args = dict(data.items())
15         args["req"] = req
16         if not spec.keywords:
17             for arg in list(args):
18                 if arg not in spec.args:
19                     del args[arg]
20         for i in range(len(spec.args) - (len(spec.defaults) if spec.defaults else 0)):
21             if spec.args[i] not in args:
22                 raise resp.httperror(400, "Missing parameter", ("The query parameter `", resp.h.code(spec.args[i]), "' is required but not supplied."))
23         return callable(**args)
24     wrapper.__wrapped__ = callable
25     return wrapper
26
27 class funplex(object):
28     def __init__(self, *funs, **nfuns):
29         self.dir = {}
30         self.dir.update(((self.unwrap(fun).__name__, fun) for fun in funs))
31         self.dir.update(nfuns)
32
33     @staticmethod
34     def unwrap(fun):
35         while hasattr(fun, "__wrapped__"):
36             fun = fun.__wrapped__
37         return fun
38
39     def __call__(self, req):
40         if req.pathinfo == "":
41             raise resp.redirect(req.uriname + "/")
42         if req.pathinfo[:1] != "/":
43             raise resp.notfound()
44         p = req.pathinfo[1:]
45         if p == "":
46             p = "__index__"
47             bi = 1
48         else:
49             p = p.partition("/")[0]
50             bi = len(p) + 1
51         if p in self.dir:
52             sreq = req.shift(bi)
53             sreq.selfpath = req.pathinfo[1:]
54             return self.dir[p](sreq)
55         raise resp.notfound()
56
57     def add(self, fun):
58         self.dir[self.unwrap(fun).__name__] = fun
59         return fun
60
61     def name(self, name):
62         def dec(fun):
63             self.dir[name] = fun
64             return fun
65         return dec
66
67 def persession(data=None):
68     def dec(callable):
69         def wrapper(req):
70             sess = session.get(req)
71             if callable not in sess:
72                 if data is None:
73                     sess[callable] = callable()
74                 else:
75                     if data not in sess:
76                         sess[data] = data()
77                     sess[callable] = callable(data)
78             return sess[callable].handle(req)
79         wrapper.__wrapped__ = callable
80         return wrapper
81     return dec
82
83 class preiter(object):
84     __slots__ = ["bk", "bki", "_next"]
85     end = object()
86     def __init__(self, real):
87         self.bk = real
88         self.bki = iter(real)
89         self._next = None
90         self.__next__()
91
92     def __iter__(self):
93         return self
94
95     def __next__(self):
96         if self._next is self.end:
97             raise StopIteration()
98         ret = self._next
99         try:
100             self._next = next(self.bki)
101         except StopIteration:
102             self._next = self.end
103         return ret
104
105     def close(self):
106         if hasattr(self.bk, "close"):
107             self.bk.close()
108
109 def pregen(callable):
110     def wrapper(*args, **kwargs):
111         return preiter(callable(*args, **kwargs))
112     wrapper.__wrapped__ = callable
113     return wrapper
114
115 def stringwrap(charset):
116     def dec(callable):
117         @pregen
118         def wrapper(*args, **kwargs):
119             for string in callable(*args, **kwargs):
120                 yield string.encode(charset)
121         wrapper.__wrapped__ = callable
122         return wrapper
123     return dec
124
125 class sessiondata(object):
126     @classmethod
127     def get(cls, req, create=True):
128         sess = cls.sessdb().get(req)
129         with sess.lock:
130             try:
131                 return sess[cls]
132             except KeyError:
133                 if not create:
134                     return None
135                 ret = cls(req, sess)
136                 sess[cls] = ret
137                 return ret
138
139     @classmethod
140     def sessdb(cls):
141         return session.default.val
142
143 class autodirty(sessiondata):
144     @classmethod
145     def get(cls, req):
146         ret = super().get(req)
147         if "_is_dirty" not in ret.__dict__:
148             ret.__dict__["_is_dirty"] = False
149         return ret
150
151     def sessfrozen(self):
152         self.__dict__["_is_dirty"] = False
153
154     def sessdirty(self):
155         return self._is_dirty
156
157     def __setattr__(self, name, value):
158         super().__setattr__(name, value)
159         if "_is_dirty" in self.__dict__:
160             self.__dict__["_is_dirty"] = True
161
162     def __delattr__(self, name):
163         super().__delattr__(name, value)
164         if "_is_dirty" in self.__dict__:
165             self.__dict__["_is_dirty"] = True
166
167 class manudirty(object):
168     def __init__(self, *args, **kwargs):
169         super().__init__(*args, **kwargs)
170         self.__dirty = False
171
172     def sessfrozen(self):
173         self.__dirty = False
174
175     def sessdirty(self):
176         return self.__dirty
177
178     def dirty(self):
179         self.__dirty = True
180
181 class specslot(object):
182     __slots__ = ["nm", "idx", "dirty"]
183     unbound = object()
184     
185     def __init__(self, nm, idx, dirty):
186         self.nm = nm
187         self.idx = idx
188         self.dirty = dirty
189
190     @staticmethod
191     def slist(ins):
192         # Avoid calling __getattribute__
193         return specdirty.__sslots__.__get__(ins, type(ins))
194
195     def __get__(self, ins, cls):
196         val = self.slist(ins)[self.idx]
197         if val is specslot.unbound:
198             raise AttributeError("specslot %r is unbound" % self.nm)
199         return val
200
201     def __set__(self, ins, val):
202         self.slist(ins)[self.idx] = val
203         if self.dirty:
204             ins.dirty()
205
206     def __delete__(self, ins):
207         self.slist(ins)[self.idx] = specslot.unbound
208         ins.dirty()
209
210 class specclass(type):
211     def __init__(self, name, bases, tdict):
212         super().__init__(name, bases, tdict)
213         sslots = set()
214         dslots = set()
215         for cls in self.__mro__:
216             css = cls.__dict__.get("__saveslots__", ())
217             sslots.update(css)
218             dslots.update(cls.__dict__.get("__dirtyslots__", css))
219         self.__sslots_l__ = list(sslots)
220         self.__sslots_a__ = list(sslots | dslots)
221         for i, slot in enumerate(self.__sslots_a__):
222             setattr(self, slot, specslot(slot, i, slot in dslots))
223
224 class specdirty(sessiondata, metaclass=specclass):
225     __slots__ = ["session", "__sslots__", "_is_dirty"]
226     
227     def __specinit__(self):
228         pass
229
230     @staticmethod
231     def __new__(cls, req, sess):
232         self = super().__new__(cls)
233         self.session = sess
234         self.__sslots__ = [specslot.unbound] * len(cls.__sslots_a__)
235         self.__specinit__()
236         self._is_dirty = False
237         return self
238
239     def __getnewargs__(self):
240         return (None, self.session)
241
242     def dirty(self):
243         self._is_dirty = True
244
245     def sessfrozen(self):
246         self._is_dirty = False
247
248     def sessdirty(self):
249         return self._is_dirty
250
251     def __getstate__(self):
252         ret = {}
253         for nm, val in zip(type(self).__sslots_a__, specslot.slist(self)):
254             if val is specslot.unbound:
255                 ret[nm] = False, None
256             else:
257                 ret[nm] = True, val
258         return ret
259
260     def __setstate__(self, st):
261         ss = specslot.slist(self)
262         for i, nm in enumerate(type(self).__sslots_a__):
263             bound, val = st.pop(nm, (False, None))
264             if not bound:
265                 ss[i] = specslot.unbound
266             else:
267                 ss[i] = val
268
269 def datecheck(req, mtime):
270     if "If-Modified-Since" in req.ihead:
271         rtime = proto.phttpdate(req.ihead["If-Modified-Since"])
272         if rtime is not None and rtime >= math.floor(mtime):
273             raise resp.unmodified()
274     req.ohead["Last-Modified"] = proto.httpdate(mtime)