Attempt to make the session database properly parallel.
authorFredrik Tolf <fredrik@dolda2000.com>
Fri, 29 Jun 2012 01:01:57 +0000 (03:01 +0200)
committerFredrik Tolf <fredrik@dolda2000.com>
Fri, 29 Jun 2012 01:01:57 +0000 (03:01 +0200)
wrw/session.py

index cf2d792..428c93d 100644 (file)
@@ -16,10 +16,10 @@ def gennonce(length):
     return nonce
 
 class session(object):
-    def __init__(self, expire = 86400 * 7):
+    def __init__(self, lock, expire = 86400 * 7):
         self.id = hexencode(gennonce(16))
         self.dict = {}
-        self.lock = threading.RLock()
+        self.lock = lock
         self.ctime = self.atime = self.mtime = int(time.time())
         self.expire = expire
         self.dctl = set()
@@ -68,7 +68,7 @@ class session(object):
     def __setstate__(self, st):
         for k, v in st:
             self.__dict__[k] = v
-        self.lock = threading.Lock()
+        # The proper lock is set by the thawer
 
 class db(object):
     def __init__(self, backdb = None, cookiename = "wrwsess", path = "/"):
@@ -83,19 +83,34 @@ class db(object):
     def clean(self):
         now = int(time.time())
         with self.lock:
-            dlist = []
-            for sess in self.live.itervalues():
-                if sess.atime + self.freezetime < now:
-                    try:
-                        if sess.dirty():
-                            self.freeze(sess)
-                    except:
-                        if sess.atime + sess.expire < now:
-                            dlist.append(sess)
-                    else:
-                        dlist.append(sess)
-            for sess in dlist:
-                del self.live[sess.id]
+            clist = self.live.keys()
+        for sessid in clist:
+            with self.lock:
+                try:
+                    entry = self.live[sessid]
+                except KeyError:
+                    continue
+            with entry[0]:
+                rm = False
+                if entry[1] == "retired":
+                    pass
+                elif entry[1] is None:
+                    pass
+                else:
+                    sess = entry[1]
+                    if sess.atime + self.freezetime < now:
+                        try:
+                            if sess.dirty():
+                                self.freeze(sess)
+                        except:
+                            if sess.atime + sess.expire < now:
+                                rm = True
+                        else:
+                            rm = True
+                if rm:
+                    entry[1] = "retired"
+                    with self.lock:
+                        del self.live[sessid]
 
     def cleanloop(self):
         try:
@@ -108,6 +123,37 @@ class db(object):
             with self.lock:
                 self.cthread = None
 
+    def _fetch(self, sessid):
+        while True:
+            now = int(time.time())
+            with self.lock:
+                if sessid in self.live:
+                    entry = self.live[sessid]
+                else:
+                    entry = self.live[sessid] = [threading.RLock(), None]
+            with entry[0]:
+                if isinstance(entry[1], session):
+                    entry[1].atime = now
+                    return entry[1]
+                elif entry[1] == "retired":
+                    continue
+                elif entry[1] is None:
+                    try:
+                        thawed = self.thaw(sessid)
+                        if thawed.atime + thawed.expire < now:
+                            raise KeyError()
+                        thawed.lock = entry[0]
+                        thawed.atime = now
+                        entry[1] = thawed
+                        return thawed
+                    finally:
+                        if entry[1] is None:
+                            entry[1] = "retired"
+                            with self.lock:
+                                del self.live[sessid]
+                else:
+                    raise Exception("Illegal session entry: " + repr(entry[1]))
+
     def fetch(self, req):
         now = int(time.time())
         sessid = cookie.get(req, self.cookiename)
@@ -117,27 +163,20 @@ class db(object):
                 self.cthread = threading.Thread(target = self.cleanloop)
                 self.cthread.setDaemon(True)
                 self.cthread.start()
-            try:
-                if sessid is None:
-                    raise KeyError()
-                elif sessid in self.live:
-                    sess = self.live[sessid]
-                else:
-                    sess = self.thaw(sessid)
-                    self.live[sessid] = sess
-                if sess.atime + sess.expire < now:
-                    raise KeyError()
-                sess.atime = now
-            except KeyError:
-                sess = session()
-                new = True
+        try:
+            if sessid is None:
+                raise KeyError()
+            sess = self._fetch(sessid)
+        except KeyError:
+            sess = session(threading.RLock())
+            new = True
 
         def ckfreeze(req):
             if sess.dirty():
                 if new:
                     cookie.add(req, self.cookiename, sess.id, self.path)
                     with self.lock:
-                        self.live[sess.id] = sess
+                        self.live[sess.id] = [sess.lock, sess]
                 try:
                     self.freeze(sess)
                 except:
@@ -157,7 +196,9 @@ class db(object):
     def freeze(self, sess):
         if self.backdb is None:
             raise TypeError()
-        self.backdb[sess.id] = pickle.dumps(sess, -1)
+        with sess.lock:
+            data = pickle.dumps(sess, -1)
+        self.backdb[sess.id] = data
         sess.frozen()
 
     def get(self, req):