Error out more usefully from formparams when required parameters are missing.
[wrw.git] / wrw / util.py
1 import inspect
2 import req, dispatch, session, form, resp
3
4 def wsgiwrap(callable):
5     def wrapper(env, startreq):
6         return dispatch.handleenv(env, startreq, callable)
7     return wrapper
8
9 def formparams(callable):
10     def wrapper(req):
11         data = form.formdata(req)
12         spec = inspect.getargspec(callable)
13         args = dict(data.items())
14         args["req"] = req
15         if not spec.keywords:
16             for arg in list(args):
17                 if arg not in spec.args:
18                     del args[arg]
19         for i in xrange(len(spec.args) - len(spec.defaults)):
20             if spec.args[i] not in args:
21                 raise resp.httperror(400, "Missing parameter", ("The query parameter `", resp.h.code(spec.args[i]), "' is required but not supplied."))
22         return callable(**args)
23     return wrapper
24
25 def persession(data = None):
26     def dec(callable):
27         def wrapper(req):
28             sess = session.get(req)
29             if callable not in sess:
30                 if data is None:
31                     sess[callable] = callable()
32                 else:
33                     if data not in sess:
34                         sess[data] = data()
35                     sess[callable] = callable(data)
36             return sess[callable].handle(req)
37         return wrapper
38     return dec
39
40 class preiter(object):
41     __slots__ = ["bk", "bki", "_next"]
42     end = object()
43     def __init__(self, real):
44         self.bk = real
45         self.bki = iter(real)
46         self._next = None
47         self.next()
48
49     def __iter__(self):
50         return self
51
52     def next(self):
53         if self._next is self.end:
54             raise StopIteration()
55         ret = self._next
56         try:
57             self._next = next(self.bki)
58         except StopIteration:
59             self._next = self.end
60         return ret
61
62     def close(self):
63         if hasattr(self.bk, "close"):
64             self.bk.close()
65
66 def pregen(callable):
67     def wrapper(*args, **kwargs):
68         return preiter(callable(*args, **kwargs))
69     return wrapper
70
71 class sessiondata(object):
72     @classmethod
73     def get(cls, req, create = True):
74         sess = cls.sessdb().get(req)
75         with sess.lock:
76             try:
77                 return sess[cls]
78             except KeyError:
79                 if not create:
80                     return None
81                 ret = cls(req, sess)
82                 sess[cls] = ret
83                 return ret
84
85     @classmethod
86     def sessdb(cls):
87         return session.default.val
88
89 class autodirty(sessiondata):
90     @classmethod
91     def get(cls, req):
92         ret = super(autodirty, cls).get(req)
93         if "_is_dirty" not in ret.__dict__:
94             ret.__dict__["_is_dirty"] = False
95         return ret
96
97     def sessfrozen(self):
98         self.__dict__["_is_dirty"] = False
99
100     def sessdirty(self):
101         return self._is_dirty
102
103     def __setattr__(self, name, value):
104         super(autodirty, self).__setattr__(name, value)
105         if "_is_dirty" in self.__dict__:
106             self.__dict__["_is_dirty"] = True
107
108     def __delattr__(self, name):
109         super(autodirty, self).__delattr__(name, value)
110         if "_is_dirty" in self.__dict__:
111             self.__dict__["_is_dirty"] = True
112
113 class manudirty(object):
114     def __init__(self, *args, **kwargs):
115         super(manudirty, self).__init__(*args, **kwargs)
116         self.__dirty = False
117
118     def sessfrozen(self):
119         self.__dirty = False
120
121     def sessdirty(self):
122         return self.__dirty
123
124     def dirty(self):
125         self.__dirty = True
126
127 class specslot(object):
128     __slots__ = ["nm", "idx", "dirty"]
129     unbound = object()
130     
131     def __init__(self, nm, idx, dirty):
132         self.nm = nm
133         self.idx = idx
134         self.dirty = dirty
135
136     @staticmethod
137     def slist(ins):
138         # Avoid calling __getattribute__
139         return specdirty.__sslots__.__get__(ins, type(ins))
140
141     def __get__(self, ins, cls):
142         val = self.slist(ins)[self.idx]
143         if val is specslot.unbound:
144             raise AttributeError("specslot %r is unbound" % self.nm)
145         return val
146
147     def __set__(self, ins, val):
148         self.slist(ins)[self.idx] = val
149         if self.dirty:
150             ins.dirty()
151
152     def __delete__(self, ins):
153         self.slist(ins)[self.idx] = specslot.unbound
154         ins.dirty()
155
156 class specclass(type):
157     def __init__(self, name, bases, tdict):
158         super(specclass, self).__init__(name, bases, tdict)
159         sslots = set()
160         dslots = set()
161         for cls in self.__mro__:
162             css = cls.__dict__.get("__saveslots__", ())
163             sslots.update(css)
164             dslots.update(cls.__dict__.get("__dirtyslots__", css))
165         self.__sslots_l__ = list(sslots)
166         self.__sslots_a__ = list(sslots | dslots)
167         for i, slot in enumerate(self.__sslots_a__):
168             setattr(self, slot, specslot(slot, i, slot in dslots))
169
170 class specdirty(sessiondata):
171     __metaclass__ = specclass
172     __slots__ = ["session", "__sslots__", "_is_dirty"]
173     
174     def __specinit__(self):
175         pass
176
177     @staticmethod
178     def __new__(cls, req, sess):
179         self = super(specdirty, cls).__new__(cls)
180         self.session = sess
181         self.__sslots__ = [specslot.unbound] * len(cls.__sslots_a__)
182         self.__specinit__()
183         self._is_dirty = False
184         return self
185
186     def __getnewargs__(self):
187         return (None, self.session)
188
189     def dirty(self):
190         self._is_dirty = True
191
192     def sessfrozen(self):
193         self._is_dirty = False
194
195     def sessdirty(self):
196         return self._is_dirty
197
198     def __getstate__(self):
199         ret = {}
200         for nm, val in zip(type(self).__sslots_a__, specslot.slist(self)):
201             if val is specslot.unbound:
202                 ret[nm] = False, None
203             else:
204                 ret[nm] = True, val
205         return ret
206
207     def __setstate__(self, st):
208         ss = specslot.slist(self)
209         for i, nm in enumerate(type(self).__sslots_a__):
210             bound, val = st.pop(nm, (False, None))
211             if not bound:
212                 ss[i] = specslot.unbound
213             else:
214                 ss[i] = val