65b1b12c1cff197dced8c25a7e4b73267b5d089b
[didex.git] / didex / index.py
1 import struct, contextlib, math
2 from . import db, lib
3 from .db import bd, txnfun, dloopfun
4
5 __all__ = ["maybe", "t_int", "t_uint", "t_dbid", "t_float", "t_str", "ordered"]
6
7 deadlock = bd.DBLockDeadlockError
8 notfound = bd.DBNotFoundError
9
10 class simpletype(object):
11     def __init__(self, encode, decode):
12         self.enc = encode
13         self.dec = decode
14
15     def encode(self, ob):
16         return self.enc(ob)
17     def decode(self, dat):
18         return self.dec(dat)
19     def compare(self, a, b):
20         if a < b:
21             return -1
22         elif a > b:
23             return 1
24         else:
25             return 0
26
27     @classmethod
28     def struct(cls, fmt):
29         return cls(lambda ob: struct.pack(fmt, ob),
30                    lambda dat: struct.unpack(fmt, dat)[0])
31
32 class maybe(object):
33     def __init__(self, bk):
34         self.bk = bk
35
36     def encode(self, ob):
37         if ob is None: return b""
38         return b"\0" + self.bk.encode(ob)
39     def decode(self, dat):
40         if dat == b"": return None
41         return self.bk.dec(dat[1:])
42     def compare(self, a, b):
43         if a is b is None:
44             return 0
45         elif a is None:
46             return -1
47         elif b is None:
48             return 1
49         else:
50             return self.bk.compare(a[1:], b[1:])
51
52 class compound(object):
53     def __init__(self, *parts):
54         self.parts = parts
55
56     small = object()
57     large = object()
58     def minim(self, *parts):
59         return parts + tuple([self.small] * (len(self.parts) - len(parts)))
60     def maxim(self, *parts):
61         return parts + tuple([self.large] * (len(self.parts) - len(parts)))
62
63     def encode(self, obs):
64         if len(obs) != len(self.parts):
65             raise ValueError("invalid length of compound data: " + str(len(obs)) + ", rather than " + len(self.parts))
66         buf = bytearray()
67         for ob, part in zip(obs, self.parts):
68             if ob is self.small:
69                 buf.append(0x01)
70             elif ob is self.large:
71                 buf.append(0x02)
72             else:
73                 dat = part.encode(ob)
74                 if len(dat) < 128:
75                     buf.append(0x80 | len(dat))
76                     buf.extend(dat)
77                 else:
78                     buf.extend(struct.pack(">BI", 0, len(dat)))
79                     buf.extend(dat)
80         return bytes(buf)
81     def decode(self, dat):
82         ret = []
83         off = 0
84         for part in self.parts:
85             fl = dat[off]
86             off += 1
87             if fl & 0x80:
88                 ln = fl & 0x7f
89             elif fl == 0x01:
90                 ret.append(self.small)
91                 continue
92             elif fl == 0x02:
93                 ret.append(self.large)
94                 continue
95             else:
96                 ln = struct.unpack(">I", dat[off:off + 4])[0]
97                 off += 4
98             ret.append(part.decode(dat[off:off + ln]))
99             off += ln
100         return tuple(ret)
101     def compare(self, al, bl):
102         if (len(al) != len(self.parts)) or (len(bl) != len(self.parts)):
103             raise ValueError("invalid length of compound data: " + str(len(al)) + ", " + str(len(bl)) + ", rather than " + len(self.parts))
104         for a, b, part in zip(al, bl, self.parts):
105             if a in (self.small, self.large) or b in (self.small, self.large):
106                 if a is b:
107                     return 0
108                 if a is self.small:
109                     return -1
110                 elif b is self.small:
111                     return 1
112                 elif a is self.large:
113                     return 1
114                 elif b is self.large:
115                     return -1
116             c = part.compare(a, b)
117             if c != 0:
118                 return c
119         return 0
120
121 def floatcmp(a, b):
122     if math.isnan(a) and math.isnan(b):
123         return 0
124     elif math.isnan(a):
125         return -1
126     elif math.isnan(b):
127         return 1
128     elif a < b:
129         return -1
130     elif a > b:
131         return 1
132     else:
133         return 0
134
135 t_int = simpletype.struct(">q")
136 t_uint = simpletype.struct(">Q")
137 t_dbid = t_uint
138 t_float = simpletype.struct(">d")
139 t_float.compare = floatcmp
140 t_str = simpletype((lambda ob: ob.encode("utf-8")), (lambda dat: dat.decode("utf-8")))
141
142 class index(object):
143     def __init__(self, db, name, datatype):
144         self.db = db
145         self.nm = name
146         self.typ = datatype
147
148 missing = object()
149
150 class ordered(index, lib.closable):
151     def __init__(self, db, name, datatype, create=True):
152         super().__init__(db, name, datatype)
153         fl = bd.DB_THREAD | bd.DB_AUTO_COMMIT
154         if create: fl |= bd.DB_CREATE
155         def initdb(db):
156             def compare(a, b):
157                 if a == b == "": return 0
158                 return self.typ.compare(self.typ.decode(a), self.typ.decode(b))
159             db.set_flags(bd.DB_DUPSORT)
160             db.set_bt_compare(compare)
161         self.bk = db._opendb("i-" + name, bd.DB_BTREE, fl, initdb)
162         self.bk.set_get_returns_none(False)
163
164     def close(self):
165         self.bk.close()
166
167     class cursor(lib.closable):
168         def __init__(self, idx, fd, fi, ld, li, reverse):
169             self.idx = idx
170             self.typ = idx.typ
171             self.cur = self.idx.bk.cursor()
172             self.item = None
173             self.fd = fd
174             self.fi = fi
175             self.ld = ld
176             self.li = li
177             self.rev = reverse
178
179         def close(self):
180             if self.cur is not None:
181                 self.cur.close()
182                 self.cur = None
183
184         def __iter__(self):
185             return self
186
187         def _decode(self, d):
188             k, v = d
189             k = self.typ.decode(k)
190             v = struct.unpack(">Q", v)[0]
191             return k, v
192
193         @dloopfun
194         def first(self):
195             try:
196                 if self.fd is missing:
197                     self.item = self._decode(self.cur.first())
198                 else:
199                     k, v = self._decode(self.cur.set_range(self.typ.encode(self.fd)))
200                     if not self.fi:
201                         while self.typ.compare(k, self.fd) == 0:
202                             k, v = self._decode(self.cur.next())
203                     self.item = k, v
204             except notfound:
205                 self.item = StopIteration
206
207         @dloopfun
208         def last(self):
209             try:
210                 if self.ld is missing:
211                     self.item = self._decode(self.cur.last())
212                 else:
213                     try:
214                         k, v = self._decode(self.cur.set_range(self.typ.encode(self.ld)))
215                     except notfound:
216                         k, v = self._decode(self.cur.last())
217                     if self.li:
218                         while self.typ.compare(k, self.ld) == 0:
219                             k, v = self._decode(self.cur.next())
220                         while self.typ.compare(k, self.ld) > 0:
221                             k, v = self._decode(self.cur.prev())
222                     else:
223                         while self.typ.compare(k, self.ld) >= 0:
224                             k, v = self._decode(self.cur.prev())
225                     self.item = k, v
226             except notfound:
227                 self.item = StopIteration
228
229         @dloopfun
230         def next(self):
231             try:
232                 k, v = self.item = self._decode(self.cur.next())
233                 if (self.ld is not missing and
234                     ((self.li and self.typ.compare(k, self.ld) > 0) or
235                      (not self.li and self.typ.compare(k, self.ld) >= 0))):
236                     self.item = StopIteration
237             except notfound:
238                 self.item = StopIteration
239
240         @dloopfun
241         def prev(self):
242             try:
243                 self.item = self._decode(self.cur.prev())
244                 if (self.fd is not missing and
245                     ((self.fi and self.typ.compare(k, self.fd) < 0) or
246                      (not self.fi and self.typ.compare(k, self.fd) <= 0))):
247                     self.item = StopIteration
248             except notfound:
249                 self.item = StopIteration
250
251         def __next__(self):
252             if self.item is None:
253                 if not self.rev:
254                     self.next()
255                 else:
256                     self.prev()
257             if self.item is StopIteration:
258                 raise StopIteration()
259             ret, self.item = self.item, None
260             return ret
261
262         def skip(self, n=1):
263             try:
264                 for i in range(n):
265                     next(self)
266             except StopIteration:
267                 return
268
269     def get(self, *, match=missing, ge=missing, gt=missing, lt=missing, le=missing, all=False, reverse=False):
270         if all:
271             cur = self.cursor(self, missing, True, missing, True, reverse)
272         elif match is not missing:
273             cur = self.cursor(self, match, True, match, True, reverse)
274         elif ge is not missing or gt is not missing or lt is not missing or le is not missing:
275             if ge is not missing:
276                 fd, fi = ge, True
277             elif gt is not missing:
278                 fd, fi = gt, False
279             else:
280                 fd, fi = missing, True
281             if le is not missing:
282                 ld, li = le, True
283             elif lt is not missing:
284                 ld, li = lt, False
285             else:
286                 ld, li = missing, True
287             cur = self.cursor(self, fd, fi, ld, li, reverse)
288         else:
289             raise NameError("invalid get() specification")
290         done = False
291         try:
292             if not reverse:
293                 cur.first()
294             else:
295                 cur.last()
296             done = True
297             return cur
298         finally:
299             if not done:
300                 cur.close()
301
302     @txnfun(lambda self: self.db.env.env)
303     def put(self, key, id, *, tx):
304         obid = struct.pack(">Q", id)
305         if not self.db.ob.has_key(obid, txn=tx.tx):
306             raise ValueError("no such object in database: " + str(id))
307         try:
308             self.bk.put(self.typ.encode(key), obid, txn=tx.tx, flags=bd.DB_NODUPDATA)
309         except bd.DBKeyExistError:
310             return False
311         return True
312
313     @txnfun(lambda self: self.db.env.env)
314     def remove(self, key, id, *, tx):
315         obid = struct.pack(">Q", id)
316         if not self.db.ob.has_key(obid, txn=tx.tx):
317             raise ValueError("no such object in database: " + str(id))
318         cur = self.bk.cursor(txn=tx.tx)
319         try:
320             try:
321                 cur.get_both(self.typ.encode(key), obid)
322             except notfound:
323                 return False
324             cur.delete()
325         finally:
326             cur.close()
327         return True