From: Fredrik Tolf Date: Sat, 21 Mar 2015 07:09:52 +0000 (+0100) Subject: Added support for compound indices. X-Git-Url: http://dolda2000.com/gitweb/?p=didex.git;a=commitdiff_plain;h=bd14729f305c077e65963fba4aeaab3baf8ee653 Added support for compound indices. --- diff --git a/didex/index.py b/didex/index.py index a2828ce..cdd4069 100644 --- a/didex/index.py +++ b/didex/index.py @@ -49,6 +49,45 @@ class maybe(object): else: return self.bk.compare(a[1:], b[1:]) +class compound(object): + def __init__(self, *parts): + self.parts = parts + + def encode(self, obs): + if len(obs) != len(self.parts): + raise ValueError("invalid length of compound data: " + str(len(obs)) + ", rather than " + len(self.parts)) + buf = bytearray() + for ob, part in zip(obs, self.parts): + dat = part.encode(ob) + if len(dat) < 128: + buf.append(0x80 | len(dat)) + buf.extend(dat) + else: + buf.extend(struct.pack(">i", len(dat))) + buf.extend(dat) + return bytes(buf) + def decode(self, dat): + ret = [] + off = 0 + for part in self.parts: + if dat[off] & 0x80: + ln = dat[off] & 0x7f + off += 1 + else: + ln = struct.unpack(">i", dat[off:off + 4])[0] + off += 4 + ret.append(part.decode(dat[off:off + len])) + off += len + return tuple(ret) + def compare(self, al, bl): + if (len(al) != len(self.parts)) or (len(bl) != len(self.parts)): + raise ValueError("invalid length of compound data: " + str(len(al)) + ", " + str(len(bl)) + ", rather than " + len(self.parts)) + for a, b, part in zip(al, bl, self.parts): + c = part.compare(a, b) + if c != 0: + return c + return 0 + def floatcmp(a, b): if math.isnan(a) and math.isnan(b): return 0 diff --git a/didex/values.py b/didex/values.py index 090edd4..9a501eb 100644 --- a/didex/values.py +++ b/didex/values.py @@ -1,5 +1,5 @@ import threading -from . import store, lib +from . import store, lib, index from .store import storedesc __all__ = ["simple", "multi"] @@ -23,16 +23,13 @@ class cursor(lib.closable): self.bk.skip(n) class base(storedesc): - def __init__(self, store, indextype, name, datatype, default): + def __init__(self, store, indextype, name, datatype): self.store = store self.indextype = indextype self.name = name self.typ = datatype - self.default = default self.idx = None self.lk = threading.Lock() - self.mattr = "__idx_%s_new" % name - self.iattr = "__idx_%s_cur" % name def index(self): with self.lk: @@ -40,6 +37,16 @@ class base(storedesc): self.idx = self.indextype(self.store.db(), self.name, self.typ) return self.idx + def get(self, **kwargs): + return cursor(self.index().get(**kwargs), self.store) + +class descbase(base): + def __init__(self, store, indextype, name, datatype, default): + super().__init__(store, indextype, name, datatype) + self.default = default + self.mattr = "__idx_%s_new" % name + self.iattr = "__idx_%s_cur" % name + def __get__(self, obj, cls): if obj is None: return self return getattr(obj, self.mattr, self.default) @@ -50,10 +57,7 @@ class base(storedesc): def __delete__(self, obj): delattr(obj, self.mattr) - def get(self, **kwargs): - return cursor(self.index().get(**kwargs), self.store) - -class simple(base): +class simple(descbase): def __init__(self, store, indextype, name, datatype, default=None): super().__init__(store, indextype, name, datatype, default) @@ -75,7 +79,7 @@ class simple(base): idx.put(val, id, tx=tx) tx.postcommit(lambda: setattr(obj, self.iattr, val)) -class multi(base): +class multi(descbase): def __init__(self, store, indextype, name, datatype): super().__init__(store, indextype, name, datatype, ()) @@ -102,3 +106,27 @@ class multi(base): for val in vals - ivals: idx.put(val, id, tx=tx) tx.postcommit(lambda: setattr(obj, self.iattr, vals)) + +class compound(base): + def __init__(self, indextype, name, *parts): + super().__init__(parts[0].store, indextype, name, index.compound(*(part.typ for part in parts))) + self.parts = parts + self.iattr = "__idx_%s_cur" % name + + def register(self, id, obj, tx): + val = tuple(part.__get__(obj, None) for part in self.parts) + self.index().put(val, id, tx=tx) + tx.postcommit(lambda: setattr(obj, self.iattr, val)) + + def unregister(self, id, obj, tx): + self.index().remove(getattr(obj, self.iattr), id, tx=tx) + tx.postcommit(lambda: delattr(obj, self.iattr)) + + def update(self, id, obj, tx): + val = tuple(part.__get__(obj, None) for part in self.parts) + ival = getattr(obj, self.iattr) + if val != ival: + idx = self.index() + idx.remove(ival, id, tx=tx) + idx.put(val, id, tx=tx) + tx.postcommit(lambda: setattr(obj, self.iattr, val))