Hopefully fixed the ugliness in index-value duplication.
[didex.git] / didex / store.py
index 15d2eca..11fe17a 100644 (file)
@@ -1,22 +1,34 @@
-import threading, pickle
+import threading, pickle, inspect, atexit, weakref
 from . import db, index, cache
 from .db import txnfun
 
+__all__ = ["environment", "datastore", "autostore"]
+
 class environment(object):
-    def __init__(self, path):
-        self.path = path
+    def __init__(self, *, path=None, getpath=None, recover=False):
+        if path is not None:
+            self.path = path
+            self.getpath = None
+        else:
+            self.path = None
+            self.getpath = getpath
+        self.recover = recover
         self.lk = threading.Lock()
         self.bk = None
 
     def __call__(self):
         with self.lk:
             if self.bk is None:
-                self.bk = db.environment(self.path)
+                if self.path is None:
+                    self.path = self.getpath()
+                self.bk = db.environment(self.path, recover=self.recover)
+                atexit.register(self.close)
             return self.bk
 
     def close(self):
         with self.lk:
             if self.bk is not None:
+                atexit.unregister(self.close)
                 self.bk.close()
                 self.bk = None
 
@@ -25,28 +37,54 @@ class storedesc(object):
 
 def storedescs(obj):
     t = type(obj)
-    ret = getattr(t, "__didex_attr", None)
+    ret = t.__dict__.get("__didex_attr")
     if ret is None:
         ret = []
-        for nm, val in t.__dict__.items():
-            if isinstance(val, storedesc):
-                ret.append((nm, val))
+        for st in inspect.getmro(t):
+            for nm, val in st.__dict__.items():
+                if isinstance(val, storedesc):
+                    ret.append((nm, val))
         t.__didex_attr = ret
     return ret
 
-class store(object):
-    def __init__(self, name, *, env=None, path=".", ncache=None):
+class icache(object):
+    def __init__(self):
+        self.d = weakref.WeakKeyDictionary()
+
+    def __getitem__(self, key):
+        obj, idx = key
+        return self.d[obj][idx]
+    def __setitem__(self, key, val):
+        obj, idx = key
+        if obj in self.d:
+            self.d[obj][idx] = val
+        else:
+            self.d[obj] = {idx: val}
+    def __delitem__(self, key):
+        obj, idx = key
+        del self.d[obj][idx]
+    def get(self, key, default=None):
+        obj, idx = key
+        if obj not in self.d:
+            return default
+        return self.d[obj].get(idx, default)
+
+class datastore(object):
+    def __init__(self, name, *, env=None, path=".", ncache=None, codec=None):
         self.name = name
         self.lk = threading.Lock()
         if env:
             self.env = env
         else:
-            self.env = environment(path)
+            self.env = environment(path=path)
         self._db = None
         if ncache is None:
             ncache = cache.cache()
+        if codec is not None:
+            self._encode, self._decode = codec
         self.cache = ncache
         self.cache.load = self._load
+        self.icache = icache()
 
     def db(self):
         with self.lk:
@@ -54,15 +92,24 @@ class store(object):
                 self._db = self.env().db(self.name)
             return self._db
 
-    def _load(self, id):
+    def _decode(self, data):
         try:
-            return pickle.loads(self.db().get(id))
+            return pickle.loads(data)
         except:
             raise KeyError(id, "could not unpickle data")
 
     def _encode(self, obj):
         return pickle.dumps(obj)
 
+    @txnfun(lambda self: self.db().env.env)
+    def _load(self, id, *, tx):
+        loaded = self._decode(self.db().get(id, tx=tx))
+        if hasattr(loaded, "__didex_loaded__"):
+            loaded.__didex_loaded__(self, id)
+        for nm, attr in storedescs(loaded):
+            attr.loaded(id, loaded, tx)
+        return loaded
+
     def get(self, id, *, load=True):
         return self.cache.get(id, load=load)
 
@@ -75,16 +122,42 @@ class store(object):
         return id
 
     @txnfun(lambda self: self.db().env.env)
-    def unregister(self, id, *, tx):
+    def unregister(self, id, *, vfy=None, tx):
         obj = self.get(id)
+        if vfy is not None and obj is not vfy:
+            raise RuntimeError("object identity crisis: " + str(vfy) + " is not cached object " + obj)
         for nm, attr in storedescs(obj):
             attr.unregister(id, obj, tx)
         self.db().remove(id, tx=tx)
         self.cache.remove(id)
 
     @txnfun(lambda self: self.db().env.env)
-    def update(self, id, *, tx):
+    def update(self, id, *, vfy=None, tx):
         obj = self.get(id, load=False)
+        if vfy is not None and obj is not vfy:
+            raise RuntimeError("object identity crisis: " + str(vfy) + " is not cached object " + obj)
         for nm, attr, in storedescs(obj):
             attr.update(id, obj, tx)
         self.db().replace(id, self._encode(obj), tx=tx)
+
+class autotype(type):
+    def __call__(self, *args, **kwargs):
+        new = super().__call__(*args, **kwargs)
+        new.id = self.store.register(new)
+        # XXX? ID is not saved now, but relied upon to be __didex_loaded__ later.
+        return new
+
+class autostore(object, metaclass=autotype):
+    def __init__(self):
+        self.id = None
+
+    def __didex_loaded__(self, store, id):
+        assert self.id is None or self.id == id
+        self.id = id
+
+    def save(self):
+        self.store.update(self.id, vfy=self)
+
+    def remove(self):
+        self.store.unregister(self.id, vfy=self)
+        self.id = None