Hopefully fixed the ugliness in index-value duplication.
[didex.git] / didex / store.py
index fb5f2c6..11fe17a 100644 (file)
@@ -1,7 +1,9 @@
-import threading, pickle
+import threading, pickle, inspect, atexit, weakref
 from . import db, index, cache
 from .db import txnfun
 
+__all__ = ["environment", "datastore", "autostore"]
+
 class environment(object):
     def __init__(self, *, path=None, getpath=None, recover=False):
         if path is not None:
@@ -20,11 +22,13 @@ class environment(object):
                 if self.path is None:
                     self.path = self.getpath()
                 self.bk = db.environment(self.path, recover=self.recover)
+                atexit.register(self.close)
             return self.bk
 
     def close(self):
         with self.lk:
             if self.bk is not None:
+                atexit.unregister(self.close)
                 self.bk.close()
                 self.bk = None
 
@@ -33,17 +37,40 @@ class storedesc(object):
 
 def storedescs(obj):
     t = type(obj)
-    ret = getattr(t, "__didex_attr", None)
+    ret = t.__dict__.get("__didex_attr")
     if ret is None:
         ret = []
-        for nm, val in t.__dict__.items():
-            if isinstance(val, storedesc):
-                ret.append((nm, val))
+        for st in inspect.getmro(t):
+            for nm, val in st.__dict__.items():
+                if isinstance(val, storedesc):
+                    ret.append((nm, val))
         t.__didex_attr = ret
     return ret
 
-class store(object):
-    def __init__(self, name, *, env=None, path=".", ncache=None):
+class icache(object):
+    def __init__(self):
+        self.d = weakref.WeakKeyDictionary()
+
+    def __getitem__(self, key):
+        obj, idx = key
+        return self.d[obj][idx]
+    def __setitem__(self, key, val):
+        obj, idx = key
+        if obj in self.d:
+            self.d[obj][idx] = val
+        else:
+            self.d[obj] = {idx: val}
+    def __delitem__(self, key):
+        obj, idx = key
+        del self.d[obj][idx]
+    def get(self, key, default=None):
+        obj, idx = key
+        if obj not in self.d:
+            return default
+        return self.d[obj].get(idx, default)
+
+class datastore(object):
+    def __init__(self, name, *, env=None, path=".", ncache=None, codec=None):
         self.name = name
         self.lk = threading.Lock()
         if env:
@@ -53,8 +80,11 @@ class store(object):
         self._db = None
         if ncache is None:
             ncache = cache.cache()
+        if codec is not None:
+            self._encode, self._decode = codec
         self.cache = ncache
         self.cache.load = self._load
+        self.icache = icache()
 
     def db(self):
         with self.lk:
@@ -62,15 +92,24 @@ class store(object):
                 self._db = self.env().db(self.name)
             return self._db
 
-    def _load(self, id):
+    def _decode(self, data):
         try:
-            return pickle.loads(self.db().get(id))
+            return pickle.loads(data)
         except:
             raise KeyError(id, "could not unpickle data")
 
     def _encode(self, obj):
         return pickle.dumps(obj)
 
+    @txnfun(lambda self: self.db().env.env)
+    def _load(self, id, *, tx):
+        loaded = self._decode(self.db().get(id, tx=tx))
+        if hasattr(loaded, "__didex_loaded__"):
+            loaded.__didex_loaded__(self, id)
+        for nm, attr in storedescs(loaded):
+            attr.loaded(id, loaded, tx)
+        return loaded
+
     def get(self, id, *, load=True):
         return self.cache.get(id, load=load)
 
@@ -105,12 +144,17 @@ class autotype(type):
     def __call__(self, *args, **kwargs):
         new = super().__call__(*args, **kwargs)
         new.id = self.store.register(new)
+        # XXX? ID is not saved now, but relied upon to be __didex_loaded__ later.
         return new
 
 class autostore(object, metaclass=autotype):
     def __init__(self):
         self.id = None
 
+    def __didex_loaded__(self, store, id):
+        assert self.id is None or self.id == id
+        self.id = id
+
     def save(self):
         self.store.update(self.id, vfy=self)