Merge branch 'master' into python3
[wrw.git] / wrw / util.py
1 import inspect
2 from . import req, dispatch, session, form
3
4 def wsgiwrap(callable):
5     def wrapper(env, startreq):
6         return dispatch.handle(req.origrequest(env), startreq, callable)
7     return wrapper
8
9 def stringwrap(charset):
10     def dec(callable):
11         def wrapper(*args, **kwargs):
12             bk = callable(*args, **kwargs)
13             for string in bk:
14                 yield string.encode(charset)
15         return wrapper
16     return dec
17
18 def formparams(callable):
19     def wrapper(req):
20         data = form.formdata(req)
21         spec = inspect.getargspec(callable)
22         args = dict(data.items())
23         args["req"] = req
24         if not spec.keywords:
25             for arg in list(args):
26                 if arg not in spec.args:
27                     del args[arg]
28         return callable(**args)
29     return wrapper
30
31 def persession(data = None):
32     def dec(callable):
33         def wrapper(req):
34             sess = session.get(req)
35             if callable not in sess:
36                 if data is None:
37                     sess[callable] = callable()
38                 else:
39                     if data not in sess:
40                         sess[data] = data()
41                     sess[callable] = callable(data)
42             return sess[callable].handle(req)
43         return wrapper
44     return dec
45
46 class sessiondata(object):
47     @classmethod
48     def get(cls, req, create = True):
49         sess = cls.sessdb().get(req)
50         with sess.lock:
51             try:
52                 return sess[cls]
53             except KeyError:
54                 if not create:
55                     return None
56                 ret = cls(req, sess)
57                 sess[cls] = ret
58                 return ret
59
60     @classmethod
61     def sessdb(cls):
62         return session.default.val
63
64 class autodirty(sessiondata):
65     @classmethod
66     def get(cls, req):
67         ret = super().get(req)
68         if "_is_dirty" not in ret.__dict__:
69             ret.__dict__["_is_dirty"] = False
70         return ret
71
72     def sessfrozen(self):
73         self.__dict__["_is_dirty"] = False
74
75     def sessdirty(self):
76         return self._is_dirty
77
78     def __setattr__(self, name, value):
79         super().__setattr__(name, value)
80         if "_is_dirty" in self.__dict__:
81             self.__dict__["_is_dirty"] = True
82
83     def __delattr__(self, name):
84         super().__delattr__(name, value)
85         if "_is_dirty" in self.__dict__:
86             self.__dict__["_is_dirty"] = True
87
88 class manudirty(object):
89     def __init__(self, *args, **kwargs):
90         super().__init__(*args, **kwargs)
91         self.__dirty = False
92
93     def sessfrozen(self):
94         self.__dirty = False
95
96     def sessdirty(self):
97         return self.__dirty
98
99     def dirty(self):
100         self.__dirty = True
101
102 class specslot(object):
103     __slots__ = ["nm", "idx", "dirty"]
104     unbound = object()
105     
106     def __init__(self, nm, idx, dirty):
107         self.nm = nm
108         self.idx = idx
109         self.dirty = dirty
110
111     @staticmethod
112     def slist(ins):
113         # Avoid calling __getattribute__
114         return specdirty.__sslots__.__get__(ins, type(ins))
115
116     def __get__(self, ins, cls):
117         val = self.slist(ins)[self.idx]
118         if val is specslot.unbound:
119             raise AttributeError("specslot %r is unbound" % self.nm)
120         return val
121
122     def __set__(self, ins, val):
123         self.slist(ins)[self.idx] = val
124         if self.dirty:
125             ins.dirty()
126
127     def __delete__(self, ins):
128         self.slist(ins)[self.idx] = specslot.unbound
129         ins.dirty()
130
131 class specclass(type):
132     def __init__(self, name, bases, tdict):
133         super().__init__(name, bases, tdict)
134         sslots = set()
135         dslots = set()
136         for cls in self.__mro__:
137             css = cls.__dict__.get("__saveslots__", ())
138             sslots.update(css)
139             dslots.update(cls.__dict__.get("__dirtyslots__", css))
140         self.__sslots_l__ = list(sslots)
141         self.__sslots_a__ = list(sslots | dslots)
142         for i, slot in enumerate(self.__sslots_a__):
143             setattr(self, slot, specslot(slot, i, slot in dslots))
144
145 class specdirty(sessiondata, metaclass=specclass):
146     __slots__ = ["session", "__sslots__", "_is_dirty"]
147     
148     def __specinit__(self):
149         pass
150
151     @staticmethod
152     def __new__(cls, req, sess):
153         self = super().__new__(cls)
154         self.session = sess
155         self.__sslots__ = [specslot.unbound] * len(cls.__sslots_a__)
156         self.__specinit__()
157         self._is_dirty = False
158         return self
159
160     def __getnewargs__(self):
161         return (None, self.session)
162
163     def dirty(self):
164         self._is_dirty = True
165
166     def sessfrozen(self):
167         self._is_dirty = False
168
169     def sessdirty(self):
170         return self._is_dirty
171
172     def __getstate__(self):
173         ret = {}
174         for nm, val in zip(type(self).__sslots_a__, specslot.slist(self)):
175             if val is specslot.unbound:
176                 ret[nm] = False, None
177             else:
178                 ret[nm] = True, val
179         return ret
180
181     def __setstate__(self, st):
182         ss = specslot.slist(self)
183         for i, nm in enumerate(type(self).__sslots_a__):
184             bound, val = st.pop(nm, (False, None))
185             if not bound:
186                 ss[i] = specslot.unbound
187             else:
188                 ss[i] = val