Fixed wrapping of util.stringwrap
[wrw.git] / wrw / util.py
1 import inspect, math
2 from . import req, dispatch, session, form, resp, proto
3
4 def wsgiwrap(callable):
5     def wrapper(env, startreq):
6         return dispatch.handleenv(env, startreq, callable)
7     wrapper.__wrapped__ = callable
8     return wrapper
9
10 def stringwrap(charset):
11     def dec(callable):
12         def wrapper(*args, **kwargs):
13             bk = callable(*args, **kwargs)
14             for string in bk:
15                 yield string.encode(charset)
16         wrapper.__wrapped__ = callable
17         return wrapper
18     return dec
19
20 def formparams(callable):
21     spec = inspect.getargspec(callable)
22     def wrapper(req):
23         data = form.formdata(req)
24         args = dict(data.items())
25         args["req"] = req
26         if not spec.keywords:
27             for arg in list(args):
28                 if arg not in spec.args:
29                     del args[arg]
30         for i in range(len(spec.args) - (len(spec.defaults) if spec.defaults else 0)):
31             if spec.args[i] not in args:
32                 raise resp.httperror(400, "Missing parameter", ("The query parameter `", resp.h.code(spec.args[i]), "' is required but not supplied."))
33         return callable(**args)
34     wrapper.__wrapped__ = callable
35     return wrapper
36
37 class funplex(object):
38     def __init__(self, *funs, **nfuns):
39         self.dir = {}
40         self.dir.update(((self.unwrap(fun).__name__, fun) for fun in funs))
41         self.dir.update(nfuns)
42
43     @staticmethod
44     def unwrap(fun):
45         while hasattr(fun, "__wrapped__"):
46             fun = fun.__wrapped__
47         return fun
48
49     def __call__(self, req):
50         if req.pathinfo == "":
51             raise resp.redirect(req.uriname + "/")
52         if req.pathinfo[:1] != "/":
53             raise resp.notfound()
54         p = req.pathinfo[1:]
55         if p == "":
56             p = "__index__"
57             bi = 1
58         else:
59             p = p.partition("/")[0]
60             bi = len(p) + 1
61         if p in self.dir:
62             return self.dir[p](req.shift(bi))
63         raise resp.notfound()
64
65     def add(self, fun):
66         self.dir[self.unwrap(fun).__name__] = fun
67         return fun
68
69     def name(self, name):
70         def dec(fun):
71             self.dir[name] = fun
72             return fun
73         return dec
74
75 def persession(data=None):
76     def dec(callable):
77         def wrapper(req):
78             sess = session.get(req)
79             if callable not in sess:
80                 if data is None:
81                     sess[callable] = callable()
82                 else:
83                     if data not in sess:
84                         sess[data] = data()
85                     sess[callable] = callable(data)
86             return sess[callable].handle(req)
87         wrapper.__wrapped__ = callable
88         return wrapper
89     return dec
90
91 class preiter(object):
92     __slots__ = ["bk", "bki", "_next"]
93     end = object()
94     def __init__(self, real):
95         self.bk = real
96         self.bki = iter(real)
97         self._next = None
98         self.__next__()
99
100     def __iter__(self):
101         return self
102
103     def __next__(self):
104         if self._next is self.end:
105             raise StopIteration()
106         ret = self._next
107         try:
108             self._next = next(self.bki)
109         except StopIteration:
110             self._next = self.end
111         return ret
112
113     def close(self):
114         if hasattr(self.bk, "close"):
115             self.bk.close()
116
117 def pregen(callable):
118     def wrapper(*args, **kwargs):
119         return preiter(callable(*args, **kwargs))
120     wrapper.__wrapped__ = callable
121     return wrapper
122
123 class sessiondata(object):
124     @classmethod
125     def get(cls, req, create=True):
126         sess = cls.sessdb().get(req)
127         with sess.lock:
128             try:
129                 return sess[cls]
130             except KeyError:
131                 if not create:
132                     return None
133                 ret = cls(req, sess)
134                 sess[cls] = ret
135                 return ret
136
137     @classmethod
138     def sessdb(cls):
139         return session.default.val
140
141 class autodirty(sessiondata):
142     @classmethod
143     def get(cls, req):
144         ret = super().get(req)
145         if "_is_dirty" not in ret.__dict__:
146             ret.__dict__["_is_dirty"] = False
147         return ret
148
149     def sessfrozen(self):
150         self.__dict__["_is_dirty"] = False
151
152     def sessdirty(self):
153         return self._is_dirty
154
155     def __setattr__(self, name, value):
156         super().__setattr__(name, value)
157         if "_is_dirty" in self.__dict__:
158             self.__dict__["_is_dirty"] = True
159
160     def __delattr__(self, name):
161         super().__delattr__(name, value)
162         if "_is_dirty" in self.__dict__:
163             self.__dict__["_is_dirty"] = True
164
165 class manudirty(object):
166     def __init__(self, *args, **kwargs):
167         super().__init__(*args, **kwargs)
168         self.__dirty = False
169
170     def sessfrozen(self):
171         self.__dirty = False
172
173     def sessdirty(self):
174         return self.__dirty
175
176     def dirty(self):
177         self.__dirty = True
178
179 class specslot(object):
180     __slots__ = ["nm", "idx", "dirty"]
181     unbound = object()
182     
183     def __init__(self, nm, idx, dirty):
184         self.nm = nm
185         self.idx = idx
186         self.dirty = dirty
187
188     @staticmethod
189     def slist(ins):
190         # Avoid calling __getattribute__
191         return specdirty.__sslots__.__get__(ins, type(ins))
192
193     def __get__(self, ins, cls):
194         val = self.slist(ins)[self.idx]
195         if val is specslot.unbound:
196             raise AttributeError("specslot %r is unbound" % self.nm)
197         return val
198
199     def __set__(self, ins, val):
200         self.slist(ins)[self.idx] = val
201         if self.dirty:
202             ins.dirty()
203
204     def __delete__(self, ins):
205         self.slist(ins)[self.idx] = specslot.unbound
206         ins.dirty()
207
208 class specclass(type):
209     def __init__(self, name, bases, tdict):
210         super().__init__(name, bases, tdict)
211         sslots = set()
212         dslots = set()
213         for cls in self.__mro__:
214             css = cls.__dict__.get("__saveslots__", ())
215             sslots.update(css)
216             dslots.update(cls.__dict__.get("__dirtyslots__", css))
217         self.__sslots_l__ = list(sslots)
218         self.__sslots_a__ = list(sslots | dslots)
219         for i, slot in enumerate(self.__sslots_a__):
220             setattr(self, slot, specslot(slot, i, slot in dslots))
221
222 class specdirty(sessiondata, metaclass=specclass):
223     __slots__ = ["session", "__sslots__", "_is_dirty"]
224     
225     def __specinit__(self):
226         pass
227
228     @staticmethod
229     def __new__(cls, req, sess):
230         self = super().__new__(cls)
231         self.session = sess
232         self.__sslots__ = [specslot.unbound] * len(cls.__sslots_a__)
233         self.__specinit__()
234         self._is_dirty = False
235         return self
236
237     def __getnewargs__(self):
238         return (None, self.session)
239
240     def dirty(self):
241         self._is_dirty = True
242
243     def sessfrozen(self):
244         self._is_dirty = False
245
246     def sessdirty(self):
247         return self._is_dirty
248
249     def __getstate__(self):
250         ret = {}
251         for nm, val in zip(type(self).__sslots_a__, specslot.slist(self)):
252             if val is specslot.unbound:
253                 ret[nm] = False, None
254             else:
255                 ret[nm] = True, val
256         return ret
257
258     def __setstate__(self, st):
259         ss = specslot.slist(self)
260         for i, nm in enumerate(type(self).__sslots_a__):
261             bound, val = st.pop(nm, (False, None))
262             if not bound:
263                 ss[i] = specslot.unbound
264             else:
265                 ss[i] = val
266
267 def datecheck(req, mtime):
268     if "If-Modified-Since" in req.ihead:
269         rtime = proto.phttpdate(req.ihead["If-Modified-Since"])
270         if rtime >= math.floor(mtime):
271             raise resp.unmodified()
272     req.ohead["Last-Modified"] = proto.httpdate(mtime)