Ensure that environment maintanence runs regularly.
[didex.git] / didex / index.py
index 2844de0..0608318 100644 (file)
@@ -2,7 +2,7 @@ import struct, contextlib, math
 from . import db, lib
 from .db import bd, txnfun, dloopfun
 
-__all__ = ["maybe", "t_int", "t_uint", "t_float", "t_str", "ordered"]
+__all__ = ["maybe", "t_bool", "t_int", "t_uint", "t_dbid", "t_float", "t_str", "t_casestr", "ordered"]
 
 deadlock = bd.DBLockDeadlockError
 notfound = bd.DBNotFoundError
@@ -29,6 +29,14 @@ class simpletype(object):
         return cls(lambda ob: struct.pack(fmt, ob),
                    lambda dat: struct.unpack(fmt, dat)[0])
 
+class foldtype(simpletype):
+    def __init__(self, encode, decode, fold):
+        super().__init__(encode, decode)
+        self.fold = fold
+
+    def compare(self, a, b):
+        return super().compare(self.fold(a), self.fold(b))
+
 class maybe(object):
     def __init__(self, bk):
         self.bk = bk
@@ -53,36 +61,66 @@ class compound(object):
     def __init__(self, *parts):
         self.parts = parts
 
+    small = object()
+    large = object()
+    def minim(self, *parts):
+        return parts + tuple([self.small] * (len(self.parts) - len(parts)))
+    def maxim(self, *parts):
+        return parts + tuple([self.large] * (len(self.parts) - len(parts)))
+
     def encode(self, obs):
         if len(obs) != len(self.parts):
             raise ValueError("invalid length of compound data: " + str(len(obs)) + ", rather than " + len(self.parts))
         buf = bytearray()
         for ob, part in zip(obs, self.parts):
-            dat = part.encode(ob)
-            if len(dat) < 128:
-                buf.append(0x80 | len(dat))
-                buf.extend(dat)
+            if ob is self.small:
+                buf.append(0x01)
+            elif ob is self.large:
+                buf.append(0x02)
             else:
-                buf.extend(struct.pack(">i", len(dat)))
-                buf.extend(dat)
+                dat = part.encode(ob)
+                if len(dat) < 128:
+                    buf.append(0x80 | len(dat))
+                    buf.extend(dat)
+                else:
+                    buf.extend(struct.pack(">BI", 0, len(dat)))
+                    buf.extend(dat)
         return bytes(buf)
     def decode(self, dat):
         ret = []
         off = 0
         for part in self.parts:
-            if dat[off] & 0x80:
-                ln = dat[off] & 0x7f
-                off += 1
+            fl = dat[off]
+            off += 1
+            if fl & 0x80:
+                ln = fl & 0x7f
+            elif fl == 0x01:
+                ret.append(self.small)
+                continue
+            elif fl == 0x02:
+                ret.append(self.large)
+                continue
             else:
-                ln = struct.unpack(">i", dat[off:off + 4])[0]
+                ln = struct.unpack(">I", dat[off:off + 4])[0]
                 off += 4
-            ret.append(part.decode(dat[off:off + len]))
-            off += len
+            ret.append(part.decode(dat[off:off + ln]))
+            off += ln
         return tuple(ret)
     def compare(self, al, bl):
         if (len(al) != len(self.parts)) or (len(bl) != len(self.parts)):
             raise ValueError("invalid length of compound data: " + str(len(al)) + ", " + str(len(bl)) + ", rather than " + len(self.parts))
         for a, b, part in zip(al, bl, self.parts):
+            if a in (self.small, self.large) or b in (self.small, self.large):
+                if a is b:
+                    return 0
+                if a is self.small:
+                    return -1
+                elif b is self.small:
+                    return 1
+                elif a is self.large:
+                    return 1
+                elif b is self.large:
+                    return -1
             c = part.compare(a, b)
             if c != 0:
                 return c
@@ -102,11 +140,15 @@ def floatcmp(a, b):
     else:
         return 0
 
+t_bool = simpletype((lambda ob: b"\x01" if ob else b"\x00"), (lambda dat: False if dat == b"x\00" else True))
 t_int = simpletype.struct(">q")
 t_uint = simpletype.struct(">Q")
+t_dbid = t_uint
 t_float = simpletype.struct(">d")
 t_float.compare = floatcmp
 t_str = simpletype((lambda ob: ob.encode("utf-8")), (lambda dat: dat.decode("utf-8")))
+t_casestr = foldtype((lambda ob: ob.encode("utf-8")), (lambda dat: dat.decode("utf-8")),
+                     (lambda st: st.lower()))
 
 class index(object):
     def __init__(self, db, name, datatype):
@@ -117,9 +159,9 @@ class index(object):
 missing = object()
 
 class ordered(index, lib.closable):
-    def __init__(self, db, name, datatype, create=True):
+    def __init__(self, db, name, datatype, create=True, *, tx=None):
         super().__init__(db, name, datatype)
-        fl = bd.DB_THREAD | bd.DB_AUTO_COMMIT
+        fl = bd.DB_THREAD
         if create: fl |= bd.DB_CREATE
         def initdb(db):
             def compare(a, b):
@@ -127,7 +169,7 @@ class ordered(index, lib.closable):
                 return self.typ.compare(self.typ.decode(a), self.typ.decode(b))
             db.set_flags(bd.DB_DUPSORT)
             db.set_bt_compare(compare)
-        self.bk = db._opendb("i-" + name, bd.DB_BTREE, fl, initdb)
+        self.bk = db._opendb("i-" + name, bd.DB_BTREE, fl, initdb, tx=tx)
         self.bk.set_get_returns_none(False)
 
     def close(self):
@@ -162,7 +204,7 @@ class ordered(index, lib.closable):
         @dloopfun
         def first(self):
             try:
-                if self.fd is None:
+                if self.fd is missing:
                     self.item = self._decode(self.cur.first())
                 else:
                     k, v = self._decode(self.cur.set_range(self.typ.encode(self.fd)))
@@ -176,16 +218,20 @@ class ordered(index, lib.closable):
         @dloopfun
         def last(self):
             try:
-                if self.fd is None:
+                if self.ld is missing:
                     self.item = self._decode(self.cur.last())
                 else:
-                    k, v = self._decode(self.cur.set_range(self.typ.encode(self.ld)))
-                    if self.fi:
-                        while self.typ.compare(k, self.fd) == 0:
+                    try:
+                        k, v = self._decode(self.cur.set_range(self.typ.encode(self.ld)))
+                    except notfound:
+                        k, v = self._decode(self.cur.last())
+                    if self.li:
+                        while self.typ.compare(k, self.ld) == 0:
                             k, v = self._decode(self.cur.next())
-                        k, v = self._decode(self.cur.prev())
+                        while self.typ.compare(k, self.ld) > 0:
+                            k, v = self._decode(self.cur.prev())
                     else:
-                        while self.typ.compare(k, self.fd) >= 0:
+                        while self.typ.compare(k, self.ld) >= 0:
                             k, v = self._decode(self.cur.prev())
                     self.item = k, v
             except notfound:
@@ -195,8 +241,9 @@ class ordered(index, lib.closable):
         def next(self):
             try:
                 k, v = self.item = self._decode(self.cur.next())
-                if ((self.li and self.typ.compare(k, self.ld) > 0) or
-                    (not self.li and self.typ.compare(k, self.ld) >= 0)):
+                if (self.ld is not missing and
+                    ((self.li and self.typ.compare(k, self.ld) > 0) or
+                     (not self.li and self.typ.compare(k, self.ld) >= 0))):
                     self.item = StopIteration
             except notfound:
                 self.item = StopIteration
@@ -205,8 +252,9 @@ class ordered(index, lib.closable):
         def prev(self):
             try:
                 self.item = self._decode(self.cur.prev())
-                if ((self.fi and self.typ.compare(k, self.fd) < 0) or
-                    (not self.fi and self.typ.compare(k, self.fd) <= 0)):
+                if (self.fd is not missing and
+                    ((self.fi and self.typ.compare(k, self.fd) < 0) or
+                     (not self.fi and self.typ.compare(k, self.fd) <= 0))):
                     self.item = StopIteration
             except notfound:
                 self.item = StopIteration
@@ -231,7 +279,7 @@ class ordered(index, lib.closable):
 
     def get(self, *, match=missing, ge=missing, gt=missing, lt=missing, le=missing, all=False, reverse=False):
         if all:
-            cur = self.cursor(self, None, True, None, True, reverse)
+            cur = self.cursor(self, missing, True, missing, True, reverse)
         elif match is not missing:
             cur = self.cursor(self, match, True, match, True, reverse)
         elif ge is not missing or gt is not missing or lt is not missing or le is not missing:
@@ -240,13 +288,13 @@ class ordered(index, lib.closable):
             elif gt is not missing:
                 fd, fi = gt, False
             else:
-                fd, fi = None, True
+                fd, fi = missing, True
             if le is not missing:
                 ld, li = le, True
             elif lt is not missing:
                 ld, li = lt, False
             else:
-                ld, li = None, True
+                ld, li = missing, True
             cur = self.cursor(self, fd, fi, ld, li, reverse)
         else:
             raise NameError("invalid get() specification")
@@ -262,7 +310,7 @@ class ordered(index, lib.closable):
             if not done:
                 cur.close()
 
-    @txnfun(lambda self: self.db.env.env)
+    @txnfun(lambda self: self.db.env)
     def put(self, key, id, *, tx):
         obid = struct.pack(">Q", id)
         if not self.db.ob.has_key(obid, txn=tx.tx):
@@ -273,7 +321,7 @@ class ordered(index, lib.closable):
             return False
         return True
 
-    @txnfun(lambda self: self.db.env.env)
+    @txnfun(lambda self: self.db.env)
     def remove(self, key, id, *, tx):
         obid = struct.pack(">Q", id)
         if not self.db.ob.has_key(obid, txn=tx.tx):