Added support for compound indices.
[didex.git] / didex / index.py
index 5d206ff..cdd4069 100644 (file)
@@ -1,6 +1,8 @@
-import struct, contextlib
+import struct, contextlib, math
 from . import db, lib
-from .db import bd
+from .db import bd, txnfun
+
+__all__ = ["maybe", "t_int", "t_uint", "t_float", "t_str", "ordered"]
 
 deadlock = bd.DBLockDeadlockError
 notfound = bd.DBNotFoundError
@@ -47,7 +49,64 @@ class maybe(object):
         else:
             return self.bk.compare(a[1:], b[1:])
 
-t_int = simpletype.struct(">Q")
+class compound(object):
+    def __init__(self, *parts):
+        self.parts = parts
+
+    def encode(self, obs):
+        if len(obs) != len(self.parts):
+            raise ValueError("invalid length of compound data: " + str(len(obs)) + ", rather than " + len(self.parts))
+        buf = bytearray()
+        for ob, part in zip(obs, self.parts):
+            dat = part.encode(ob)
+            if len(dat) < 128:
+                buf.append(0x80 | len(dat))
+                buf.extend(dat)
+            else:
+                buf.extend(struct.pack(">i", len(dat)))
+                buf.extend(dat)
+        return bytes(buf)
+    def decode(self, dat):
+        ret = []
+        off = 0
+        for part in self.parts:
+            if dat[off] & 0x80:
+                ln = dat[off] & 0x7f
+                off += 1
+            else:
+                ln = struct.unpack(">i", dat[off:off + 4])[0]
+                off += 4
+            ret.append(part.decode(dat[off:off + len]))
+            off += len
+        return tuple(ret)
+    def compare(self, al, bl):
+        if (len(al) != len(self.parts)) or (len(bl) != len(self.parts)):
+            raise ValueError("invalid length of compound data: " + str(len(al)) + ", " + str(len(bl)) + ", rather than " + len(self.parts))
+        for a, b, part in zip(al, bl, self.parts):
+            c = part.compare(a, b)
+            if c != 0:
+                return c
+        return 0
+
+def floatcmp(a, b):
+    if math.isnan(a) and math.isnan(b):
+        return 0
+    elif math.isnan(a):
+        return -1
+    elif math.isnan(b):
+        return 1
+    elif a < b:
+        return -1
+    elif a > b:
+        return 1
+    else:
+        return 0
+
+t_int = simpletype.struct(">q")
+t_uint = simpletype.struct(">Q")
+t_float = simpletype.struct(">d")
+t_float.compare = floatcmp
+t_str = simpletype((lambda ob: ob.encode("utf-8")), (lambda dat: dat.decode("utf-8")))
 
 class index(object):
     def __init__(self, db, name, datatype):
@@ -58,9 +117,8 @@ class index(object):
 missing = object()
 
 class ordered(index, lib.closable):
-    def __init__(self, db, name, datatype, duplicates, create=True):
+    def __init__(self, db, name, datatype, create=True):
         super().__init__(db, name, datatype)
-        self.dup = duplicates
         fl = bd.DB_THREAD | bd.DB_AUTO_COMMIT
         if create: fl |= bd.DB_CREATE
         def initdb(db):
@@ -177,39 +235,29 @@ class ordered(index, lib.closable):
             except deadlock:
                 continue
 
-    def put(self, key, id):
-        while True:
-            try:
-                with db.txn(self.db.env.env) as tx:
-                    obid = struct.pack(">Q", id)
-                    if not self.db.ob.has_key(obid, txn=tx.tx):
-                        raise ValueError("no such object in database: " + str(id))
-                    try:
-                        self.bk.put(self.typ.encode(key), obid, txn=tx.tx, flags=bd.DB_NODUPDATA)
-                    except bd.DBKeyExistError:
-                        return False
-                    tx.commit()
-                    return True
-            except deadlock:
-                continue
+    @txnfun(lambda self: self.db.env.env)
+    def put(self, key, id, *, tx):
+        obid = struct.pack(">Q", id)
+        if not self.db.ob.has_key(obid, txn=tx.tx):
+            raise ValueError("no such object in database: " + str(id))
+        try:
+            self.bk.put(self.typ.encode(key), obid, txn=tx.tx, flags=bd.DB_NODUPDATA)
+        except bd.DBKeyExistError:
+            return False
+        return True
 
-    def remove(self, key, id):
-        while True:
+    @txnfun(lambda self: self.db.env.env)
+    def remove(self, key, id, *, tx):
+        obid = struct.pack(">Q", id)
+        if not self.db.ob.has_key(obid, txn=tx.tx):
+            raise ValueError("no such object in database: " + str(id))
+        cur = self.bk.cursor(txn=tx.tx)
+        try:
             try:
-                with db.txn(self.db.env.env) as tx:
-                    obid = struct.pack(">Q", id)
-                    if not self.db.ob.has_key(obid, txn=tx.tx):
-                        raise ValueError("no such object in database: " + str(id))
-                    cur = self.bk.cursor(txn=tx.tx)
-                    try:
-                        try:
-                            cur.get_both(self.typ.encode(key), obid)
-                        except notfound:
-                            return False
-                        cur.delete()
-                    finally:
-                        cur.close()
-                    tx.commit()
-                    return True
-            except deadlock:
-                continue
+                cur.get_both(self.typ.encode(key), obid)
+            except notfound:
+                return False
+            cur.delete()
+        finally:
+            cur.close()
+        return True