Improved cursor functionality.
authorFredrik Tolf <fredrik@dolda2000.com>
Mon, 30 Mar 2015 00:56:21 +0000 (02:56 +0200)
committerFredrik Tolf <fredrik@dolda2000.com>
Mon, 30 Mar 2015 00:56:21 +0000 (02:56 +0200)
didex/db.py
didex/index.py

index 7435166..33eb0d8 100644 (file)
@@ -94,6 +94,15 @@ class txn(object):
     def postcommit(self, fun):
         self.pcommit.add(fun)
 
+def dloopfun(fun):
+    def wrapper(self, *args, **kwargs):
+        while True:
+            try:
+                return fun(self, *args, **kwargs)
+            except deadlock:
+                continue
+    return wrapper
+
 def txnfun(envfun):
     def fxf(fun):
         def wrapper(self, *args, tx=None, **kwargs):
index cdd4069..30dcfdf 100644 (file)
@@ -1,6 +1,6 @@
 import struct, contextlib, math
 from . import db, lib
-from .db import bd, txnfun
+from .db import bd, txnfun, dloopfun
 
 __all__ = ["maybe", "t_int", "t_uint", "t_float", "t_str", "ordered"]
 
@@ -134,42 +134,93 @@ class ordered(index, lib.closable):
         self.bk.close()
 
     class cursor(lib.closable):
-        def __init__(self, idx, cur, item, stop):
+        def __init__(self, idx, fd, fi, ld, li, reverse):
             self.idx = idx
-            self.cur = cur
-            self.item = item
-            self.stop = stop
+            self.typ = idx.typ
+            self.cur = self.idx.bk.cursor()
+            self.item = None
+            self.fd = fd
+            self.fi = fi
+            self.ld = ld
+            self.li = li
+            self.rev = reverse
 
         def close(self):
             if self.cur is not None:
                 self.cur.close()
+                self.cur = None
 
         def __iter__(self):
             return self
 
-        def peek(self):
-            if self.item is None:
-                raise StopIteration()
-            rk, rv = self.item
-            rk = self.idx.typ.decode(rk)
-            rv = struct.unpack(">Q", rv)[0]
-            if self.stop(rk):
-                self.item = None
-                raise StopIteration()
-            return rk, rv
+        def _decode(self, d):
+            k, v = d
+            k = self.type.decode(k)
+            v = struct.unpack(">Q", v)[0]
+            return k, v
 
-        def __next__(self):
-            rk, rv = self.peek()
+        @dloopfun
+        def first(self):
+            try:
+                if self.fd is None:
+                    self.item = self._decode(self.cur.first())
+                else:
+                    k, v = self._decode(self.cur.set_range(self.typ.encode(self.fd)))
+                    if not self.fi:
+                        while self.typ.compare(k, self.fd) == 0:
+                            k, v = self._decode(self.cur.next())
+                    self.item = k, v
+            except notfound:
+                self.item = StopIteration
+
+        @dloopfun
+        def last(self):
+            try:
+                if self.fd is None:
+                    self.item = self._decode(self.cur.last())
+                else:
+                    k, v = self._decode(self.cur.set_range(self.typ.encode(self.ld)))
+                    if self.fi:
+                        while self.typ.compare(k, self.fd) == 0:
+                            k, v = self._decode(self.cur.next())
+                        k, v = self._decode(self.cur.prev())
+                    else:
+                        while self.typ.compare(k, self.fd) >= 0:
+                            k, v = self._decode(self.cur.prev())
+                    self.item = k, v
+            except notfound:
+                self.item = StopIteration
+
+        @dloopfun
+        def next(self):
+            try:
+                k, v = self.item = self._decode(self.cur.next())
+                if ((self.li and self.typ.compare(k, self.ld) > 0) or
+                    (not self.li and self.typ.compare(k, self.ld) >= 0)):
+                    self.item = StopIteration
+            except notfound:
+                self.item = StopIteration
+
+        @dloopfun
+        def prev(self):
             try:
-                while True:
-                    try:
-                        self.item = self.cur.next()
-                        break
-                    except deadlock:
-                        continue
+                self.item = self._decode(self.cur.prev())
+                if ((self.fi and self.typ.compare(k, self.fd) < 0) or
+                    (not self.fi and self.typ.compare(k, self.fd) <= 0)):
+                    self.item = StopIteration
             except notfound:
-                self.item = None
-            return rk, rv
+                self.item = StopIteration
+
+        def __next__(self):
+            if self.item is StopIteration:
+                raise StopIteration()
+            if self.item is None:
+                if not self.rev:
+                    self.next()
+                else:
+                    self.prev()
+            ret, self.item = self.item, None
+            return ret
 
         def skip(self, n=1):
             try:
@@ -178,62 +229,38 @@ class ordered(index, lib.closable):
             except StopIteration:
                 return
 
-    def get(self, *, match=missing, ge=missing, gt=missing, lt=missing, le=missing, all=False):
-        while True:
-            try:
-                cur = self.bk.cursor()
-                done = False
-                try:
-                    if match is not missing:
-                        try:
-                            k, v = cur.set(self.typ.encode(match))
-                        except notfound:
-                            return self.cursor(None, None, None, None)
-                        else:
-                            done = True
-                            return self.cursor(self, cur, (k, v), lambda o: (self.typ.compare(o, match) != 0))
-                    elif all:
-                        try:
-                            k, v = cur.first()
-                        except notfound:
-                            return self.cursor(None, None, None, None)
-                        else:
-                            done = True
-                            return self.cursor(self, cur, (k, v), lambda o: False)
-                    elif ge is not missing or gt is not missing or lt is not missing or le is not missing:
-                        skip = False
-                        try:
-                            if ge is not missing:
-                                k, v = cur.set_range(self.typ.encode(ge))
-                            elif gt is not missing:
-                                k, v = cur.set_range(self.typ.encode(gt))
-                                skip = True
-                            else:
-                                k, v = cur.first()
-                        except notfound:
-                            return self.cursor(None, None, None, None)
-                        if lt is not missing:
-                            stop = lambda o: self.typ.compare(o, lt) >= 0
-                        elif le is not missing:
-                            stop = lambda o: self.typ.compare(o, le) > 0
-                        else:
-                            stop = lambda o: False
-                        ret = self.cursor(self, cur, (k, v), stop)
-                        if skip:
-                            try:
-                                while self.typ.compare(ret.peek()[0], gt) == 0:
-                                    next(ret)
-                            except StopIteration:
-                                pass
-                        done = True
-                        return ret
-                    else:
-                        raise NameError("invalid get() specification")
-                finally:
-                    if not done:
-                        cur.close()
-            except deadlock:
-                continue
+    def get(self, *, match=missing, ge=missing, gt=missing, lt=missing, le=missing, all=False, reverse=False):
+        if all:
+            cur = self.cursor(self, None, True, None, True, reverse)
+        elif match is not missing:
+            cur = self.cursor(self, match, True, match, True, reverse)
+        elif ge is not missing or gt is not missing or lt is not missing or le is not missing:
+            if ge is not missing:
+                fd, fi = ge, True
+            elif gt is not missing:
+                fd, fi = gt, False
+            else:
+                fd, fi = None, True
+            if le is not missing:
+                ld, li = le, True
+            elif lt is not missing:
+                ld, li = lt, False
+            else:
+                ld, li = None, True
+            cur = self.cursor(self, fd, fi, ld, li, reverse)
+        else:
+            raise NameError("invalid get() specification")
+        done = False
+        try:
+            if not reverse:
+                cur.first()
+            else:
+                cur.last()
+            done = True
+            return cur
+        finally:
+            if not done:
+                cur.close()
 
     @txnfun(lambda self: self.db.env.env)
     def put(self, key, id, *, tx):