Ensure that environment maintanence runs regularly.
[didex.git] / didex / index.py
index 2f3fdef..0608318 100644 (file)
@@ -1,6 +1,8 @@
-import struct, contextlib
+import struct, contextlib, math
 from . import db, lib
-from .db import bd, txnfun
+from .db import bd, txnfun, dloopfun
+
+__all__ = ["maybe", "t_bool", "t_int", "t_uint", "t_dbid", "t_float", "t_str", "t_casestr", "ordered"]
 
 deadlock = bd.DBLockDeadlockError
 notfound = bd.DBNotFoundError
@@ -27,6 +29,14 @@ class simpletype(object):
         return cls(lambda ob: struct.pack(fmt, ob),
                    lambda dat: struct.unpack(fmt, dat)[0])
 
+class foldtype(simpletype):
+    def __init__(self, encode, decode, fold):
+        super().__init__(encode, decode)
+        self.fold = fold
+
+    def compare(self, a, b):
+        return super().compare(self.fold(a), self.fold(b))
+
 class maybe(object):
     def __init__(self, bk):
         self.bk = bk
@@ -47,7 +57,98 @@ class maybe(object):
         else:
             return self.bk.compare(a[1:], b[1:])
 
-t_int = simpletype.struct(">Q")
+class compound(object):
+    def __init__(self, *parts):
+        self.parts = parts
+
+    small = object()
+    large = object()
+    def minim(self, *parts):
+        return parts + tuple([self.small] * (len(self.parts) - len(parts)))
+    def maxim(self, *parts):
+        return parts + tuple([self.large] * (len(self.parts) - len(parts)))
+
+    def encode(self, obs):
+        if len(obs) != len(self.parts):
+            raise ValueError("invalid length of compound data: " + str(len(obs)) + ", rather than " + len(self.parts))
+        buf = bytearray()
+        for ob, part in zip(obs, self.parts):
+            if ob is self.small:
+                buf.append(0x01)
+            elif ob is self.large:
+                buf.append(0x02)
+            else:
+                dat = part.encode(ob)
+                if len(dat) < 128:
+                    buf.append(0x80 | len(dat))
+                    buf.extend(dat)
+                else:
+                    buf.extend(struct.pack(">BI", 0, len(dat)))
+                    buf.extend(dat)
+        return bytes(buf)
+    def decode(self, dat):
+        ret = []
+        off = 0
+        for part in self.parts:
+            fl = dat[off]
+            off += 1
+            if fl & 0x80:
+                ln = fl & 0x7f
+            elif fl == 0x01:
+                ret.append(self.small)
+                continue
+            elif fl == 0x02:
+                ret.append(self.large)
+                continue
+            else:
+                ln = struct.unpack(">I", dat[off:off + 4])[0]
+                off += 4
+            ret.append(part.decode(dat[off:off + ln]))
+            off += ln
+        return tuple(ret)
+    def compare(self, al, bl):
+        if (len(al) != len(self.parts)) or (len(bl) != len(self.parts)):
+            raise ValueError("invalid length of compound data: " + str(len(al)) + ", " + str(len(bl)) + ", rather than " + len(self.parts))
+        for a, b, part in zip(al, bl, self.parts):
+            if a in (self.small, self.large) or b in (self.small, self.large):
+                if a is b:
+                    return 0
+                if a is self.small:
+                    return -1
+                elif b is self.small:
+                    return 1
+                elif a is self.large:
+                    return 1
+                elif b is self.large:
+                    return -1
+            c = part.compare(a, b)
+            if c != 0:
+                return c
+        return 0
+
+def floatcmp(a, b):
+    if math.isnan(a) and math.isnan(b):
+        return 0
+    elif math.isnan(a):
+        return -1
+    elif math.isnan(b):
+        return 1
+    elif a < b:
+        return -1
+    elif a > b:
+        return 1
+    else:
+        return 0
+
+t_bool = simpletype((lambda ob: b"\x01" if ob else b"\x00"), (lambda dat: False if dat == b"x\00" else True))
+t_int = simpletype.struct(">q")
+t_uint = simpletype.struct(">Q")
+t_dbid = t_uint
+t_float = simpletype.struct(">d")
+t_float.compare = floatcmp
+t_str = simpletype((lambda ob: ob.encode("utf-8")), (lambda dat: dat.decode("utf-8")))
+t_casestr = foldtype((lambda ob: ob.encode("utf-8")), (lambda dat: dat.decode("utf-8")),
+                     (lambda st: st.lower()))
 
 class index(object):
     def __init__(self, db, name, datatype):
@@ -58,10 +159,9 @@ class index(object):
 missing = object()
 
 class ordered(index, lib.closable):
-    def __init__(self, db, name, datatype, duplicates, create=True):
+    def __init__(self, db, name, datatype, create=True, *, tx=None):
         super().__init__(db, name, datatype)
-        self.dup = duplicates
-        fl = bd.DB_THREAD | bd.DB_AUTO_COMMIT
+        fl = bd.DB_THREAD
         if create: fl |= bd.DB_CREATE
         def initdb(db):
             def compare(a, b):
@@ -69,49 +169,106 @@ class ordered(index, lib.closable):
                 return self.typ.compare(self.typ.decode(a), self.typ.decode(b))
             db.set_flags(bd.DB_DUPSORT)
             db.set_bt_compare(compare)
-        self.bk = db._opendb("i-" + name, bd.DB_BTREE, fl, initdb)
+        self.bk = db._opendb("i-" + name, bd.DB_BTREE, fl, initdb, tx=tx)
         self.bk.set_get_returns_none(False)
 
     def close(self):
         self.bk.close()
 
     class cursor(lib.closable):
-        def __init__(self, idx, cur, item, stop):
+        def __init__(self, idx, fd, fi, ld, li, reverse):
             self.idx = idx
-            self.cur = cur
-            self.item = item
-            self.stop = stop
+            self.typ = idx.typ
+            self.cur = self.idx.bk.cursor()
+            self.item = None
+            self.fd = fd
+            self.fi = fi
+            self.ld = ld
+            self.li = li
+            self.rev = reverse
 
         def close(self):
             if self.cur is not None:
                 self.cur.close()
+                self.cur = None
 
         def __iter__(self):
             return self
 
-        def peek(self):
-            if self.item is None:
-                raise StopIteration()
-            rk, rv = self.item
-            rk = self.idx.typ.decode(rk)
-            rv = struct.unpack(">Q", rv)[0]
-            if self.stop(rk):
-                self.item = None
-                raise StopIteration()
-            return rk, rv
+        def _decode(self, d):
+            k, v = d
+            k = self.typ.decode(k)
+            v = struct.unpack(">Q", v)[0]
+            return k, v
 
-        def __next__(self):
-            rk, rv = self.peek()
+        @dloopfun
+        def first(self):
+            try:
+                if self.fd is missing:
+                    self.item = self._decode(self.cur.first())
+                else:
+                    k, v = self._decode(self.cur.set_range(self.typ.encode(self.fd)))
+                    if not self.fi:
+                        while self.typ.compare(k, self.fd) == 0:
+                            k, v = self._decode(self.cur.next())
+                    self.item = k, v
+            except notfound:
+                self.item = StopIteration
+
+        @dloopfun
+        def last(self):
             try:
-                while True:
+                if self.ld is missing:
+                    self.item = self._decode(self.cur.last())
+                else:
                     try:
-                        self.item = self.cur.next()
-                        break
-                    except deadlock:
-                        continue
+                        k, v = self._decode(self.cur.set_range(self.typ.encode(self.ld)))
+                    except notfound:
+                        k, v = self._decode(self.cur.last())
+                    if self.li:
+                        while self.typ.compare(k, self.ld) == 0:
+                            k, v = self._decode(self.cur.next())
+                        while self.typ.compare(k, self.ld) > 0:
+                            k, v = self._decode(self.cur.prev())
+                    else:
+                        while self.typ.compare(k, self.ld) >= 0:
+                            k, v = self._decode(self.cur.prev())
+                    self.item = k, v
             except notfound:
-                self.item = None
-            return rk, rv
+                self.item = StopIteration
+
+        @dloopfun
+        def next(self):
+            try:
+                k, v = self.item = self._decode(self.cur.next())
+                if (self.ld is not missing and
+                    ((self.li and self.typ.compare(k, self.ld) > 0) or
+                     (not self.li and self.typ.compare(k, self.ld) >= 0))):
+                    self.item = StopIteration
+            except notfound:
+                self.item = StopIteration
+
+        @dloopfun
+        def prev(self):
+            try:
+                self.item = self._decode(self.cur.prev())
+                if (self.fd is not missing and
+                    ((self.fi and self.typ.compare(k, self.fd) < 0) or
+                     (not self.fi and self.typ.compare(k, self.fd) <= 0))):
+                    self.item = StopIteration
+            except notfound:
+                self.item = StopIteration
+
+        def __next__(self):
+            if self.item is None:
+                if not self.rev:
+                    self.next()
+                else:
+                    self.prev()
+            if self.item is StopIteration:
+                raise StopIteration()
+            ret, self.item = self.item, None
+            return ret
 
         def skip(self, n=1):
             try:
@@ -120,64 +277,40 @@ class ordered(index, lib.closable):
             except StopIteration:
                 return
 
-    def get(self, *, match=missing, ge=missing, gt=missing, lt=missing, le=missing, all=False):
-        while True:
-            try:
-                cur = self.bk.cursor()
-                done = False
-                try:
-                    if match is not missing:
-                        try:
-                            k, v = cur.set(self.typ.encode(match))
-                        except notfound:
-                            return self.cursor(None, None, None, None)
-                        else:
-                            done = True
-                            return self.cursor(self, cur, (k, v), lambda o: (self.typ.compare(o, match) != 0))
-                    elif all:
-                        try:
-                            k, v = cur.first()
-                        except notfound:
-                            return self.cursor(None, None, None, None)
-                        else:
-                            done = True
-                            return self.cursor(self, cur, (k, v), lambda o: False)
-                    elif ge is not missing or gt is not missing or lt is not missing or le is not missing:
-                        skip = False
-                        try:
-                            if ge is not missing:
-                                k, v = cur.set_range(self.typ.encode(ge))
-                            elif gt is not missing:
-                                k, v = cur.set_range(self.typ.encode(gt))
-                                skip = True
-                            else:
-                                k, v = cur.first()
-                        except notfound:
-                            return self.cursor(None, None, None, None)
-                        if lt is not missing:
-                            stop = lambda o: self.typ.compare(o, lt) >= 0
-                        elif le is not missing:
-                            stop = lambda o: self.typ.compare(o, le) > 0
-                        else:
-                            stop = lambda o: False
-                        ret = self.cursor(self, cur, (k, v), stop)
-                        if skip:
-                            try:
-                                while self.typ.compare(ret.peek()[0], gt) == 0:
-                                    next(ret)
-                            except StopIteration:
-                                pass
-                        done = True
-                        return ret
-                    else:
-                        raise NameError("invalid get() specification")
-                finally:
-                    if not done:
-                        cur.close()
-            except deadlock:
-                continue
+    def get(self, *, match=missing, ge=missing, gt=missing, lt=missing, le=missing, all=False, reverse=False):
+        if all:
+            cur = self.cursor(self, missing, True, missing, True, reverse)
+        elif match is not missing:
+            cur = self.cursor(self, match, True, match, True, reverse)
+        elif ge is not missing or gt is not missing or lt is not missing or le is not missing:
+            if ge is not missing:
+                fd, fi = ge, True
+            elif gt is not missing:
+                fd, fi = gt, False
+            else:
+                fd, fi = missing, True
+            if le is not missing:
+                ld, li = le, True
+            elif lt is not missing:
+                ld, li = lt, False
+            else:
+                ld, li = missing, True
+            cur = self.cursor(self, fd, fi, ld, li, reverse)
+        else:
+            raise NameError("invalid get() specification")
+        done = False
+        try:
+            if not reverse:
+                cur.first()
+            else:
+                cur.last()
+            done = True
+            return cur
+        finally:
+            if not done:
+                cur.close()
 
-    @txnfun(lambda self: self.db.env.env)
+    @txnfun(lambda self: self.db.env)
     def put(self, key, id, *, tx):
         obid = struct.pack(">Q", id)
         if not self.db.ob.has_key(obid, txn=tx.tx):
@@ -188,7 +321,7 @@ class ordered(index, lib.closable):
             return False
         return True
 
-    @txnfun(lambda self: self.db.env.env)
+    @txnfun(lambda self: self.db.env)
     def remove(self, key, id, *, tx):
         obid = struct.pack(">Q", id)
         if not self.db.ob.has_key(obid, txn=tx.tx):