Fix keyword-parameter handling bug in formparams.
[wrw.git] / wrw / session.py
index 41f2b5d..8141827 100644 (file)
@@ -1,25 +1,16 @@
 import threading, time, pickle, random, os
-from . import cookie
+from . import cookie, env, proto
 
 __all__ = ["db", "get"]
 
-def hexencode(str):
-    ret = ""
-    for byte in str:
-        ret += "%02X" % (ord(byte),)
-    return ret
-
 def gennonce(length):
-    nonce = ""
-    for i in xrange(length):
-        nonce += chr(random.randint(0, 255))
-    return nonce
+    return os.urandom(length)
 
 class session(object):
-    def __init__(self, expire = 86400 * 7):
-        self.id = hexencode(gennonce(16))
+    def __init__(self, lock, expire=86400 * 7):
+        self.id = proto.enhex(gennonce(16))
         self.dict = {}
-        self.lock = threading.Lock()
+        self.lock = lock
         self.ctime = self.atime = self.mtime = int(time.time())
         self.expire = expire
         self.dctl = set()
@@ -39,7 +30,7 @@ class session(object):
     def __getitem__(self, key):
         return self.dict[key]
 
-    def get(self, key, default = None):
+    def get(self, key, default=None):
         return self.dict.get(key, default)
 
     def __setitem__(self, key, value):
@@ -68,33 +59,52 @@ class session(object):
     def __setstate__(self, st):
         for k, v in st:
             self.__dict__[k] = v
-        self.lock = threading.Lock()
+        # The proper lock is set by the thawer
+
+    def __repr__(self):
+        return "<session %s>" % self.id
 
 class db(object):
-    def __init__(self, cookiename = "wrwsess", path = "/"):
+    def __init__(self, backdb=None, cookiename="wrwsess", path="/"):
         self.live = {}
         self.cookiename = cookiename
         self.path = path
         self.lock = threading.Lock()
         self.cthread = None
         self.freezetime = 3600
+        self.backdb = backdb
 
     def clean(self):
         now = int(time.time())
         with self.lock:
-            dlist = []
-            for sess in self.live.itervalues():
-                if sess.atime + self.freezetime < now:
-                    try:
-                        if sess.dirty():
-                            self.freeze(sess)
-                    except:
-                        if sess.atime + sess.expire < now:
-                            dlist.append(sess)
-                    else:
-                        dlist.append(sess)
-            for sess in dlist:
-                del self.live[sess.id]
+            clist = list(self.live.keys())
+        for sessid in clist:
+            with self.lock:
+                try:
+                    entry = self.live[sessid]
+                except KeyError:
+                    continue
+            with entry[0]:
+                rm = False
+                if entry[1] == "retired":
+                    pass
+                elif entry[1] is None:
+                    pass
+                else:
+                    sess = entry[1]
+                    if sess.atime + self.freezetime < now:
+                        try:
+                            if sess.dirty():
+                                self.freeze(sess)
+                        except:
+                            if sess.atime + sess.expire < now:
+                                rm = True
+                        else:
+                            rm = True
+                if rm:
+                    entry[1] = "retired"
+                    with self.lock:
+                        del self.live[sessid]
 
     def cleanloop(self):
         try:
@@ -107,67 +117,98 @@ class db(object):
             with self.lock:
                 self.cthread = None
 
-    def fetch(self, req):
-        now = int(time.time())
-        sessid = cookie.get(req, self.cookiename)
-        new = False
+    def _fetch(self, sessid):
+        while True:
+            now = int(time.time())
+            with self.lock:
+                if sessid in self.live:
+                    entry = self.live[sessid]
+                else:
+                    entry = self.live[sessid] = [threading.RLock(), None]
+            with entry[0]:
+                if isinstance(entry[1], session):
+                    entry[1].atime = now
+                    return entry[1]
+                elif entry[1] == "retired":
+                    continue
+                elif entry[1] is None:
+                    try:
+                        thawed = self.thaw(sessid)
+                        if thawed.atime + thawed.expire < now:
+                            raise KeyError()
+                        thawed.lock = entry[0]
+                        thawed.atime = now
+                        entry[1] = thawed
+                        return thawed
+                    finally:
+                        if entry[1] is None:
+                            entry[1] = "retired"
+                            with self.lock:
+                                del self.live[sessid]
+                else:
+                    raise Exception("Illegal session entry: " + repr(entry[1]))
+
+    def checkclean(self):
         with self.lock:
             if self.cthread is None:
                 self.cthread = threading.Thread(target = self.cleanloop)
                 self.cthread.setDaemon(True)
                 self.cthread.start()
-            try:
-                if sessid is None:
-                    raise KeyError()
-                elif sessid in self.live:
-                    sess = self.live[sessid]
-                else:
-                    sess = self.thaw(sessid)
-                    self.live[sessid] = sess
-                if sess.atime + sess.expire < now:
-                    raise KeyError()
-                sess.atime = now
-            except KeyError:
-                sess = session()
-                self.live[sess.id] = sess
-                new = True
+
+    def mksession(self, req):
+        return session(threading.RLock())
+
+    def mkcookie(self, req, sess):
+        cookie.add(req, self.cookiename, sess.id,
+                   path=self.path,
+                   expires=cookie.cdate(time.time() + sess.expire))
+
+    def fetch(self, req):
+        now = int(time.time())
+        sessid = cookie.get(req, self.cookiename)
+        new = False
+        try:
+            if sessid is None:
+                raise KeyError()
+            sess = self._fetch(sessid)
+        except KeyError:
+            sess = self.mksession(req)
+            new = True
 
         def ckfreeze(req):
             if sess.dirty():
+                if new:
+                    self.mkcookie(req, sess)
+                    with self.lock:
+                        self.live[sess.id] = [sess.lock, sess]
                 try:
-                    if new:
-                        cookie.add(req, self.cookiename, sess.id, self.path)
                     self.freeze(sess)
                 except:
                     pass
+                self.checkclean()
         req.oncommit(ckfreeze)
         return sess
 
     def thaw(self, sessid):
-        raise KeyError()
-
-    def freeze(self, sess):
-        raise TypeError()
-
-    def get(self, req):
-        return req.item(self.fetch)
-
-class backeddb(db):
-    def __init__(self, backdb, *args, **kw):
-        super(backeddb, self).__init__(*args, **kw)
-        self.backdb = backdb
-
-    def thaw(self, sessid):
+        if self.backdb is None:
+            raise KeyError()
         data = self.backdb[sessid]
         try:
             return pickle.loads(data)
-        except Exception, e:
+        except:
             raise KeyError()
 
     def freeze(self, sess):
-        self.backdb[sess.id] = pickle.dumps(sess)
+        if self.backdb is None:
+            raise TypeError()
+        with sess.lock:
+            data = pickle.dumps(sess, -1)
+        self.backdb[sess.id] = data
         sess.frozen()
 
+    def get(self, req):
+        return req.item(self.fetch)
+
 class dirback(object):
     def __init__(self, path):
         self.path = path
@@ -185,7 +226,7 @@ class dirback(object):
         with open(os.path.join(self.path, key), "w") as out:
             out.write(value)
 
-default = backeddb(dirback(os.path.join("/tmp", "wrwsess-" + str(os.getuid()))))
+default = env.var(db(backdb=dirback(os.path.join("/tmp", "wrwsess-" + str(os.getuid())))))
 
 def get(req):
-    return default.get(req)
+    return default.val.get(req)