dcf745ac88eb2eadf05c0f99ddeeb3f07c96af50
[wrw.git] / wrw / util.py
1 import inspect, math
2 from . import req, dispatch, session, form, resp, proto
3
4 def wsgiwrap(callable):
5     def wrapper(env, startreq):
6         return dispatch.handleenv(env, startreq, callable)
7     wrapper.__wrapped__ = callable
8     return wrapper
9
10 def formparams(callable):
11     sig = inspect.signature(callable)
12     haskw = inspect.Parameter.VAR_KEYWORD in (par.kind for par in sig.parameters.values())
13     def wrapper(req):
14         try:
15             data = dict(form.formdata(req).items())
16         except IOError:
17             raise resp.httperror(400, "Invalid request", "Form data was incomplete")
18         
19         data["req"] = req
20         if haskw:
21             args = data
22         else:
23             args = {}
24             for par in sig.parameters.values():
25                 if par.name in data:
26                     args[par.name] = data[par.name]
27         for par in sig.parameters.values():
28             if par.default is inspect.Parameter.empty and par.name not in args:
29                 raise resp.httperror(400, "Missing parameter", ("The query parameter `", resp.h.code(par.name), "' is required but not supplied."))
30         return callable(**args)
31     wrapper.__wrapped__ = callable
32     return wrapper
33
34 class funplex(object):
35     def __init__(self, *funs, **nfuns):
36         self.dir = {}
37         self.dir.update(((self.unwrap(fun).__name__, fun) for fun in funs))
38         self.dir.update(nfuns)
39
40     @staticmethod
41     def unwrap(fun):
42         while hasattr(fun, "__wrapped__"):
43             fun = fun.__wrapped__
44         return fun
45
46     def __call__(self, req):
47         if req.pathinfo == "":
48             if "__root__" in self.dir:
49                 return self.dir["__root__"](req)
50             raise resp.redirect(req.uriname + "/")
51         if req.pathinfo[:1] != "/":
52             raise resp.notfound()
53         p = req.pathinfo[1:]
54         if p == "":
55             p = "__index__"
56             bi = 1
57         else:
58             p = p.partition("/")[0]
59             bi = len(p) + 1
60         if p in self.dir:
61             sreq = req.shift(bi)
62             sreq.selfpath = req.pathinfo[1:]
63             return self.dir[p](sreq)
64         raise resp.notfound()
65
66     def add(self, fun):
67         self.dir[self.unwrap(fun).__name__] = fun
68         return fun
69
70     def name(self, name):
71         def dec(fun):
72             self.dir[name] = fun
73             return fun
74         return dec
75
76 def persession(data=None):
77     def dec(callable):
78         def wrapper(req):
79             sess = session.get(req)
80             if callable not in sess:
81                 if data is None:
82                     sess[callable] = callable()
83                 else:
84                     if data not in sess:
85                         sess[data] = data()
86                     sess[callable] = callable(data)
87             return sess[callable].handle(req)
88         wrapper.__wrapped__ = callable
89         return wrapper
90     return dec
91
92 class preiter(object):
93     __slots__ = ["bk", "bki", "_next"]
94     end = object()
95     def __init__(self, real):
96         self.bk = real
97         self.bki = iter(real)
98         self._next = None
99         self.__next__()
100
101     def __iter__(self):
102         return self
103
104     def __next__(self):
105         if self._next is self.end:
106             raise StopIteration()
107         ret = self._next
108         try:
109             self._next = next(self.bki)
110         except StopIteration:
111             self._next = self.end
112         return ret
113
114     def close(self):
115         if hasattr(self.bk, "close"):
116             self.bk.close()
117
118 def pregen(callable):
119     def wrapper(*args, **kwargs):
120         return preiter(callable(*args, **kwargs))
121     wrapper.__wrapped__ = callable
122     return wrapper
123
124 def stringwrap(charset):
125     def dec(callable):
126         @pregen
127         def wrapper(*args, **kwargs):
128             for string in callable(*args, **kwargs):
129                 yield string.encode(charset)
130         wrapper.__wrapped__ = callable
131         return wrapper
132     return dec
133
134 class sessiondata(object):
135     @classmethod
136     def get(cls, req, create=True):
137         sess = cls.sessdb().get(req)
138         with sess.lock:
139             try:
140                 return sess[cls]
141             except KeyError:
142                 if not create:
143                     return None
144                 ret = cls(req, sess)
145                 sess[cls] = ret
146                 return ret
147
148     @classmethod
149     def sessdb(cls):
150         return session.default.val
151
152 class autodirty(sessiondata):
153     @classmethod
154     def get(cls, req):
155         ret = super().get(req)
156         if "_is_dirty" not in ret.__dict__:
157             ret.__dict__["_is_dirty"] = False
158         return ret
159
160     def sessfrozen(self):
161         self.__dict__["_is_dirty"] = False
162
163     def sessdirty(self):
164         return self._is_dirty
165
166     def __setattr__(self, name, value):
167         super().__setattr__(name, value)
168         if "_is_dirty" in self.__dict__:
169             self.__dict__["_is_dirty"] = True
170
171     def __delattr__(self, name):
172         super().__delattr__(name, value)
173         if "_is_dirty" in self.__dict__:
174             self.__dict__["_is_dirty"] = True
175
176 class manudirty(object):
177     def __init__(self, *args, **kwargs):
178         super().__init__(*args, **kwargs)
179         self.__dirty = False
180
181     def sessfrozen(self):
182         self.__dirty = False
183
184     def sessdirty(self):
185         return self.__dirty
186
187     def dirty(self):
188         self.__dirty = True
189
190 class specslot(object):
191     __slots__ = ["nm", "idx", "dirty"]
192     unbound = object()
193     
194     def __init__(self, nm, idx, dirty):
195         self.nm = nm
196         self.idx = idx
197         self.dirty = dirty
198
199     @staticmethod
200     def slist(ins):
201         # Avoid calling __getattribute__
202         return specdirty.__sslots__.__get__(ins, type(ins))
203
204     def __get__(self, ins, cls):
205         val = self.slist(ins)[self.idx]
206         if val is specslot.unbound:
207             raise AttributeError("specslot %r is unbound" % self.nm)
208         return val
209
210     def __set__(self, ins, val):
211         self.slist(ins)[self.idx] = val
212         if self.dirty:
213             ins.dirty()
214
215     def __delete__(self, ins):
216         self.slist(ins)[self.idx] = specslot.unbound
217         ins.dirty()
218
219 class specclass(type):
220     def __init__(self, name, bases, tdict):
221         super().__init__(name, bases, tdict)
222         sslots = set()
223         dslots = set()
224         for cls in self.__mro__:
225             css = cls.__dict__.get("__saveslots__", ())
226             sslots.update(css)
227             dslots.update(cls.__dict__.get("__dirtyslots__", css))
228         self.__sslots_l__ = list(sslots)
229         self.__sslots_a__ = list(sslots | dslots)
230         for i, slot in enumerate(self.__sslots_a__):
231             setattr(self, slot, specslot(slot, i, slot in dslots))
232
233 class specdirty(sessiondata, metaclass=specclass):
234     __slots__ = ["session", "__sslots__", "_is_dirty"]
235     
236     def __specinit__(self):
237         pass
238
239     @staticmethod
240     def __new__(cls, req, sess):
241         self = super().__new__(cls)
242         self.session = sess
243         self.__sslots__ = [specslot.unbound] * len(cls.__sslots_a__)
244         self.__specinit__()
245         self._is_dirty = False
246         return self
247
248     def __getnewargs__(self):
249         return (None, self.session)
250
251     def dirty(self):
252         self._is_dirty = True
253
254     def sessfrozen(self):
255         self._is_dirty = False
256
257     def sessdirty(self):
258         return self._is_dirty
259
260     def __getstate__(self):
261         ret = {}
262         for nm, val in zip(type(self).__sslots_a__, specslot.slist(self)):
263             if val is specslot.unbound:
264                 ret[nm] = False, None
265             else:
266                 ret[nm] = True, val
267         return ret
268
269     def __setstate__(self, st):
270         ss = specslot.slist(self)
271         for i, nm in enumerate(type(self).__sslots_a__):
272             bound, val = st.pop(nm, (False, None))
273             if not bound:
274                 ss[i] = specslot.unbound
275             else:
276                 ss[i] = val
277
278 def datecheck(req, mtime):
279     if "If-Modified-Since" in req.ihead:
280         rtime = proto.phttpdate(req.ihead["If-Modified-Since"])
281         if rtime is not None and rtime >= math.floor(mtime):
282             raise resp.unmodified()
283     req.ohead["Last-Modified"] = proto.httpdate(mtime)