Tried to create a more advanced sessiondata class for more precise control.
[wrw.git] / wrw / util.py
1 import inspect
2 import req, dispatch, session, form
3
4 def wsgiwrap(callable):
5     def wrapper(env, startreq):
6         return dispatch.handle(req.origrequest(env), startreq, callable)
7     return wrapper
8
9 def formparams(callable):
10     def wrapper(req):
11         data = form.formdata(req)
12         spec = inspect.getargspec(callable)
13         args = dict(data.items())
14         args["req"] = req
15         if not spec.keywords:
16             for arg in list(args):
17                 if arg not in spec.args:
18                     del args[arg]
19         return callable(**args)
20     return wrapper
21
22 def persession(data = None):
23     def dec(callable):
24         def wrapper(req):
25             sess = session.get(req)
26             if callable not in sess:
27                 if data is None:
28                     sess[callable] = callable()
29                 else:
30                     if data not in sess:
31                         sess[data] = data()
32                     sess[callable] = callable(data)
33             return sess[callable].handle(req)
34         return wrapper
35     return dec
36
37 class sessiondata(object):
38     @classmethod
39     def get(cls, req, create = True):
40         sess = cls.sessdb().get(req)
41         with sess.lock:
42             try:
43                 return sess[cls]
44             except KeyError:
45                 if not create:
46                     return None
47                 ret = cls(req, sess)
48                 sess[cls] = ret
49                 return ret
50
51     @classmethod
52     def sessdb(cls):
53         return session.default.val
54
55 class autodirty(sessiondata):
56     @classmethod
57     def get(cls, req):
58         ret = super(autodirty, cls).get(req)
59         if "_is_dirty" not in ret.__dict__:
60             ret.__dict__["_is_dirty"] = False
61         return ret
62
63     def sessfrozen(self):
64         self.__dict__["_is_dirty"] = False
65
66     def sessdirty(self):
67         return self._is_dirty
68
69     def __setattr__(self, name, value):
70         super(autodirty, self).__setattr__(name, value)
71         if "_is_dirty" in self.__dict__:
72             self.__dict__["_is_dirty"] = True
73
74     def __delattr__(self, name):
75         super(autodirty, self).__delattr__(name, value)
76         if "_is_dirty" in self.__dict__:
77             self.__dict__["_is_dirty"] = True
78
79 class manudirty(object):
80     def __init__(self, *args, **kwargs):
81         super(manudirty, self).__init__(*args, **kwargs)
82         self.__dirty = False
83
84     def sessfrozen(self):
85         self.__dirty = False
86
87     def sessdirty(self):
88         return self.__dirty
89
90     def dirty(self):
91         self.__dirty = True
92
93 class specslot(object):
94     __slots__ = ["nm", "idx", "dirty"]
95     unbound = object()
96     
97     def __init__(self, nm, idx, dirty):
98         self.nm = nm
99         self.idx = idx
100         self.dirty = dirty
101
102     @staticmethod
103     def slist(ins):
104         # Avoid calling __getattribute__
105         return specdirty.__sslots__.__get__(ins, type(ins))
106
107     def __get__(self, ins, cls):
108         val = self.slist(ins)[self.idx]
109         if val is specslot.unbound:
110             raise AttributeError("specslot %r is unbound" % self.nm)
111         return val
112
113     def __set__(self, ins, val):
114         self.slist(ins)[self.idx] = val
115         if self.dirty:
116             ins.dirty()
117
118     def __delete__(self, ins):
119         self.slist(ins)[self.idx] = specslot.unbound
120         ins.dirty()
121
122 class specclass(type):
123     def __init__(self, name, bases, tdict):
124         super(specclass, self).__init__(name, bases, tdict)
125         sslots = set()
126         dslots = set()
127         for cls in self.__mro__:
128             css = cls.__dict__.get("__saveslots__", ())
129             sslots.update(css)
130             dslots.update(cls.__dict__.get("__dirtyslots__", css))
131         self.__sslots_l__ = list(sslots)
132         self.__sslots_a__ = list(sslots | dslots)
133         for i, slot in enumerate(self.__sslots_a__):
134             setattr(self, slot, specslot(slot, i, slot in dslots))
135
136 class specdirty(sessiondata):
137     __metaclass__ = specclass
138     __slots__ = ["session", "__sslots__", "_is_dirty"]
139     
140     def __specinit__(self):
141         pass
142
143     @staticmethod
144     def __new__(cls, req, sess):
145         self = super(specdirty, cls).__new__(cls)
146         self.session = sess
147         self.__sslots__ = [specslot.unbound] * len(cls.__sslots_a__)
148         self.__specinit__()
149         self._is_dirty = False
150         return self
151
152     def __getnewargs__(self):
153         return (None, self.session)
154
155     def dirty(self):
156         self._is_dirty = True
157
158     def sessfrozen(self):
159         self._is_dirty = False
160
161     def sessdirty(self):
162         return self._is_dirty
163
164     def __getstate__(self):
165         ret = {}
166         for nm, val in zip(type(self).__sslots_a__, specslot.slist(self)):
167             if val is specslot.unbound:
168                 ret[nm] = False, None
169             else:
170                 ret[nm] = True, val
171         return ret
172
173     def __setstate__(self, st):
174         ss = specslot.slist(self)
175         for i, nm in enumerate(type(self).__sslots_a__):
176             bound, val = st.pop(nm, (False, None))
177             print i, nm, bound, val
178             if not bound:
179                 ss[i] = specslot.unbound
180             else:
181                 ss[i] = val