Always use the latest pickle protocol when freezing sessions.
[wrw.git] / wrw / session.py
index 9e91ad9..c45090a 100644 (file)
@@ -110,6 +110,7 @@ class db(object):
     def fetch(self, req):
         now = int(time.time())
         sessid = cookie.get(req, self.cookiename)
+        new = False
         with self.lock:
             if self.cthread is None:
                 self.cthread = threading.Thread(target = self.cleanloop)
@@ -128,21 +129,20 @@ class db(object):
                 sess.atime = now
             except KeyError:
                 sess = session()
-                self.live[sess.id] = sess
-                sess.new = True
-        req.oncommit(self.ckfreeze)
-        return sess
+                new = True
 
-    def ckfreeze(self, req):
-        sess = req.item(self.fetch)
-        if sess.dirty():
-            try:
-                if getattr(sess, "new", False):
+        def ckfreeze(req):
+            if sess.dirty():
+                if new:
                     cookie.add(req, self.cookiename, sess.id, self.path)
-                    del sess.new
-                self.freeze(sess)
-            except:
-                pass
+                    with self.lock:
+                        self.live[sess.id] = sess
+                try:
+                    self.freeze(sess)
+                except:
+                    pass
+        req.oncommit(ckfreeze)
+        return sess
 
     def thaw(self, sessid):
         raise KeyError()
@@ -150,6 +150,9 @@ class db(object):
     def freeze(self, sess):
         raise TypeError()
 
+    def get(self, req):
+        return req.item(self.fetch)
+
 class backeddb(db):
     def __init__(self, backdb, *args, **kw):
         super(backeddb, self).__init__(*args, **kw)
@@ -163,7 +166,7 @@ class backeddb(db):
             raise KeyError()
 
     def freeze(self, sess):
-        self.backdb[sess.id] = pickle.dumps(sess)
+        self.backdb[sess.id] = pickle.dumps(sess, -1)
         sess.frozen()
 
 class dirback(object):
@@ -186,4 +189,4 @@ class dirback(object):
 default = backeddb(dirback(os.path.join("/tmp", "wrwsess-" + str(os.getuid()))))
 
 def get(req):
-    return req.item(default.fetch)
+    return default.get(req)