X-Git-Url: http://dolda2000.com/gitweb/?p=didex.git;a=blobdiff_plain;f=didex%2Fvalues.py;h=bff26e7c4eae145fc71268695eaf7bcc2861dcf5;hp=c2b6d1ff2e82f1e58ef1cb68356bbe77a8647713;hb=9ef33548b36d539eac18cc28a4fc6835ff2d6ee6;hpb=177fbee6346cbe47e3ac689814fc34f8b75e186a diff --git a/didex/values.py b/didex/values.py index c2b6d1f..bff26e7 100644 --- a/didex/values.py +++ b/didex/values.py @@ -2,7 +2,7 @@ import threading from . import store, lib, index from .store import storedesc -__all__ = ["simple", "multi", "compound"] +__all__ = ["simple", "multi", "compound", "idlink"] class cursor(lib.closable): def __init__(self, bk, st): @@ -31,14 +31,35 @@ class base(storedesc): self.idx = None self.lk = threading.Lock() - def index(self): + def index(self, tx): with self.lk: if self.idx is None: - self.idx = self.indextype(self.store.db(), self.name, self.typ) + self.idx = self.indextype(self.store.db(), self.name, self.typ, tx=tx) return self.idx def get(self, **kwargs): - return cursor(self.index().get(**kwargs), self.store) + return cursor(self.index(None).get(**kwargs), self.store) + + def get1(self, *, check=False, default=KeyError, **kwargs): + with self.get(**kwargs) as cursor: + try: + k, v = next(cursor) + except StopIteration: + if default is not KeyError: + return default + raise KeyError("no matches in " + self.name, kwargs) + if check: + try: + next(cursor) + except StopIteration: + pass + else: + raise ValueError("unexpected multiple matchies in " + self.name, kwargs) + return v + + def list(self, **kwargs): + with self.get(**kwargs) as cursor: + return [v for k, v in cursor] class descbase(base): def __init__(self, store, indextype, name, datatype, default): @@ -63,18 +84,18 @@ class simple(descbase): def register(self, id, obj, tx): val = self.__get__(obj, None) - self.index().put(val, id, tx=tx) + self.index(tx).put(val, id, tx=tx) tx.postcommit(lambda: setattr(obj, self.iattr, val)) def unregister(self, id, obj, tx): - self.index().remove(getattr(obj, self.iattr), id, tx=tx) + self.index(tx).remove(getattr(obj, self.iattr), id, tx=tx) tx.postcommit(lambda: delattr(obj, self.iattr)) def update(self, id, obj, tx): val = self.__get__(obj, None) ival = getattr(obj, self.iattr) if val != ival: - idx = self.index() + idx = self.index(tx) idx.remove(ival, id, tx=tx) idx.put(val, id, tx=tx) tx.postcommit(lambda: setattr(obj, self.iattr, val)) @@ -85,13 +106,13 @@ class multi(descbase): def register(self, id, obj, tx): vals = frozenset(self.__get__(obj, None)) - idx = self.index() + idx = self.index(tx) for val in vals: idx.put(val, id, tx=tx) tx.postcommit(lambda: setattr(obj, self.iattr, vals)) def unregister(self, id, obj, tx): - idx = self.index() + idx = self.index(tx) for val in getattr(obj, self.iattr): idx.remove(val, id, tx=tx) tx.postcommit(lambda: delattr(obj, self.iattr)) @@ -100,7 +121,7 @@ class multi(descbase): vals = frozenset(self.__get__(obj, None)) ivals = getattr(obj, self.iattr) if vals != ivals: - idx = self.index() + idx = self.index(tx) for val in ivals - vals: idx.remove(val, id, tx=tx) for val in vals - ivals: @@ -126,18 +147,36 @@ class compound(base): def register(self, id, obj, tx): val = tuple(part.__get__(obj, None) for part in self.parts) - self.index().put(val, id, tx=tx) + self.index(tx).put(val, id, tx=tx) tx.postcommit(lambda: setattr(obj, self.iattr, val)) def unregister(self, id, obj, tx): - self.index().remove(getattr(obj, self.iattr), id, tx=tx) + self.index(tx).remove(getattr(obj, self.iattr), id, tx=tx) tx.postcommit(lambda: delattr(obj, self.iattr)) def update(self, id, obj, tx): val = tuple(part.__get__(obj, None) for part in self.parts) ival = getattr(obj, self.iattr) if val != ival: - idx = self.index() + idx = self.index(tx) idx.remove(ival, id, tx=tx) idx.put(val, id, tx=tx) tx.postcommit(lambda: setattr(obj, self.iattr, val)) + +class idlink(object): + def __init__(self, name, atype): + self.atype = atype + self.battr = "__idlink_%s" % name + + def __get__(self, obj, cls): + if obj is None: return self + ret = self.atype.store.get(getattr(obj, self.battr)) + assert isinstance(ret, self.atype) + return ret + + def __set__(self, obj, val): + assert isinstance(val, self.atype) + setattr(obj, self.battr, val.id) + + def __delete__(self, obj): + delattr(obj, self.battr)