Added support for partial compound key matches.
[didex.git] / didex / index.py
index 30dcfdf..4906b46 100644 (file)
@@ -53,36 +53,66 @@ class compound(object):
     def __init__(self, *parts):
         self.parts = parts
 
+    small = object()
+    large = object()
+    def minim(self, *parts):
+        return parts + tuple([self.small] * (len(self.parts) - len(parts)))
+    def maxim(self, *parts):
+        return parts + tuple([self.large] * (len(self.parts) - len(parts)))
+
     def encode(self, obs):
         if len(obs) != len(self.parts):
             raise ValueError("invalid length of compound data: " + str(len(obs)) + ", rather than " + len(self.parts))
         buf = bytearray()
         for ob, part in zip(obs, self.parts):
-            dat = part.encode(ob)
-            if len(dat) < 128:
-                buf.append(0x80 | len(dat))
-                buf.extend(dat)
+            if ob is self.small:
+                buf.append(0x01)
+            elif ob is self.large:
+                buf.append(0x02)
             else:
-                buf.extend(struct.pack(">i", len(dat)))
-                buf.extend(dat)
+                dat = part.encode(ob)
+                if len(dat) < 128:
+                    buf.append(0x80 | len(dat))
+                    buf.extend(dat)
+                else:
+                    buf.extend(struct.pack(">BI", 0, len(dat)))
+                    buf.extend(dat)
         return bytes(buf)
     def decode(self, dat):
         ret = []
         off = 0
         for part in self.parts:
-            if dat[off] & 0x80:
-                ln = dat[off] & 0x7f
-                off += 1
+            fl = dat[off]
+            off += 1
+            if fl & 0x80:
+                ln = fl & 0x7f
+            elif fl == 0x01:
+                ret.append(self.small)
+                continue
+            elif fl == 0x02:
+                ret.append(self.large)
+                continue
             else:
-                ln = struct.unpack(">i", dat[off:off + 4])[0]
+                ln = struct.unpack(">I", dat[off:off + 4])[0]
                 off += 4
-            ret.append(part.decode(dat[off:off + len]))
-            off += len
+            ret.append(part.decode(dat[off:off + ln]))
+            off += ln
         return tuple(ret)
     def compare(self, al, bl):
         if (len(al) != len(self.parts)) or (len(bl) != len(self.parts)):
             raise ValueError("invalid length of compound data: " + str(len(al)) + ", " + str(len(bl)) + ", rather than " + len(self.parts))
         for a, b, part in zip(al, bl, self.parts):
+            if a in (self.small, self.large) or b in (self.small, self.large):
+                if a is b:
+                    return 0
+                if a is self.small:
+                    return -1
+                elif b is self.small:
+                    return 1
+                elif a is self.large:
+                    return 1
+                elif b is self.large:
+                    return -1
             c = part.compare(a, b)
             if c != 0:
                 return c
@@ -155,14 +185,14 @@ class ordered(index, lib.closable):
 
         def _decode(self, d):
             k, v = d
-            k = self.type.decode(k)
+            k = self.typ.decode(k)
             v = struct.unpack(">Q", v)[0]
             return k, v
 
         @dloopfun
         def first(self):
             try:
-                if self.fd is None:
+                if self.fd is missing:
                     self.item = self._decode(self.cur.first())
                 else:
                     k, v = self._decode(self.cur.set_range(self.typ.encode(self.fd)))
@@ -176,16 +206,20 @@ class ordered(index, lib.closable):
         @dloopfun
         def last(self):
             try:
-                if self.fd is None:
+                if self.ld is missing:
                     self.item = self._decode(self.cur.last())
                 else:
-                    k, v = self._decode(self.cur.set_range(self.typ.encode(self.ld)))
-                    if self.fi:
-                        while self.typ.compare(k, self.fd) == 0:
+                    try:
+                        k, v = self._decode(self.cur.set_range(self.typ.encode(self.ld)))
+                    except notfound:
+                        k, v = self._decode(self.cur.last())
+                    if self.li:
+                        while self.typ.compare(k, self.ld) == 0:
                             k, v = self._decode(self.cur.next())
-                        k, v = self._decode(self.cur.prev())
+                        while self.typ.compare(k, self.ld) > 0:
+                            k, v = self._decode(self.cur.prev())
                     else:
-                        while self.typ.compare(k, self.fd) >= 0:
+                        while self.typ.compare(k, self.ld) >= 0:
                             k, v = self._decode(self.cur.prev())
                     self.item = k, v
             except notfound:
@@ -195,8 +229,9 @@ class ordered(index, lib.closable):
         def next(self):
             try:
                 k, v = self.item = self._decode(self.cur.next())
-                if ((self.li and self.typ.compare(k, self.ld) > 0) or
-                    (not self.li and self.typ.compare(k, self.ld) >= 0)):
+                if (self.ld is not missing and
+                    ((self.li and self.typ.compare(k, self.ld) > 0) or
+                     (not self.li and self.typ.compare(k, self.ld) >= 0))):
                     self.item = StopIteration
             except notfound:
                 self.item = StopIteration
@@ -205,20 +240,21 @@ class ordered(index, lib.closable):
         def prev(self):
             try:
                 self.item = self._decode(self.cur.prev())
-                if ((self.fi and self.typ.compare(k, self.fd) < 0) or
-                    (not self.fi and self.typ.compare(k, self.fd) <= 0)):
+                if (self.fd is not missing and
+                    ((self.fi and self.typ.compare(k, self.fd) < 0) or
+                     (not self.fi and self.typ.compare(k, self.fd) <= 0))):
                     self.item = StopIteration
             except notfound:
                 self.item = StopIteration
 
         def __next__(self):
-            if self.item is StopIteration:
-                raise StopIteration()
             if self.item is None:
                 if not self.rev:
                     self.next()
                 else:
                     self.prev()
+            if self.item is StopIteration:
+                raise StopIteration()
             ret, self.item = self.item, None
             return ret
 
@@ -231,7 +267,7 @@ class ordered(index, lib.closable):
 
     def get(self, *, match=missing, ge=missing, gt=missing, lt=missing, le=missing, all=False, reverse=False):
         if all:
-            cur = self.cursor(self, None, True, None, True, reverse)
+            cur = self.cursor(self, missing, True, missing, True, reverse)
         elif match is not missing:
             cur = self.cursor(self, match, True, match, True, reverse)
         elif ge is not missing or gt is not missing or lt is not missing or le is not missing:
@@ -240,13 +276,13 @@ class ordered(index, lib.closable):
             elif gt is not missing:
                 fd, fi = gt, False
             else:
-                fd, fi = None, True
+                fd, fi = missing, True
             if le is not missing:
                 ld, li = le, True
             elif lt is not missing:
                 ld, li = lt, False
             else:
-                ld, li = None, True
+                ld, li = missing, True
             cur = self.cursor(self, fd, fi, ld, li, reverse)
         else:
             raise NameError("invalid get() specification")