Merge branch 'master' into python3
[wrw.git] / wrw / util.py
1 import inspect
2 from . import req, dispatch, session, form, resp
3
4 def wsgiwrap(callable):
5     def wrapper(env, startreq):
6         return dispatch.handleenv(env, startreq, callable)
7     return wrapper
8
9 def stringwrap(charset):
10     def dec(callable):
11         def wrapper(*args, **kwargs):
12             bk = callable(*args, **kwargs)
13             for string in bk:
14                 yield string.encode(charset)
15         return wrapper
16     return dec
17
18 def formparams(callable):
19     def wrapper(req):
20         data = form.formdata(req)
21         spec = inspect.getargspec(callable)
22         args = dict(data.items())
23         args["req"] = req
24         if not spec.keywords:
25             for arg in list(args):
26                 if arg not in spec.args:
27                     del args[arg]
28         for i in range(len(spec.args) - len(spec.defaults)):
29             if spec.args[i] not in args:
30                 raise resp.httperror(400, "Missing parameter", ("The query parameter `", resp.h.code(spec.args[i]), "' is required but not supplied."))
31         return callable(**args)
32     return wrapper
33
34 def persession(data = None):
35     def dec(callable):
36         def wrapper(req):
37             sess = session.get(req)
38             if callable not in sess:
39                 if data is None:
40                     sess[callable] = callable()
41                 else:
42                     if data not in sess:
43                         sess[data] = data()
44                     sess[callable] = callable(data)
45             return sess[callable].handle(req)
46         return wrapper
47     return dec
48
49 class preiter(object):
50     __slots__ = ["bk", "bki", "_next"]
51     end = object()
52     def __init__(self, real):
53         self.bk = real
54         self.bki = iter(real)
55         self._next = None
56         self.__next__()
57
58     def __iter__(self):
59         return self
60
61     def __next__(self):
62         if self._next is self.end:
63             raise StopIteration()
64         ret = self._next
65         try:
66             self._next = next(self.bki)
67         except StopIteration:
68             self._next = self.end
69         return ret
70
71     def close(self):
72         if hasattr(self.bk, "close"):
73             self.bk.close()
74
75 def pregen(callable):
76     def wrapper(*args, **kwargs):
77         return preiter(callable(*args, **kwargs))
78     return wrapper
79
80 class sessiondata(object):
81     @classmethod
82     def get(cls, req, create = True):
83         sess = cls.sessdb().get(req)
84         with sess.lock:
85             try:
86                 return sess[cls]
87             except KeyError:
88                 if not create:
89                     return None
90                 ret = cls(req, sess)
91                 sess[cls] = ret
92                 return ret
93
94     @classmethod
95     def sessdb(cls):
96         return session.default.val
97
98 class autodirty(sessiondata):
99     @classmethod
100     def get(cls, req):
101         ret = super().get(req)
102         if "_is_dirty" not in ret.__dict__:
103             ret.__dict__["_is_dirty"] = False
104         return ret
105
106     def sessfrozen(self):
107         self.__dict__["_is_dirty"] = False
108
109     def sessdirty(self):
110         return self._is_dirty
111
112     def __setattr__(self, name, value):
113         super().__setattr__(name, value)
114         if "_is_dirty" in self.__dict__:
115             self.__dict__["_is_dirty"] = True
116
117     def __delattr__(self, name):
118         super().__delattr__(name, value)
119         if "_is_dirty" in self.__dict__:
120             self.__dict__["_is_dirty"] = True
121
122 class manudirty(object):
123     def __init__(self, *args, **kwargs):
124         super().__init__(*args, **kwargs)
125         self.__dirty = False
126
127     def sessfrozen(self):
128         self.__dirty = False
129
130     def sessdirty(self):
131         return self.__dirty
132
133     def dirty(self):
134         self.__dirty = True
135
136 class specslot(object):
137     __slots__ = ["nm", "idx", "dirty"]
138     unbound = object()
139     
140     def __init__(self, nm, idx, dirty):
141         self.nm = nm
142         self.idx = idx
143         self.dirty = dirty
144
145     @staticmethod
146     def slist(ins):
147         # Avoid calling __getattribute__
148         return specdirty.__sslots__.__get__(ins, type(ins))
149
150     def __get__(self, ins, cls):
151         val = self.slist(ins)[self.idx]
152         if val is specslot.unbound:
153             raise AttributeError("specslot %r is unbound" % self.nm)
154         return val
155
156     def __set__(self, ins, val):
157         self.slist(ins)[self.idx] = val
158         if self.dirty:
159             ins.dirty()
160
161     def __delete__(self, ins):
162         self.slist(ins)[self.idx] = specslot.unbound
163         ins.dirty()
164
165 class specclass(type):
166     def __init__(self, name, bases, tdict):
167         super().__init__(name, bases, tdict)
168         sslots = set()
169         dslots = set()
170         for cls in self.__mro__:
171             css = cls.__dict__.get("__saveslots__", ())
172             sslots.update(css)
173             dslots.update(cls.__dict__.get("__dirtyslots__", css))
174         self.__sslots_l__ = list(sslots)
175         self.__sslots_a__ = list(sslots | dslots)
176         for i, slot in enumerate(self.__sslots_a__):
177             setattr(self, slot, specslot(slot, i, slot in dslots))
178
179 class specdirty(sessiondata, metaclass=specclass):
180     __slots__ = ["session", "__sslots__", "_is_dirty"]
181     
182     def __specinit__(self):
183         pass
184
185     @staticmethod
186     def __new__(cls, req, sess):
187         self = super().__new__(cls)
188         self.session = sess
189         self.__sslots__ = [specslot.unbound] * len(cls.__sslots_a__)
190         self.__specinit__()
191         self._is_dirty = False
192         return self
193
194     def __getnewargs__(self):
195         return (None, self.session)
196
197     def dirty(self):
198         self._is_dirty = True
199
200     def sessfrozen(self):
201         self._is_dirty = False
202
203     def sessdirty(self):
204         return self._is_dirty
205
206     def __getstate__(self):
207         ret = {}
208         for nm, val in zip(type(self).__sslots_a__, specslot.slist(self)):
209             if val is specslot.unbound:
210                 ret[nm] = False, None
211             else:
212                 ret[nm] = True, val
213         return ret
214
215     def __setstate__(self, st):
216         ss = specslot.slist(self)
217         for i, nm in enumerate(type(self).__sslots_a__):
218             bound, val = st.pop(nm, (False, None))
219             if not bound:
220                 ss[i] = specslot.unbound
221             else:
222                 ss[i] = val