X-Git-Url: http://dolda2000.com/gitweb/?p=didex.git;a=blobdiff_plain;f=didex%2Findex.py;h=0608318d3ca52db7c0dc30419202fad36c9d5fc8;hp=30dcfdf3a5c4c8b87fb7dd51dab4decc7972a53a;hb=947dfab3c174ecce6bd1ff18bdc4df7e0e4087c1;hpb=6efe4e234312927239701a865725786625407410 diff --git a/didex/index.py b/didex/index.py index 30dcfdf..0608318 100644 --- a/didex/index.py +++ b/didex/index.py @@ -2,7 +2,7 @@ import struct, contextlib, math from . import db, lib from .db import bd, txnfun, dloopfun -__all__ = ["maybe", "t_int", "t_uint", "t_float", "t_str", "ordered"] +__all__ = ["maybe", "t_bool", "t_int", "t_uint", "t_dbid", "t_float", "t_str", "t_casestr", "ordered"] deadlock = bd.DBLockDeadlockError notfound = bd.DBNotFoundError @@ -29,6 +29,14 @@ class simpletype(object): return cls(lambda ob: struct.pack(fmt, ob), lambda dat: struct.unpack(fmt, dat)[0]) +class foldtype(simpletype): + def __init__(self, encode, decode, fold): + super().__init__(encode, decode) + self.fold = fold + + def compare(self, a, b): + return super().compare(self.fold(a), self.fold(b)) + class maybe(object): def __init__(self, bk): self.bk = bk @@ -53,36 +61,66 @@ class compound(object): def __init__(self, *parts): self.parts = parts + small = object() + large = object() + def minim(self, *parts): + return parts + tuple([self.small] * (len(self.parts) - len(parts))) + def maxim(self, *parts): + return parts + tuple([self.large] * (len(self.parts) - len(parts))) + def encode(self, obs): if len(obs) != len(self.parts): raise ValueError("invalid length of compound data: " + str(len(obs)) + ", rather than " + len(self.parts)) buf = bytearray() for ob, part in zip(obs, self.parts): - dat = part.encode(ob) - if len(dat) < 128: - buf.append(0x80 | len(dat)) - buf.extend(dat) + if ob is self.small: + buf.append(0x01) + elif ob is self.large: + buf.append(0x02) else: - buf.extend(struct.pack(">i", len(dat))) - buf.extend(dat) + dat = part.encode(ob) + if len(dat) < 128: + buf.append(0x80 | len(dat)) + buf.extend(dat) + else: + buf.extend(struct.pack(">BI", 0, len(dat))) + buf.extend(dat) return bytes(buf) def decode(self, dat): ret = [] off = 0 for part in self.parts: - if dat[off] & 0x80: - ln = dat[off] & 0x7f - off += 1 + fl = dat[off] + off += 1 + if fl & 0x80: + ln = fl & 0x7f + elif fl == 0x01: + ret.append(self.small) + continue + elif fl == 0x02: + ret.append(self.large) + continue else: - ln = struct.unpack(">i", dat[off:off + 4])[0] + ln = struct.unpack(">I", dat[off:off + 4])[0] off += 4 - ret.append(part.decode(dat[off:off + len])) - off += len + ret.append(part.decode(dat[off:off + ln])) + off += ln return tuple(ret) def compare(self, al, bl): if (len(al) != len(self.parts)) or (len(bl) != len(self.parts)): raise ValueError("invalid length of compound data: " + str(len(al)) + ", " + str(len(bl)) + ", rather than " + len(self.parts)) for a, b, part in zip(al, bl, self.parts): + if a in (self.small, self.large) or b in (self.small, self.large): + if a is b: + return 0 + if a is self.small: + return -1 + elif b is self.small: + return 1 + elif a is self.large: + return 1 + elif b is self.large: + return -1 c = part.compare(a, b) if c != 0: return c @@ -102,11 +140,15 @@ def floatcmp(a, b): else: return 0 +t_bool = simpletype((lambda ob: b"\x01" if ob else b"\x00"), (lambda dat: False if dat == b"x\00" else True)) t_int = simpletype.struct(">q") t_uint = simpletype.struct(">Q") +t_dbid = t_uint t_float = simpletype.struct(">d") t_float.compare = floatcmp t_str = simpletype((lambda ob: ob.encode("utf-8")), (lambda dat: dat.decode("utf-8"))) +t_casestr = foldtype((lambda ob: ob.encode("utf-8")), (lambda dat: dat.decode("utf-8")), + (lambda st: st.lower())) class index(object): def __init__(self, db, name, datatype): @@ -117,9 +159,9 @@ class index(object): missing = object() class ordered(index, lib.closable): - def __init__(self, db, name, datatype, create=True): + def __init__(self, db, name, datatype, create=True, *, tx=None): super().__init__(db, name, datatype) - fl = bd.DB_THREAD | bd.DB_AUTO_COMMIT + fl = bd.DB_THREAD if create: fl |= bd.DB_CREATE def initdb(db): def compare(a, b): @@ -127,7 +169,7 @@ class ordered(index, lib.closable): return self.typ.compare(self.typ.decode(a), self.typ.decode(b)) db.set_flags(bd.DB_DUPSORT) db.set_bt_compare(compare) - self.bk = db._opendb("i-" + name, bd.DB_BTREE, fl, initdb) + self.bk = db._opendb("i-" + name, bd.DB_BTREE, fl, initdb, tx=tx) self.bk.set_get_returns_none(False) def close(self): @@ -155,14 +197,14 @@ class ordered(index, lib.closable): def _decode(self, d): k, v = d - k = self.type.decode(k) + k = self.typ.decode(k) v = struct.unpack(">Q", v)[0] return k, v @dloopfun def first(self): try: - if self.fd is None: + if self.fd is missing: self.item = self._decode(self.cur.first()) else: k, v = self._decode(self.cur.set_range(self.typ.encode(self.fd))) @@ -176,16 +218,20 @@ class ordered(index, lib.closable): @dloopfun def last(self): try: - if self.fd is None: + if self.ld is missing: self.item = self._decode(self.cur.last()) else: - k, v = self._decode(self.cur.set_range(self.typ.encode(self.ld))) - if self.fi: - while self.typ.compare(k, self.fd) == 0: + try: + k, v = self._decode(self.cur.set_range(self.typ.encode(self.ld))) + except notfound: + k, v = self._decode(self.cur.last()) + if self.li: + while self.typ.compare(k, self.ld) == 0: k, v = self._decode(self.cur.next()) - k, v = self._decode(self.cur.prev()) + while self.typ.compare(k, self.ld) > 0: + k, v = self._decode(self.cur.prev()) else: - while self.typ.compare(k, self.fd) >= 0: + while self.typ.compare(k, self.ld) >= 0: k, v = self._decode(self.cur.prev()) self.item = k, v except notfound: @@ -195,8 +241,9 @@ class ordered(index, lib.closable): def next(self): try: k, v = self.item = self._decode(self.cur.next()) - if ((self.li and self.typ.compare(k, self.ld) > 0) or - (not self.li and self.typ.compare(k, self.ld) >= 0)): + if (self.ld is not missing and + ((self.li and self.typ.compare(k, self.ld) > 0) or + (not self.li and self.typ.compare(k, self.ld) >= 0))): self.item = StopIteration except notfound: self.item = StopIteration @@ -205,20 +252,21 @@ class ordered(index, lib.closable): def prev(self): try: self.item = self._decode(self.cur.prev()) - if ((self.fi and self.typ.compare(k, self.fd) < 0) or - (not self.fi and self.typ.compare(k, self.fd) <= 0)): + if (self.fd is not missing and + ((self.fi and self.typ.compare(k, self.fd) < 0) or + (not self.fi and self.typ.compare(k, self.fd) <= 0))): self.item = StopIteration except notfound: self.item = StopIteration def __next__(self): - if self.item is StopIteration: - raise StopIteration() if self.item is None: if not self.rev: self.next() else: self.prev() + if self.item is StopIteration: + raise StopIteration() ret, self.item = self.item, None return ret @@ -231,7 +279,7 @@ class ordered(index, lib.closable): def get(self, *, match=missing, ge=missing, gt=missing, lt=missing, le=missing, all=False, reverse=False): if all: - cur = self.cursor(self, None, True, None, True, reverse) + cur = self.cursor(self, missing, True, missing, True, reverse) elif match is not missing: cur = self.cursor(self, match, True, match, True, reverse) elif ge is not missing or gt is not missing or lt is not missing or le is not missing: @@ -240,13 +288,13 @@ class ordered(index, lib.closable): elif gt is not missing: fd, fi = gt, False else: - fd, fi = None, True + fd, fi = missing, True if le is not missing: ld, li = le, True elif lt is not missing: ld, li = lt, False else: - ld, li = None, True + ld, li = missing, True cur = self.cursor(self, fd, fi, ld, li, reverse) else: raise NameError("invalid get() specification") @@ -262,7 +310,7 @@ class ordered(index, lib.closable): if not done: cur.close() - @txnfun(lambda self: self.db.env.env) + @txnfun(lambda self: self.db.env) def put(self, key, id, *, tx): obid = struct.pack(">Q", id) if not self.db.ob.has_key(obid, txn=tx.tx): @@ -273,7 +321,7 @@ class ordered(index, lib.closable): return False return True - @txnfun(lambda self: self.db.env.env) + @txnfun(lambda self: self.db.env) def remove(self, key, id, *, tx): obid = struct.pack(">Q", id) if not self.db.ob.has_key(obid, txn=tx.tx):