Fixed some index bugs.
[didex.git] / didex / index.py
index 4d42506..951eb3f 100644 (file)
@@ -1,6 +1,8 @@
 import struct, contextlib, math
 from . import db, lib
-from .db import bd, txnfun
+from .db import bd, txnfun, dloopfun
+
+__all__ = ["maybe", "t_int", "t_uint", "t_float", "t_str", "ordered"]
 
 deadlock = bd.DBLockDeadlockError
 notfound = bd.DBNotFoundError
@@ -47,6 +49,45 @@ class maybe(object):
         else:
             return self.bk.compare(a[1:], b[1:])
 
+class compound(object):
+    def __init__(self, *parts):
+        self.parts = parts
+
+    def encode(self, obs):
+        if len(obs) != len(self.parts):
+            raise ValueError("invalid length of compound data: " + str(len(obs)) + ", rather than " + len(self.parts))
+        buf = bytearray()
+        for ob, part in zip(obs, self.parts):
+            dat = part.encode(ob)
+            if len(dat) < 128:
+                buf.append(0x80 | len(dat))
+                buf.extend(dat)
+            else:
+                buf.extend(struct.pack(">i", len(dat)))
+                buf.extend(dat)
+        return bytes(buf)
+    def decode(self, dat):
+        ret = []
+        off = 0
+        for part in self.parts:
+            if dat[off] & 0x80:
+                ln = dat[off] & 0x7f
+                off += 1
+            else:
+                ln = struct.unpack(">i", dat[off:off + 4])[0]
+                off += 4
+            ret.append(part.decode(dat[off:off + len]))
+            off += len
+        return tuple(ret)
+    def compare(self, al, bl):
+        if (len(al) != len(self.parts)) or (len(bl) != len(self.parts)):
+            raise ValueError("invalid length of compound data: " + str(len(al)) + ", " + str(len(bl)) + ", rather than " + len(self.parts))
+        for a, b, part in zip(al, bl, self.parts):
+            c = part.compare(a, b)
+            if c != 0:
+                return c
+        return 0
+
 def floatcmp(a, b):
     if math.isnan(a) and math.isnan(b):
         return 0
@@ -93,42 +134,99 @@ class ordered(index, lib.closable):
         self.bk.close()
 
     class cursor(lib.closable):
-        def __init__(self, idx, cur, item, stop):
+        def __init__(self, idx, fd, fi, ld, li, reverse):
             self.idx = idx
-            self.cur = cur
-            self.item = item
-            self.stop = stop
+            self.typ = idx.typ
+            self.cur = self.idx.bk.cursor()
+            self.item = None
+            self.fd = fd
+            self.fi = fi
+            self.ld = ld
+            self.li = li
+            self.rev = reverse
 
         def close(self):
             if self.cur is not None:
                 self.cur.close()
+                self.cur = None
 
         def __iter__(self):
             return self
 
-        def peek(self):
-            if self.item is None:
-                raise StopIteration()
-            rk, rv = self.item
-            rk = self.idx.typ.decode(rk)
-            rv = struct.unpack(">Q", rv)[0]
-            if self.stop(rk):
-                self.item = None
-                raise StopIteration()
-            return rk, rv
+        def _decode(self, d):
+            k, v = d
+            k = self.typ.decode(k)
+            v = struct.unpack(">Q", v)[0]
+            return k, v
 
-        def __next__(self):
-            rk, rv = self.peek()
+        @dloopfun
+        def first(self):
             try:
-                while True:
+                if self.fd is missing:
+                    self.item = self._decode(self.cur.first())
+                else:
+                    k, v = self._decode(self.cur.set_range(self.typ.encode(self.fd)))
+                    if not self.fi:
+                        while self.typ.compare(k, self.fd) == 0:
+                            k, v = self._decode(self.cur.next())
+                    self.item = k, v
+            except notfound:
+                self.item = StopIteration
+
+        @dloopfun
+        def last(self):
+            try:
+                if self.ld is missing:
+                    self.item = self._decode(self.cur.last())
+                else:
                     try:
-                        self.item = self.cur.next()
-                        break
-                    except deadlock:
-                        continue
+                        k, v = self._decode(self.cur.set_range(self.typ.encode(self.ld)))
+                    except notfound:
+                        k, v = self._decode(self.cur.last())
+                    if self.li:
+                        while self.typ.compare(k, self.ld) == 0:
+                            k, v = self._decode(self.cur.next())
+                        while self.typ.compare(k, self.ld) > 0:
+                            k, v = self._decode(self.cur.prev())
+                    else:
+                        while self.typ.compare(k, self.ld) >= 0:
+                            k, v = self._decode(self.cur.prev())
+                    self.item = k, v
+            except notfound:
+                self.item = StopIteration
+
+        @dloopfun
+        def next(self):
+            try:
+                k, v = self.item = self._decode(self.cur.next())
+                if (self.ld is not missing and
+                    ((self.li and self.typ.compare(k, self.ld) > 0) or
+                     (not self.li and self.typ.compare(k, self.ld) >= 0))):
+                    self.item = StopIteration
             except notfound:
-                self.item = None
-            return rk, rv
+                self.item = StopIteration
+
+        @dloopfun
+        def prev(self):
+            try:
+                self.item = self._decode(self.cur.prev())
+                if (self.fd is not missing and
+                    ((self.fi and self.typ.compare(k, self.fd) < 0) or
+                     (not self.fi and self.typ.compare(k, self.fd) <= 0))):
+                    self.item = StopIteration
+            except notfound:
+                self.item = StopIteration
+
+        def __next__(self):
+            if self.item is None:
+                if not self.rev:
+                    self.next()
+                else:
+                    self.prev()
+            if self.item is StopIteration:
+                raise StopIteration()
+            ret, self.item = self.item, None
+            return ret
 
         def skip(self, n=1):
             try:
@@ -137,62 +235,38 @@ class ordered(index, lib.closable):
             except StopIteration:
                 return
 
-    def get(self, *, match=missing, ge=missing, gt=missing, lt=missing, le=missing, all=False):
-        while True:
-            try:
-                cur = self.bk.cursor()
-                done = False
-                try:
-                    if match is not missing:
-                        try:
-                            k, v = cur.set(self.typ.encode(match))
-                        except notfound:
-                            return self.cursor(None, None, None, None)
-                        else:
-                            done = True
-                            return self.cursor(self, cur, (k, v), lambda o: (self.typ.compare(o, match) != 0))
-                    elif all:
-                        try:
-                            k, v = cur.first()
-                        except notfound:
-                            return self.cursor(None, None, None, None)
-                        else:
-                            done = True
-                            return self.cursor(self, cur, (k, v), lambda o: False)
-                    elif ge is not missing or gt is not missing or lt is not missing or le is not missing:
-                        skip = False
-                        try:
-                            if ge is not missing:
-                                k, v = cur.set_range(self.typ.encode(ge))
-                            elif gt is not missing:
-                                k, v = cur.set_range(self.typ.encode(gt))
-                                skip = True
-                            else:
-                                k, v = cur.first()
-                        except notfound:
-                            return self.cursor(None, None, None, None)
-                        if lt is not missing:
-                            stop = lambda o: self.typ.compare(o, lt) >= 0
-                        elif le is not missing:
-                            stop = lambda o: self.typ.compare(o, le) > 0
-                        else:
-                            stop = lambda o: False
-                        ret = self.cursor(self, cur, (k, v), stop)
-                        if skip:
-                            try:
-                                while self.typ.compare(ret.peek()[0], gt) == 0:
-                                    next(ret)
-                            except StopIteration:
-                                pass
-                        done = True
-                        return ret
-                    else:
-                        raise NameError("invalid get() specification")
-                finally:
-                    if not done:
-                        cur.close()
-            except deadlock:
-                continue
+    def get(self, *, match=missing, ge=missing, gt=missing, lt=missing, le=missing, all=False, reverse=False):
+        if all:
+            cur = self.cursor(self, missing, True, missing, True, reverse)
+        elif match is not missing:
+            cur = self.cursor(self, match, True, match, True, reverse)
+        elif ge is not missing or gt is not missing or lt is not missing or le is not missing:
+            if ge is not missing:
+                fd, fi = ge, True
+            elif gt is not missing:
+                fd, fi = gt, False
+            else:
+                fd, fi = missing, True
+            if le is not missing:
+                ld, li = le, True
+            elif lt is not missing:
+                ld, li = lt, False
+            else:
+                ld, li = missing, True
+            cur = self.cursor(self, fd, fi, ld, li, reverse)
+        else:
+            raise NameError("invalid get() specification")
+        done = False
+        try:
+            if not reverse:
+                cur.first()
+            else:
+                cur.last()
+            done = True
+            return cur
+        finally:
+            if not done:
+                cur.close()
 
     @txnfun(lambda self: self.db.env.env)
     def put(self, key, id, *, tx):