Merge branch 'master' into python3
[wrw.git] / wrw / util.py
1 import inspect, math
2 from . import req, dispatch, session, form, resp, proto
3
4 def wsgiwrap(callable):
5     def wrapper(env, startreq):
6         return dispatch.handleenv(env, startreq, callable)
7     wrapper.__wrapped__ = callable
8     return wrapper
9
10 def stringwrap(charset):
11     def dec(callable):
12         def wrapper(*args, **kwargs):
13             bk = callable(*args, **kwargs)
14             for string in bk:
15                 yield string.encode(charset)
16         return wrapper
17     return dec
18
19 def formparams(callable):
20     def wrapper(req):
21         data = form.formdata(req)
22         spec = inspect.getargspec(callable)
23         args = dict(data.items())
24         args["req"] = req
25         if not spec.keywords:
26             for arg in list(args):
27                 if arg not in spec.args:
28                     del args[arg]
29         for i in range(len(spec.args) - len(spec.defaults)):
30             if spec.args[i] not in args:
31                 raise resp.httperror(400, "Missing parameter", ("The query parameter `", resp.h.code(spec.args[i]), "' is required but not supplied."))
32         return callable(**args)
33     wrapper.__wrapped__ = callable
34     return wrapper
35
36 def funplex(*funs, **nfuns):
37     def unwrap(fun):
38         while hasattr(fun, "__wrapped__"):
39             fun = fun.__wrapped__
40         return fun
41     dir = {}
42     dir.update(((unwrap(fun).__name__, fun) for fun in funs))
43     dir.update(nfuns)
44     def handler(req):
45         if req.pathinfo == "":
46             raise resp.redirect(req.uriname + "/")
47         if req.pathinfo[:1] != "/":
48             raise resp.notfound()
49         p = req.pathinfo[1:]
50         if p == "":
51             p = "__index__"
52             bi = 1
53         else:
54             p = p.partition("/")[0]
55             bi = len(p) + 1
56         if p in dir:
57             return dir[p](req.shift(bi))
58         raise resp.notfound()
59     return handler
60
61 def persession(data = None):
62     def dec(callable):
63         def wrapper(req):
64             sess = session.get(req)
65             if callable not in sess:
66                 if data is None:
67                     sess[callable] = callable()
68                 else:
69                     if data not in sess:
70                         sess[data] = data()
71                     sess[callable] = callable(data)
72             return sess[callable].handle(req)
73         wrapper.__wrapped__ = callable
74         return wrapper
75     return dec
76
77 class preiter(object):
78     __slots__ = ["bk", "bki", "_next"]
79     end = object()
80     def __init__(self, real):
81         self.bk = real
82         self.bki = iter(real)
83         self._next = None
84         self.__next__()
85
86     def __iter__(self):
87         return self
88
89     def __next__(self):
90         if self._next is self.end:
91             raise StopIteration()
92         ret = self._next
93         try:
94             self._next = next(self.bki)
95         except StopIteration:
96             self._next = self.end
97         return ret
98
99     def close(self):
100         if hasattr(self.bk, "close"):
101             self.bk.close()
102
103 def pregen(callable):
104     def wrapper(*args, **kwargs):
105         return preiter(callable(*args, **kwargs))
106     wrapper.__wrapped__ = callable
107     return wrapper
108
109 class sessiondata(object):
110     @classmethod
111     def get(cls, req, create = True):
112         sess = cls.sessdb().get(req)
113         with sess.lock:
114             try:
115                 return sess[cls]
116             except KeyError:
117                 if not create:
118                     return None
119                 ret = cls(req, sess)
120                 sess[cls] = ret
121                 return ret
122
123     @classmethod
124     def sessdb(cls):
125         return session.default.val
126
127 class autodirty(sessiondata):
128     @classmethod
129     def get(cls, req):
130         ret = super().get(req)
131         if "_is_dirty" not in ret.__dict__:
132             ret.__dict__["_is_dirty"] = False
133         return ret
134
135     def sessfrozen(self):
136         self.__dict__["_is_dirty"] = False
137
138     def sessdirty(self):
139         return self._is_dirty
140
141     def __setattr__(self, name, value):
142         super().__setattr__(name, value)
143         if "_is_dirty" in self.__dict__:
144             self.__dict__["_is_dirty"] = True
145
146     def __delattr__(self, name):
147         super().__delattr__(name, value)
148         if "_is_dirty" in self.__dict__:
149             self.__dict__["_is_dirty"] = True
150
151 class manudirty(object):
152     def __init__(self, *args, **kwargs):
153         super().__init__(*args, **kwargs)
154         self.__dirty = False
155
156     def sessfrozen(self):
157         self.__dirty = False
158
159     def sessdirty(self):
160         return self.__dirty
161
162     def dirty(self):
163         self.__dirty = True
164
165 class specslot(object):
166     __slots__ = ["nm", "idx", "dirty"]
167     unbound = object()
168     
169     def __init__(self, nm, idx, dirty):
170         self.nm = nm
171         self.idx = idx
172         self.dirty = dirty
173
174     @staticmethod
175     def slist(ins):
176         # Avoid calling __getattribute__
177         return specdirty.__sslots__.__get__(ins, type(ins))
178
179     def __get__(self, ins, cls):
180         val = self.slist(ins)[self.idx]
181         if val is specslot.unbound:
182             raise AttributeError("specslot %r is unbound" % self.nm)
183         return val
184
185     def __set__(self, ins, val):
186         self.slist(ins)[self.idx] = val
187         if self.dirty:
188             ins.dirty()
189
190     def __delete__(self, ins):
191         self.slist(ins)[self.idx] = specslot.unbound
192         ins.dirty()
193
194 class specclass(type):
195     def __init__(self, name, bases, tdict):
196         super().__init__(name, bases, tdict)
197         sslots = set()
198         dslots = set()
199         for cls in self.__mro__:
200             css = cls.__dict__.get("__saveslots__", ())
201             sslots.update(css)
202             dslots.update(cls.__dict__.get("__dirtyslots__", css))
203         self.__sslots_l__ = list(sslots)
204         self.__sslots_a__ = list(sslots | dslots)
205         for i, slot in enumerate(self.__sslots_a__):
206             setattr(self, slot, specslot(slot, i, slot in dslots))
207
208 class specdirty(sessiondata, metaclass=specclass):
209     __slots__ = ["session", "__sslots__", "_is_dirty"]
210     
211     def __specinit__(self):
212         pass
213
214     @staticmethod
215     def __new__(cls, req, sess):
216         self = super().__new__(cls)
217         self.session = sess
218         self.__sslots__ = [specslot.unbound] * len(cls.__sslots_a__)
219         self.__specinit__()
220         self._is_dirty = False
221         return self
222
223     def __getnewargs__(self):
224         return (None, self.session)
225
226     def dirty(self):
227         self._is_dirty = True
228
229     def sessfrozen(self):
230         self._is_dirty = False
231
232     def sessdirty(self):
233         return self._is_dirty
234
235     def __getstate__(self):
236         ret = {}
237         for nm, val in zip(type(self).__sslots_a__, specslot.slist(self)):
238             if val is specslot.unbound:
239                 ret[nm] = False, None
240             else:
241                 ret[nm] = True, val
242         return ret
243
244     def __setstate__(self, st):
245         ss = specslot.slist(self)
246         for i, nm in enumerate(type(self).__sslots_a__):
247             bound, val = st.pop(nm, (False, None))
248             if not bound:
249                 ss[i] = specslot.unbound
250             else:
251                 ss[i] = val
252
253 def datecheck(req, mtime):
254     if "If-Modified-Since" in req.ihead:
255         rtime = proto.phttpdate(req.ihead["If-Modified-Since"])
256         if rtime >= math.floor(mtime):
257             raise resp.unmodified()
258     req.ohead["Last-Modified"] = proto.httpdate(mtime)