Merge branch 'master' into python3
[wrw.git] / wrw / util.py
1 import inspect
2 from . import req, dispatch, session, form, resp
3
4 def wsgiwrap(callable):
5     def wrapper(env, startreq):
6         return dispatch.handleenv(env, startreq, callable)
7     return wrapper
8
9 def stringwrap(charset):
10     def dec(callable):
11         def wrapper(*args, **kwargs):
12             bk = callable(*args, **kwargs)
13             for string in bk:
14                 yield string.encode(charset)
15         return wrapper
16     return dec
17
18 def formparams(callable):
19     def wrapper(req):
20         data = form.formdata(req)
21         spec = inspect.getargspec(callable)
22         args = dict(data.items())
23         args["req"] = req
24         if not spec.keywords:
25             for arg in list(args):
26                 if arg not in spec.args:
27                     del args[arg]
28         for i in range(len(spec.args) - len(spec.defaults)):
29             if spec.args[i] not in args:
30                 raise resp.httperror(400, "Missing parameter", ("The query parameter `", resp.h.code(spec.args[i]), "' is required but not supplied."))
31         return callable(**args)
32     return wrapper
33
34 def funplex(*funs, **nfuns):
35     dir = {}
36     dir.update(((fun.__name__, fun) for fun in funs))
37     dir.update(nfuns)
38     def handler(req):
39         if req.pathinfo == "":
40             raise resp.redirect(req.uriname + "/")
41         if req.pathinfo[:1] != "/":
42             raise resp.notfound()
43         p = req.pathinfo[1:]
44         if p == "":
45             p = "__index__"
46             bi = 1
47         else:
48             p = p.partition("/")[0]
49             bi = len(p) + 1
50         if p in dir:
51             return dir[p](req.shift(bi))
52         raise resp.notfound()
53     return handler
54
55 def persession(data = None):
56     def dec(callable):
57         def wrapper(req):
58             sess = session.get(req)
59             if callable not in sess:
60                 if data is None:
61                     sess[callable] = callable()
62                 else:
63                     if data not in sess:
64                         sess[data] = data()
65                     sess[callable] = callable(data)
66             return sess[callable].handle(req)
67         return wrapper
68     return dec
69
70 class preiter(object):
71     __slots__ = ["bk", "bki", "_next"]
72     end = object()
73     def __init__(self, real):
74         self.bk = real
75         self.bki = iter(real)
76         self._next = None
77         self.__next__()
78
79     def __iter__(self):
80         return self
81
82     def __next__(self):
83         if self._next is self.end:
84             raise StopIteration()
85         ret = self._next
86         try:
87             self._next = next(self.bki)
88         except StopIteration:
89             self._next = self.end
90         return ret
91
92     def close(self):
93         if hasattr(self.bk, "close"):
94             self.bk.close()
95
96 def pregen(callable):
97     def wrapper(*args, **kwargs):
98         return preiter(callable(*args, **kwargs))
99     return wrapper
100
101 class sessiondata(object):
102     @classmethod
103     def get(cls, req, create = True):
104         sess = cls.sessdb().get(req)
105         with sess.lock:
106             try:
107                 return sess[cls]
108             except KeyError:
109                 if not create:
110                     return None
111                 ret = cls(req, sess)
112                 sess[cls] = ret
113                 return ret
114
115     @classmethod
116     def sessdb(cls):
117         return session.default.val
118
119 class autodirty(sessiondata):
120     @classmethod
121     def get(cls, req):
122         ret = super().get(req)
123         if "_is_dirty" not in ret.__dict__:
124             ret.__dict__["_is_dirty"] = False
125         return ret
126
127     def sessfrozen(self):
128         self.__dict__["_is_dirty"] = False
129
130     def sessdirty(self):
131         return self._is_dirty
132
133     def __setattr__(self, name, value):
134         super().__setattr__(name, value)
135         if "_is_dirty" in self.__dict__:
136             self.__dict__["_is_dirty"] = True
137
138     def __delattr__(self, name):
139         super().__delattr__(name, value)
140         if "_is_dirty" in self.__dict__:
141             self.__dict__["_is_dirty"] = True
142
143 class manudirty(object):
144     def __init__(self, *args, **kwargs):
145         super().__init__(*args, **kwargs)
146         self.__dirty = False
147
148     def sessfrozen(self):
149         self.__dirty = False
150
151     def sessdirty(self):
152         return self.__dirty
153
154     def dirty(self):
155         self.__dirty = True
156
157 class specslot(object):
158     __slots__ = ["nm", "idx", "dirty"]
159     unbound = object()
160     
161     def __init__(self, nm, idx, dirty):
162         self.nm = nm
163         self.idx = idx
164         self.dirty = dirty
165
166     @staticmethod
167     def slist(ins):
168         # Avoid calling __getattribute__
169         return specdirty.__sslots__.__get__(ins, type(ins))
170
171     def __get__(self, ins, cls):
172         val = self.slist(ins)[self.idx]
173         if val is specslot.unbound:
174             raise AttributeError("specslot %r is unbound" % self.nm)
175         return val
176
177     def __set__(self, ins, val):
178         self.slist(ins)[self.idx] = val
179         if self.dirty:
180             ins.dirty()
181
182     def __delete__(self, ins):
183         self.slist(ins)[self.idx] = specslot.unbound
184         ins.dirty()
185
186 class specclass(type):
187     def __init__(self, name, bases, tdict):
188         super().__init__(name, bases, tdict)
189         sslots = set()
190         dslots = set()
191         for cls in self.__mro__:
192             css = cls.__dict__.get("__saveslots__", ())
193             sslots.update(css)
194             dslots.update(cls.__dict__.get("__dirtyslots__", css))
195         self.__sslots_l__ = list(sslots)
196         self.__sslots_a__ = list(sslots | dslots)
197         for i, slot in enumerate(self.__sslots_a__):
198             setattr(self, slot, specslot(slot, i, slot in dslots))
199
200 class specdirty(sessiondata, metaclass=specclass):
201     __slots__ = ["session", "__sslots__", "_is_dirty"]
202     
203     def __specinit__(self):
204         pass
205
206     @staticmethod
207     def __new__(cls, req, sess):
208         self = super().__new__(cls)
209         self.session = sess
210         self.__sslots__ = [specslot.unbound] * len(cls.__sslots_a__)
211         self.__specinit__()
212         self._is_dirty = False
213         return self
214
215     def __getnewargs__(self):
216         return (None, self.session)
217
218     def dirty(self):
219         self._is_dirty = True
220
221     def sessfrozen(self):
222         self._is_dirty = False
223
224     def sessdirty(self):
225         return self._is_dirty
226
227     def __getstate__(self):
228         ret = {}
229         for nm, val in zip(type(self).__sslots_a__, specslot.slist(self)):
230             if val is specslot.unbound:
231                 ret[nm] = False, None
232             else:
233                 ret[nm] = True, val
234         return ret
235
236     def __setstate__(self, st):
237         ss = specslot.slist(self)
238         for i, nm in enumerate(type(self).__sslots_a__):
239             bound, val = st.pop(nm, (False, None))
240             if not bound:
241                 ss[i] = specslot.unbound
242             else:
243                 ss[i] = val