Merge branch 'master' into python3
[wrw.git] / wrw / util.py
1 import inspect
2 from . import req, dispatch, session, form
3
4 def wsgiwrap(callable):
5     def wrapper(env, startreq):
6         return dispatch.handleenv(env, startreq, callable)
7     return wrapper
8
9 def stringwrap(charset):
10     def dec(callable):
11         def wrapper(*args, **kwargs):
12             bk = callable(*args, **kwargs)
13             for string in bk:
14                 yield string.encode(charset)
15         return wrapper
16     return dec
17
18 def formparams(callable):
19     def wrapper(req):
20         data = form.formdata(req)
21         spec = inspect.getargspec(callable)
22         args = dict(data.items())
23         args["req"] = req
24         if not spec.keywords:
25             for arg in list(args):
26                 if arg not in spec.args:
27                     del args[arg]
28         return callable(**args)
29     return wrapper
30
31 def persession(data = None):
32     def dec(callable):
33         def wrapper(req):
34             sess = session.get(req)
35             if callable not in sess:
36                 if data is None:
37                     sess[callable] = callable()
38                 else:
39                     if data not in sess:
40                         sess[data] = data()
41                     sess[callable] = callable(data)
42             return sess[callable].handle(req)
43         return wrapper
44     return dec
45
46 class preiter(object):
47     __slots__ = ["bk", "bki", "_next"]
48     end = object()
49     def __init__(self, real):
50         self.bk = real
51         self.bki = iter(real)
52         self._next = None
53         self.__next__()
54
55     def __iter__(self):
56         return self
57
58     def __next__(self):
59         if self._next is self.end:
60             raise StopIteration()
61         ret = self._next
62         try:
63             self._next = next(self.bki)
64         except StopIteration:
65             self._next = self.end
66         return ret
67
68     def close(self):
69         if hasattr(self.bk, "close"):
70             self.bk.close()
71
72 def pregen(callable):
73     def wrapper(*args, **kwargs):
74         return preiter(callable(*args, **kwargs))
75     return wrapper
76
77 class sessiondata(object):
78     @classmethod
79     def get(cls, req, create = True):
80         sess = cls.sessdb().get(req)
81         with sess.lock:
82             try:
83                 return sess[cls]
84             except KeyError:
85                 if not create:
86                     return None
87                 ret = cls(req, sess)
88                 sess[cls] = ret
89                 return ret
90
91     @classmethod
92     def sessdb(cls):
93         return session.default.val
94
95 class autodirty(sessiondata):
96     @classmethod
97     def get(cls, req):
98         ret = super().get(req)
99         if "_is_dirty" not in ret.__dict__:
100             ret.__dict__["_is_dirty"] = False
101         return ret
102
103     def sessfrozen(self):
104         self.__dict__["_is_dirty"] = False
105
106     def sessdirty(self):
107         return self._is_dirty
108
109     def __setattr__(self, name, value):
110         super().__setattr__(name, value)
111         if "_is_dirty" in self.__dict__:
112             self.__dict__["_is_dirty"] = True
113
114     def __delattr__(self, name):
115         super().__delattr__(name, value)
116         if "_is_dirty" in self.__dict__:
117             self.__dict__["_is_dirty"] = True
118
119 class manudirty(object):
120     def __init__(self, *args, **kwargs):
121         super().__init__(*args, **kwargs)
122         self.__dirty = False
123
124     def sessfrozen(self):
125         self.__dirty = False
126
127     def sessdirty(self):
128         return self.__dirty
129
130     def dirty(self):
131         self.__dirty = True
132
133 class specslot(object):
134     __slots__ = ["nm", "idx", "dirty"]
135     unbound = object()
136     
137     def __init__(self, nm, idx, dirty):
138         self.nm = nm
139         self.idx = idx
140         self.dirty = dirty
141
142     @staticmethod
143     def slist(ins):
144         # Avoid calling __getattribute__
145         return specdirty.__sslots__.__get__(ins, type(ins))
146
147     def __get__(self, ins, cls):
148         val = self.slist(ins)[self.idx]
149         if val is specslot.unbound:
150             raise AttributeError("specslot %r is unbound" % self.nm)
151         return val
152
153     def __set__(self, ins, val):
154         self.slist(ins)[self.idx] = val
155         if self.dirty:
156             ins.dirty()
157
158     def __delete__(self, ins):
159         self.slist(ins)[self.idx] = specslot.unbound
160         ins.dirty()
161
162 class specclass(type):
163     def __init__(self, name, bases, tdict):
164         super().__init__(name, bases, tdict)
165         sslots = set()
166         dslots = set()
167         for cls in self.__mro__:
168             css = cls.__dict__.get("__saveslots__", ())
169             sslots.update(css)
170             dslots.update(cls.__dict__.get("__dirtyslots__", css))
171         self.__sslots_l__ = list(sslots)
172         self.__sslots_a__ = list(sslots | dslots)
173         for i, slot in enumerate(self.__sslots_a__):
174             setattr(self, slot, specslot(slot, i, slot in dslots))
175
176 class specdirty(sessiondata, metaclass=specclass):
177     __slots__ = ["session", "__sslots__", "_is_dirty"]
178     
179     def __specinit__(self):
180         pass
181
182     @staticmethod
183     def __new__(cls, req, sess):
184         self = super().__new__(cls)
185         self.session = sess
186         self.__sslots__ = [specslot.unbound] * len(cls.__sslots_a__)
187         self.__specinit__()
188         self._is_dirty = False
189         return self
190
191     def __getnewargs__(self):
192         return (None, self.session)
193
194     def dirty(self):
195         self._is_dirty = True
196
197     def sessfrozen(self):
198         self._is_dirty = False
199
200     def sessdirty(self):
201         return self._is_dirty
202
203     def __getstate__(self):
204         ret = {}
205         for nm, val in zip(type(self).__sslots_a__, specslot.slist(self)):
206             if val is specslot.unbound:
207                 ret[nm] = False, None
208             else:
209                 ret[nm] = True, val
210         return ret
211
212     def __setstate__(self, st):
213         ss = specslot.slist(self)
214         for i, nm in enumerate(type(self).__sslots_a__):
215             bound, val = st.pop(nm, (False, None))
216             if not bound:
217                 ss[i] = specslot.unbound
218             else:
219                 ss[i] = val