Handle dictproxies.
[didex.git] / didex / index.py
index 951eb3f..c2c55b6 100644 (file)
@@ -2,7 +2,7 @@ import struct, contextlib, math
 from . import db, lib
 from .db import bd, txnfun, dloopfun
 
-__all__ = ["maybe", "t_int", "t_uint", "t_float", "t_str", "ordered"]
+__all__ = ["maybe", "t_int", "t_uint", "t_dbid", "t_float", "t_str", "t_casestr", "ordered"]
 
 deadlock = bd.DBLockDeadlockError
 notfound = bd.DBNotFoundError
@@ -29,6 +29,14 @@ class simpletype(object):
         return cls(lambda ob: struct.pack(fmt, ob),
                    lambda dat: struct.unpack(fmt, dat)[0])
 
+class foldtype(simpletype):
+    def __init__(self, encode, decode, fold):
+        super().__init__(encode, decode)
+        self.fold = fold
+
+    def compare(self, a, b):
+        return super().compare(self.fold(a), self.fold(b))
+
 class maybe(object):
     def __init__(self, bk):
         self.bk = bk
@@ -53,36 +61,66 @@ class compound(object):
     def __init__(self, *parts):
         self.parts = parts
 
+    small = object()
+    large = object()
+    def minim(self, *parts):
+        return parts + tuple([self.small] * (len(self.parts) - len(parts)))
+    def maxim(self, *parts):
+        return parts + tuple([self.large] * (len(self.parts) - len(parts)))
+
     def encode(self, obs):
         if len(obs) != len(self.parts):
             raise ValueError("invalid length of compound data: " + str(len(obs)) + ", rather than " + len(self.parts))
         buf = bytearray()
         for ob, part in zip(obs, self.parts):
-            dat = part.encode(ob)
-            if len(dat) < 128:
-                buf.append(0x80 | len(dat))
-                buf.extend(dat)
+            if ob is self.small:
+                buf.append(0x01)
+            elif ob is self.large:
+                buf.append(0x02)
             else:
-                buf.extend(struct.pack(">i", len(dat)))
-                buf.extend(dat)
+                dat = part.encode(ob)
+                if len(dat) < 128:
+                    buf.append(0x80 | len(dat))
+                    buf.extend(dat)
+                else:
+                    buf.extend(struct.pack(">BI", 0, len(dat)))
+                    buf.extend(dat)
         return bytes(buf)
     def decode(self, dat):
         ret = []
         off = 0
         for part in self.parts:
-            if dat[off] & 0x80:
-                ln = dat[off] & 0x7f
-                off += 1
+            fl = dat[off]
+            off += 1
+            if fl & 0x80:
+                ln = fl & 0x7f
+            elif fl == 0x01:
+                ret.append(self.small)
+                continue
+            elif fl == 0x02:
+                ret.append(self.large)
+                continue
             else:
-                ln = struct.unpack(">i", dat[off:off + 4])[0]
+                ln = struct.unpack(">I", dat[off:off + 4])[0]
                 off += 4
-            ret.append(part.decode(dat[off:off + len]))
-            off += len
+            ret.append(part.decode(dat[off:off + ln]))
+            off += ln
         return tuple(ret)
     def compare(self, al, bl):
         if (len(al) != len(self.parts)) or (len(bl) != len(self.parts)):
             raise ValueError("invalid length of compound data: " + str(len(al)) + ", " + str(len(bl)) + ", rather than " + len(self.parts))
         for a, b, part in zip(al, bl, self.parts):
+            if a in (self.small, self.large) or b in (self.small, self.large):
+                if a is b:
+                    return 0
+                if a is self.small:
+                    return -1
+                elif b is self.small:
+                    return 1
+                elif a is self.large:
+                    return 1
+                elif b is self.large:
+                    return -1
             c = part.compare(a, b)
             if c != 0:
                 return c
@@ -104,9 +142,12 @@ def floatcmp(a, b):
 
 t_int = simpletype.struct(">q")
 t_uint = simpletype.struct(">Q")
+t_dbid = t_uint
 t_float = simpletype.struct(">d")
 t_float.compare = floatcmp
 t_str = simpletype((lambda ob: ob.encode("utf-8")), (lambda dat: dat.decode("utf-8")))
+t_casestr = foldtype((lambda ob: ob.encode("utf-8")), (lambda dat: dat.decode("utf-8")),
+                     (lambda st: st.lower()))
 
 class index(object):
     def __init__(self, db, name, datatype):