Fixed index bugs.
[didex.git] / didex / index.py
index cdd4069..2844de0 100644 (file)
@@ -1,6 +1,6 @@
 import struct, contextlib, math
 from . import db, lib
-from .db import bd, txnfun
+from .db import bd, txnfun, dloopfun
 
 __all__ = ["maybe", "t_int", "t_uint", "t_float", "t_str", "ordered"]
 
@@ -134,42 +134,93 @@ class ordered(index, lib.closable):
         self.bk.close()
 
     class cursor(lib.closable):
-        def __init__(self, idx, cur, item, stop):
+        def __init__(self, idx, fd, fi, ld, li, reverse):
             self.idx = idx
-            self.cur = cur
-            self.item = item
-            self.stop = stop
+            self.typ = idx.typ
+            self.cur = self.idx.bk.cursor()
+            self.item = None
+            self.fd = fd
+            self.fi = fi
+            self.ld = ld
+            self.li = li
+            self.rev = reverse
 
         def close(self):
             if self.cur is not None:
                 self.cur.close()
+                self.cur = None
 
         def __iter__(self):
             return self
 
-        def peek(self):
-            if self.item is None:
-                raise StopIteration()
-            rk, rv = self.item
-            rk = self.idx.typ.decode(rk)
-            rv = struct.unpack(">Q", rv)[0]
-            if self.stop(rk):
-                self.item = None
-                raise StopIteration()
-            return rk, rv
+        def _decode(self, d):
+            k, v = d
+            k = self.typ.decode(k)
+            v = struct.unpack(">Q", v)[0]
+            return k, v
 
-        def __next__(self):
-            rk, rv = self.peek()
+        @dloopfun
+        def first(self):
+            try:
+                if self.fd is None:
+                    self.item = self._decode(self.cur.first())
+                else:
+                    k, v = self._decode(self.cur.set_range(self.typ.encode(self.fd)))
+                    if not self.fi:
+                        while self.typ.compare(k, self.fd) == 0:
+                            k, v = self._decode(self.cur.next())
+                    self.item = k, v
+            except notfound:
+                self.item = StopIteration
+
+        @dloopfun
+        def last(self):
+            try:
+                if self.fd is None:
+                    self.item = self._decode(self.cur.last())
+                else:
+                    k, v = self._decode(self.cur.set_range(self.typ.encode(self.ld)))
+                    if self.fi:
+                        while self.typ.compare(k, self.fd) == 0:
+                            k, v = self._decode(self.cur.next())
+                        k, v = self._decode(self.cur.prev())
+                    else:
+                        while self.typ.compare(k, self.fd) >= 0:
+                            k, v = self._decode(self.cur.prev())
+                    self.item = k, v
+            except notfound:
+                self.item = StopIteration
+
+        @dloopfun
+        def next(self):
+            try:
+                k, v = self.item = self._decode(self.cur.next())
+                if ((self.li and self.typ.compare(k, self.ld) > 0) or
+                    (not self.li and self.typ.compare(k, self.ld) >= 0)):
+                    self.item = StopIteration
+            except notfound:
+                self.item = StopIteration
+
+        @dloopfun
+        def prev(self):
             try:
-                while True:
-                    try:
-                        self.item = self.cur.next()
-                        break
-                    except deadlock:
-                        continue
+                self.item = self._decode(self.cur.prev())
+                if ((self.fi and self.typ.compare(k, self.fd) < 0) or
+                    (not self.fi and self.typ.compare(k, self.fd) <= 0)):
+                    self.item = StopIteration
             except notfound:
-                self.item = None
-            return rk, rv
+                self.item = StopIteration
+
+        def __next__(self):
+            if self.item is None:
+                if not self.rev:
+                    self.next()
+                else:
+                    self.prev()
+            if self.item is StopIteration:
+                raise StopIteration()
+            ret, self.item = self.item, None
+            return ret
 
         def skip(self, n=1):
             try:
@@ -178,62 +229,38 @@ class ordered(index, lib.closable):
             except StopIteration:
                 return
 
-    def get(self, *, match=missing, ge=missing, gt=missing, lt=missing, le=missing, all=False):
-        while True:
-            try:
-                cur = self.bk.cursor()
-                done = False
-                try:
-                    if match is not missing:
-                        try:
-                            k, v = cur.set(self.typ.encode(match))
-                        except notfound:
-                            return self.cursor(None, None, None, None)
-                        else:
-                            done = True
-                            return self.cursor(self, cur, (k, v), lambda o: (self.typ.compare(o, match) != 0))
-                    elif all:
-                        try:
-                            k, v = cur.first()
-                        except notfound:
-                            return self.cursor(None, None, None, None)
-                        else:
-                            done = True
-                            return self.cursor(self, cur, (k, v), lambda o: False)
-                    elif ge is not missing or gt is not missing or lt is not missing or le is not missing:
-                        skip = False
-                        try:
-                            if ge is not missing:
-                                k, v = cur.set_range(self.typ.encode(ge))
-                            elif gt is not missing:
-                                k, v = cur.set_range(self.typ.encode(gt))
-                                skip = True
-                            else:
-                                k, v = cur.first()
-                        except notfound:
-                            return self.cursor(None, None, None, None)
-                        if lt is not missing:
-                            stop = lambda o: self.typ.compare(o, lt) >= 0
-                        elif le is not missing:
-                            stop = lambda o: self.typ.compare(o, le) > 0
-                        else:
-                            stop = lambda o: False
-                        ret = self.cursor(self, cur, (k, v), stop)
-                        if skip:
-                            try:
-                                while self.typ.compare(ret.peek()[0], gt) == 0:
-                                    next(ret)
-                            except StopIteration:
-                                pass
-                        done = True
-                        return ret
-                    else:
-                        raise NameError("invalid get() specification")
-                finally:
-                    if not done:
-                        cur.close()
-            except deadlock:
-                continue
+    def get(self, *, match=missing, ge=missing, gt=missing, lt=missing, le=missing, all=False, reverse=False):
+        if all:
+            cur = self.cursor(self, None, True, None, True, reverse)
+        elif match is not missing:
+            cur = self.cursor(self, match, True, match, True, reverse)
+        elif ge is not missing or gt is not missing or lt is not missing or le is not missing:
+            if ge is not missing:
+                fd, fi = ge, True
+            elif gt is not missing:
+                fd, fi = gt, False
+            else:
+                fd, fi = None, True
+            if le is not missing:
+                ld, li = le, True
+            elif lt is not missing:
+                ld, li = lt, False
+            else:
+                ld, li = None, True
+            cur = self.cursor(self, fd, fi, ld, li, reverse)
+        else:
+            raise NameError("invalid get() specification")
+        done = False
+        try:
+            if not reverse:
+                cur.first()
+            else:
+                cur.last()
+            done = True
+            return cur
+        finally:
+            if not done:
+                cur.close()
 
     @txnfun(lambda self: self.db.env.env)
     def put(self, key, id, *, tx):