+import threading
+from . import store, lib
+from .store import storedesc
+
+class cursor(lib.closable):
+ def __init__(self, bk, st):
+ self.bk = bk
+ self.st = st
+
+ def close(self):
+ self.bk.close()
+
+ def __iter__(self):
+ return self
+
+ def __next__(self):
+ k, id = next(self.bk)
+ return k, self.st.get(id)
+
+ def skip(self, n=1):
+ self.bk.skip(n)
+
+class base(storedesc):
+ def __init__(self, store, indextype, name, datatype, default):
+ self.store = store
+ self.indextype = indextype
+ self.name = name
+ self.typ = datatype
+ self.default = default
+ self.idx = None
+ self.lk = threading.Lock()
+ self.mattr = "__idx_%s_new" % name
+ self.iattr = "__idx_%s_cur" % name
+
+ def index(self):
+ with self.lk:
+ if self.idx is None:
+ self.idx = self.indextype(self.store.db(), self.name, self.typ)
+ return self.idx
+
+ def __get__(self, obj, cls):
+ if obj is None: return self
+ return getattr(obj, self.mattr, self.default)
+
+ def __set__(self, obj, val):
+ setattr(obj, self.mattr, val)
+
+ def __delete__(self, obj):
+ delattr(obj, self.mattr)
+
+ def get(self, **kwargs):
+ return cursor(self.index().get(**kwargs), self.store)
+
+class simple(base):
+ def __init__(self, store, indextype, name, datatype, default=None):
+ super().__init__(store, indextype, name, datatype, default)
+
+ def register(self, id, obj, tx):
+ val = self.__get__(obj, None)
+ self.index().put(val, id, tx=tx)
+ tx.postcommit(lambda: setattr(obj, self.iattr, val))
+
+ def unregister(self, id, obj, tx):
+ self.index().remove(getattr(obj, self.iattr), id, tx=tx)
+ tx.postcommit(lambda: delattr(obj, self.iattr))
+
+ def update(self, id, obj, tx):
+ val = self.__get__(obj, None)
+ ival = getattr(obj, self.iattr)
+ if val != ival:
+ idx = self.index()
+ idx.remove(ival, id, tx=tx)
+ idx.put(val, id, tx=tx)
+ tx.postcommit(lambda: setattr(obj, self.iattr, val))
+
+class multi(base):
+ def __init__(self, store, indextype, name, datatype):
+ super().__init__(store, indextype, name, datatype, ())
+
+ def register(self, id, obj, tx):
+ vals = frozenset(self.__get__(obj, None))
+ idx = self.index()
+ for val in vals:
+ idx.put(val, id, tx=tx)
+ tx.postcommit(lambda: setattr(obj, self.iattr, vals))
+
+ def unregister(self, id, obj, tx):
+ idx = self.index()
+ for val in getattr(obj, self.iattr):
+ idx.remove(val, id, tx=tx)
+ tx.postcommit(lambda: delattr(obj, self.iattr))
+
+ def update(self, id, obj, tx):
+ vals = frozenset(self.__get__(obj, None))
+ ivals = getattr(obj, self.iattr)
+ if vals != ivals:
+ idx = self.index()
+ for val in ivals - vals:
+ idx.remove(val, id, tx=tx)
+ for val in vals - ivals:
+ idx.put(val, id, tx=tx)
+ tx.postcommit(lambda: setattr(obj, self.iattr, vals))