Fixed some index bugs.
[didex.git] / didex / index.py
1 import struct, contextlib, math
2 from . import db, lib
3 from .db import bd, txnfun, dloopfun
4
5 __all__ = ["maybe", "t_int", "t_uint", "t_float", "t_str", "ordered"]
6
7 deadlock = bd.DBLockDeadlockError
8 notfound = bd.DBNotFoundError
9
10 class simpletype(object):
11     def __init__(self, encode, decode):
12         self.enc = encode
13         self.dec = decode
14
15     def encode(self, ob):
16         return self.enc(ob)
17     def decode(self, dat):
18         return self.dec(dat)
19     def compare(self, a, b):
20         if a < b:
21             return -1
22         elif a > b:
23             return 1
24         else:
25             return 0
26
27     @classmethod
28     def struct(cls, fmt):
29         return cls(lambda ob: struct.pack(fmt, ob),
30                    lambda dat: struct.unpack(fmt, dat)[0])
31
32 class maybe(object):
33     def __init__(self, bk):
34         self.bk = bk
35
36     def encode(self, ob):
37         if ob is None: return b""
38         return b"\0" + self.bk.encode(ob)
39     def decode(self, dat):
40         if dat == b"": return None
41         return self.bk.dec(dat[1:])
42     def compare(self, a, b):
43         if a is b is None:
44             return 0
45         elif a is None:
46             return -1
47         elif b is None:
48             return 1
49         else:
50             return self.bk.compare(a[1:], b[1:])
51
52 class compound(object):
53     def __init__(self, *parts):
54         self.parts = parts
55
56     def encode(self, obs):
57         if len(obs) != len(self.parts):
58             raise ValueError("invalid length of compound data: " + str(len(obs)) + ", rather than " + len(self.parts))
59         buf = bytearray()
60         for ob, part in zip(obs, self.parts):
61             dat = part.encode(ob)
62             if len(dat) < 128:
63                 buf.append(0x80 | len(dat))
64                 buf.extend(dat)
65             else:
66                 buf.extend(struct.pack(">i", len(dat)))
67                 buf.extend(dat)
68         return bytes(buf)
69     def decode(self, dat):
70         ret = []
71         off = 0
72         for part in self.parts:
73             if dat[off] & 0x80:
74                 ln = dat[off] & 0x7f
75                 off += 1
76             else:
77                 ln = struct.unpack(">i", dat[off:off + 4])[0]
78                 off += 4
79             ret.append(part.decode(dat[off:off + len]))
80             off += len
81         return tuple(ret)
82     def compare(self, al, bl):
83         if (len(al) != len(self.parts)) or (len(bl) != len(self.parts)):
84             raise ValueError("invalid length of compound data: " + str(len(al)) + ", " + str(len(bl)) + ", rather than " + len(self.parts))
85         for a, b, part in zip(al, bl, self.parts):
86             c = part.compare(a, b)
87             if c != 0:
88                 return c
89         return 0
90
91 def floatcmp(a, b):
92     if math.isnan(a) and math.isnan(b):
93         return 0
94     elif math.isnan(a):
95         return -1
96     elif math.isnan(b):
97         return 1
98     elif a < b:
99         return -1
100     elif a > b:
101         return 1
102     else:
103         return 0
104
105 t_int = simpletype.struct(">q")
106 t_uint = simpletype.struct(">Q")
107 t_float = simpletype.struct(">d")
108 t_float.compare = floatcmp
109 t_str = simpletype((lambda ob: ob.encode("utf-8")), (lambda dat: dat.decode("utf-8")))
110
111 class index(object):
112     def __init__(self, db, name, datatype):
113         self.db = db
114         self.nm = name
115         self.typ = datatype
116
117 missing = object()
118
119 class ordered(index, lib.closable):
120     def __init__(self, db, name, datatype, create=True):
121         super().__init__(db, name, datatype)
122         fl = bd.DB_THREAD | bd.DB_AUTO_COMMIT
123         if create: fl |= bd.DB_CREATE
124         def initdb(db):
125             def compare(a, b):
126                 if a == b == "": return 0
127                 return self.typ.compare(self.typ.decode(a), self.typ.decode(b))
128             db.set_flags(bd.DB_DUPSORT)
129             db.set_bt_compare(compare)
130         self.bk = db._opendb("i-" + name, bd.DB_BTREE, fl, initdb)
131         self.bk.set_get_returns_none(False)
132
133     def close(self):
134         self.bk.close()
135
136     class cursor(lib.closable):
137         def __init__(self, idx, fd, fi, ld, li, reverse):
138             self.idx = idx
139             self.typ = idx.typ
140             self.cur = self.idx.bk.cursor()
141             self.item = None
142             self.fd = fd
143             self.fi = fi
144             self.ld = ld
145             self.li = li
146             self.rev = reverse
147
148         def close(self):
149             if self.cur is not None:
150                 self.cur.close()
151                 self.cur = None
152
153         def __iter__(self):
154             return self
155
156         def _decode(self, d):
157             k, v = d
158             k = self.typ.decode(k)
159             v = struct.unpack(">Q", v)[0]
160             return k, v
161
162         @dloopfun
163         def first(self):
164             try:
165                 if self.fd is missing:
166                     self.item = self._decode(self.cur.first())
167                 else:
168                     k, v = self._decode(self.cur.set_range(self.typ.encode(self.fd)))
169                     if not self.fi:
170                         while self.typ.compare(k, self.fd) == 0:
171                             k, v = self._decode(self.cur.next())
172                     self.item = k, v
173             except notfound:
174                 self.item = StopIteration
175
176         @dloopfun
177         def last(self):
178             try:
179                 if self.ld is missing:
180                     self.item = self._decode(self.cur.last())
181                 else:
182                     try:
183                         k, v = self._decode(self.cur.set_range(self.typ.encode(self.ld)))
184                     except notfound:
185                         k, v = self._decode(self.cur.last())
186                     if self.li:
187                         while self.typ.compare(k, self.ld) == 0:
188                             k, v = self._decode(self.cur.next())
189                         while self.typ.compare(k, self.ld) > 0:
190                             k, v = self._decode(self.cur.prev())
191                     else:
192                         while self.typ.compare(k, self.ld) >= 0:
193                             k, v = self._decode(self.cur.prev())
194                     self.item = k, v
195             except notfound:
196                 self.item = StopIteration
197
198         @dloopfun
199         def next(self):
200             try:
201                 k, v = self.item = self._decode(self.cur.next())
202                 if (self.ld is not missing and
203                     ((self.li and self.typ.compare(k, self.ld) > 0) or
204                      (not self.li and self.typ.compare(k, self.ld) >= 0))):
205                     self.item = StopIteration
206             except notfound:
207                 self.item = StopIteration
208
209         @dloopfun
210         def prev(self):
211             try:
212                 self.item = self._decode(self.cur.prev())
213                 if (self.fd is not missing and
214                     ((self.fi and self.typ.compare(k, self.fd) < 0) or
215                      (not self.fi and self.typ.compare(k, self.fd) <= 0))):
216                     self.item = StopIteration
217             except notfound:
218                 self.item = StopIteration
219
220         def __next__(self):
221             if self.item is None:
222                 if not self.rev:
223                     self.next()
224                 else:
225                     self.prev()
226             if self.item is StopIteration:
227                 raise StopIteration()
228             ret, self.item = self.item, None
229             return ret
230
231         def skip(self, n=1):
232             try:
233                 for i in range(n):
234                     next(self)
235             except StopIteration:
236                 return
237
238     def get(self, *, match=missing, ge=missing, gt=missing, lt=missing, le=missing, all=False, reverse=False):
239         if all:
240             cur = self.cursor(self, missing, True, missing, True, reverse)
241         elif match is not missing:
242             cur = self.cursor(self, match, True, match, True, reverse)
243         elif ge is not missing or gt is not missing or lt is not missing or le is not missing:
244             if ge is not missing:
245                 fd, fi = ge, True
246             elif gt is not missing:
247                 fd, fi = gt, False
248             else:
249                 fd, fi = missing, True
250             if le is not missing:
251                 ld, li = le, True
252             elif lt is not missing:
253                 ld, li = lt, False
254             else:
255                 ld, li = missing, True
256             cur = self.cursor(self, fd, fi, ld, li, reverse)
257         else:
258             raise NameError("invalid get() specification")
259         done = False
260         try:
261             if not reverse:
262                 cur.first()
263             else:
264                 cur.last()
265             done = True
266             return cur
267         finally:
268             if not done:
269                 cur.close()
270
271     @txnfun(lambda self: self.db.env.env)
272     def put(self, key, id, *, tx):
273         obid = struct.pack(">Q", id)
274         if not self.db.ob.has_key(obid, txn=tx.tx):
275             raise ValueError("no such object in database: " + str(id))
276         try:
277             self.bk.put(self.typ.encode(key), obid, txn=tx.tx, flags=bd.DB_NODUPDATA)
278         except bd.DBKeyExistError:
279             return False
280         return True
281
282     @txnfun(lambda self: self.db.env.env)
283     def remove(self, key, id, *, tx):
284         obid = struct.pack(">Q", id)
285         if not self.db.ob.has_key(obid, txn=tx.tx):
286             raise ValueError("no such object in database: " + str(id))
287         cur = self.bk.cursor(txn=tx.tx)
288         try:
289             try:
290                 cur.get_both(self.typ.encode(key), obid)
291             except notfound:
292                 return False
293             cur.delete()
294         finally:
295             cur.close()
296         return True