Typo fix.
[didex.git] / didex / values.py
index c2b6d1f..debd29e 100644 (file)
@@ -2,7 +2,7 @@ import threading
 from . import store, lib, index
 from .store import storedesc
 
-__all__ = ["simple", "multi", "compound"]
+__all__ = ["simple", "multi", "compound", "idlink"]
 
 class cursor(lib.closable):
     def __init__(self, bk, st):
@@ -31,21 +31,41 @@ class base(storedesc):
         self.idx = None
         self.lk = threading.Lock()
 
-    def index(self):
+    def index(self, tx):
         with self.lk:
             if self.idx is None:
-                self.idx = self.indextype(self.store.db(), self.name, self.typ)
+                self.idx = self.indextype(self.store.db(), self.name, self.typ, tx=tx)
             return self.idx
 
     def get(self, **kwargs):
-        return cursor(self.index().get(**kwargs), self.store)
+        return cursor(self.index(None).get(**kwargs), self.store)
+
+    def get1(self, *, check=False, default=KeyError, **kwargs):
+        with self.get(**kwargs) as cursor:
+            try:
+                k, v = next(cursor)
+            except StopIteration:
+                if default is not KeyError:
+                    return default
+                raise KeyError("no matches in " + self.name, kwargs)
+            if check:
+                try:
+                    next(cursor)
+                except StopIteration:
+                    pass
+                else:
+                    raise ValueError("unexpected multiple matchies in " + self.name, kwargs)
+            return v
+
+    def list(self, **kwargs):
+        with self.get(**kwargs) as cursor:
+            return [v for k, v in cursor]
 
 class descbase(base):
     def __init__(self, store, indextype, name, datatype, default):
         super().__init__(store, indextype, name, datatype)
         self.default = default
-        self.mattr = "__idx_%s_new" % name
-        self.iattr = "__idx_%s_cur" % name
+        self.mattr = "__ival_%s" % name
 
     def __get__(self, obj, cls):
         if obj is None: return self
@@ -63,21 +83,25 @@ class simple(descbase):
 
     def register(self, id, obj, tx):
         val = self.__get__(obj, None)
-        self.index().put(val, id, tx=tx)
-        tx.postcommit(lambda: setattr(obj, self.iattr, val))
+        self.index(tx).put(val, id, tx=tx)
+        tx.postcommit(lambda: self.store.icache.__setitem__((obj, self), val))
 
     def unregister(self, id, obj, tx):
-        self.index().remove(getattr(obj, self.iattr), id, tx=tx)
-        tx.postcommit(lambda: delattr(obj, self.iattr))
+        self.index(tx).remove(self.store.icache[obj, self], id, tx=tx)
+        tx.postcommit(lambda: self.store.icache.__delitem__((obj, self)))
 
     def update(self, id, obj, tx):
         val = self.__get__(obj, None)
-        ival = getattr(obj, self.iattr)
+        ival = self.store.icache[obj, self]
         if val != ival:
-            idx = self.index()
+            idx = self.index(tx)
             idx.remove(ival, id, tx=tx)
             idx.put(val, id, tx=tx)
-            tx.postcommit(lambda: setattr(obj, self.iattr, val))
+            tx.postcommit(lambda: self.store.icache.__setitem__((obj, self), val))
+
+    def loaded(self, id, obj, tx):
+        val = self.__get__(obj, None)
+        tx.postcommit(lambda: self.store.icache.__setitem__((obj, self), val))
 
 class multi(descbase):
     def __init__(self, store, indextype, name, datatype):
@@ -85,33 +109,36 @@ class multi(descbase):
 
     def register(self, id, obj, tx):
         vals = frozenset(self.__get__(obj, None))
-        idx = self.index()
+        idx = self.index(tx)
         for val in vals:
             idx.put(val, id, tx=tx)
-        tx.postcommit(lambda: setattr(obj, self.iattr, vals))
+        tx.postcommit(lambda: self.store.icache.__setitem__((obj, self), vals))
 
     def unregister(self, id, obj, tx):
-        idx = self.index()
-        for val in getattr(obj, self.iattr):
+        idx = self.index(tx)
+        for val in self.store.icache[obj, self]:
             idx.remove(val, id, tx=tx)
-        tx.postcommit(lambda: delattr(obj, self.iattr))
+        tx.postcommit(lambda: self.store.icache.__delitem__((obj, self)))
 
     def update(self, id, obj, tx):
         vals = frozenset(self.__get__(obj, None))
-        ivals = getattr(obj, self.iattr)
+        ivals = self.store.icache[obj, self]
         if vals != ivals:
-            idx = self.index()
+            idx = self.index(tx)
             for val in ivals - vals:
                 idx.remove(val, id, tx=tx)
             for val in vals - ivals:
                 idx.put(val, id, tx=tx)
-            tx.postcommit(lambda: setattr(obj, self.iattr, vals))
+            tx.postcommit(lambda: self.store.icache.__setitem__((obj, self), val))
+
+    def loaded(self, id, obj, tx):
+        vals = frozenset(self.__get__(obj, None))
+        tx.postcommit(lambda: self.store.icache.__setitem__((obj, self), vals))
 
 class compound(base):
     def __init__(self, indextype, name, *parts):
         super().__init__(parts[0].store, indextype, name, index.compound(*(part.typ for part in parts)))
         self.parts = parts
-        self.iattr = "__idx_%s_cur" % name
 
     def minim(self, *parts):
         return self.typ.minim(*parts)
@@ -126,18 +153,40 @@ class compound(base):
 
     def register(self, id, obj, tx):
         val = tuple(part.__get__(obj, None) for part in self.parts)
-        self.index().put(val, id, tx=tx)
-        tx.postcommit(lambda: setattr(obj, self.iattr, val))
+        self.index(tx).put(val, id, tx=tx)
+        tx.postcommit(lambda: self.store.icache.__setitem__((obj, self), val))
 
     def unregister(self, id, obj, tx):
-        self.index().remove(getattr(obj, self.iattr), id, tx=tx)
-        tx.postcommit(lambda: delattr(obj, self.iattr))
+        self.index(tx).remove(self.store.icache[obj, self], id, tx=tx)
+        tx.postcommit(lambda: self.store.icache.__delitem__((obj, self)))
 
     def update(self, id, obj, tx):
         val = tuple(part.__get__(obj, None) for part in self.parts)
-        ival = getattr(obj, self.iattr)
+        ival = self.store.icache[obj, self]
         if val != ival:
-            idx = self.index()
+            idx = self.index(tx)
             idx.remove(ival, id, tx=tx)
             idx.put(val, id, tx=tx)
-            tx.postcommit(lambda: setattr(obj, self.iattr, val))
+            tx.postcommit(lambda: self.store.icache.__setitem__((obj, self), val))
+
+    def loaded(self, id, obj, tx):
+        val = tuple(part.__get__(obj, None) for part in self.parts)
+        tx.postcommit(lambda: self.store.icache.__setitem__((obj, self), val))
+
+class idlink(object):
+    def __init__(self, name, atype):
+        self.atype = atype
+        self.battr = "__idlink_%s" % name
+
+    def __get__(self, obj, cls):
+        if obj is None: return self
+        ret = self.atype.store.get(getattr(obj, self.battr))
+        assert isinstance(ret, self.atype)
+        return ret
+
+    def __set__(self, obj, val):
+        assert isinstance(val, self.atype)
+        setattr(obj, self.battr, val.id)
+
+    def __delete__(self, obj):
+        delattr(obj, self.battr)