4906b46a2e7ec260c2dc2a96c85c8f2b7a3f9a77
[didex.git] / didex / index.py
1 import struct, contextlib, math
2 from . import db, lib
3 from .db import bd, txnfun, dloopfun
4
5 __all__ = ["maybe", "t_int", "t_uint", "t_float", "t_str", "ordered"]
6
7 deadlock = bd.DBLockDeadlockError
8 notfound = bd.DBNotFoundError
9
10 class simpletype(object):
11     def __init__(self, encode, decode):
12         self.enc = encode
13         self.dec = decode
14
15     def encode(self, ob):
16         return self.enc(ob)
17     def decode(self, dat):
18         return self.dec(dat)
19     def compare(self, a, b):
20         if a < b:
21             return -1
22         elif a > b:
23             return 1
24         else:
25             return 0
26
27     @classmethod
28     def struct(cls, fmt):
29         return cls(lambda ob: struct.pack(fmt, ob),
30                    lambda dat: struct.unpack(fmt, dat)[0])
31
32 class maybe(object):
33     def __init__(self, bk):
34         self.bk = bk
35
36     def encode(self, ob):
37         if ob is None: return b""
38         return b"\0" + self.bk.encode(ob)
39     def decode(self, dat):
40         if dat == b"": return None
41         return self.bk.dec(dat[1:])
42     def compare(self, a, b):
43         if a is b is None:
44             return 0
45         elif a is None:
46             return -1
47         elif b is None:
48             return 1
49         else:
50             return self.bk.compare(a[1:], b[1:])
51
52 class compound(object):
53     def __init__(self, *parts):
54         self.parts = parts
55
56     small = object()
57     large = object()
58     def minim(self, *parts):
59         return parts + tuple([self.small] * (len(self.parts) - len(parts)))
60     def maxim(self, *parts):
61         return parts + tuple([self.large] * (len(self.parts) - len(parts)))
62
63     def encode(self, obs):
64         if len(obs) != len(self.parts):
65             raise ValueError("invalid length of compound data: " + str(len(obs)) + ", rather than " + len(self.parts))
66         buf = bytearray()
67         for ob, part in zip(obs, self.parts):
68             if ob is self.small:
69                 buf.append(0x01)
70             elif ob is self.large:
71                 buf.append(0x02)
72             else:
73                 dat = part.encode(ob)
74                 if len(dat) < 128:
75                     buf.append(0x80 | len(dat))
76                     buf.extend(dat)
77                 else:
78                     buf.extend(struct.pack(">BI", 0, len(dat)))
79                     buf.extend(dat)
80         return bytes(buf)
81     def decode(self, dat):
82         ret = []
83         off = 0
84         for part in self.parts:
85             fl = dat[off]
86             off += 1
87             if fl & 0x80:
88                 ln = fl & 0x7f
89             elif fl == 0x01:
90                 ret.append(self.small)
91                 continue
92             elif fl == 0x02:
93                 ret.append(self.large)
94                 continue
95             else:
96                 ln = struct.unpack(">I", dat[off:off + 4])[0]
97                 off += 4
98             ret.append(part.decode(dat[off:off + ln]))
99             off += ln
100         return tuple(ret)
101     def compare(self, al, bl):
102         if (len(al) != len(self.parts)) or (len(bl) != len(self.parts)):
103             raise ValueError("invalid length of compound data: " + str(len(al)) + ", " + str(len(bl)) + ", rather than " + len(self.parts))
104         for a, b, part in zip(al, bl, self.parts):
105             if a in (self.small, self.large) or b in (self.small, self.large):
106                 if a is b:
107                     return 0
108                 if a is self.small:
109                     return -1
110                 elif b is self.small:
111                     return 1
112                 elif a is self.large:
113                     return 1
114                 elif b is self.large:
115                     return -1
116             c = part.compare(a, b)
117             if c != 0:
118                 return c
119         return 0
120
121 def floatcmp(a, b):
122     if math.isnan(a) and math.isnan(b):
123         return 0
124     elif math.isnan(a):
125         return -1
126     elif math.isnan(b):
127         return 1
128     elif a < b:
129         return -1
130     elif a > b:
131         return 1
132     else:
133         return 0
134
135 t_int = simpletype.struct(">q")
136 t_uint = simpletype.struct(">Q")
137 t_float = simpletype.struct(">d")
138 t_float.compare = floatcmp
139 t_str = simpletype((lambda ob: ob.encode("utf-8")), (lambda dat: dat.decode("utf-8")))
140
141 class index(object):
142     def __init__(self, db, name, datatype):
143         self.db = db
144         self.nm = name
145         self.typ = datatype
146
147 missing = object()
148
149 class ordered(index, lib.closable):
150     def __init__(self, db, name, datatype, create=True):
151         super().__init__(db, name, datatype)
152         fl = bd.DB_THREAD | bd.DB_AUTO_COMMIT
153         if create: fl |= bd.DB_CREATE
154         def initdb(db):
155             def compare(a, b):
156                 if a == b == "": return 0
157                 return self.typ.compare(self.typ.decode(a), self.typ.decode(b))
158             db.set_flags(bd.DB_DUPSORT)
159             db.set_bt_compare(compare)
160         self.bk = db._opendb("i-" + name, bd.DB_BTREE, fl, initdb)
161         self.bk.set_get_returns_none(False)
162
163     def close(self):
164         self.bk.close()
165
166     class cursor(lib.closable):
167         def __init__(self, idx, fd, fi, ld, li, reverse):
168             self.idx = idx
169             self.typ = idx.typ
170             self.cur = self.idx.bk.cursor()
171             self.item = None
172             self.fd = fd
173             self.fi = fi
174             self.ld = ld
175             self.li = li
176             self.rev = reverse
177
178         def close(self):
179             if self.cur is not None:
180                 self.cur.close()
181                 self.cur = None
182
183         def __iter__(self):
184             return self
185
186         def _decode(self, d):
187             k, v = d
188             k = self.typ.decode(k)
189             v = struct.unpack(">Q", v)[0]
190             return k, v
191
192         @dloopfun
193         def first(self):
194             try:
195                 if self.fd is missing:
196                     self.item = self._decode(self.cur.first())
197                 else:
198                     k, v = self._decode(self.cur.set_range(self.typ.encode(self.fd)))
199                     if not self.fi:
200                         while self.typ.compare(k, self.fd) == 0:
201                             k, v = self._decode(self.cur.next())
202                     self.item = k, v
203             except notfound:
204                 self.item = StopIteration
205
206         @dloopfun
207         def last(self):
208             try:
209                 if self.ld is missing:
210                     self.item = self._decode(self.cur.last())
211                 else:
212                     try:
213                         k, v = self._decode(self.cur.set_range(self.typ.encode(self.ld)))
214                     except notfound:
215                         k, v = self._decode(self.cur.last())
216                     if self.li:
217                         while self.typ.compare(k, self.ld) == 0:
218                             k, v = self._decode(self.cur.next())
219                         while self.typ.compare(k, self.ld) > 0:
220                             k, v = self._decode(self.cur.prev())
221                     else:
222                         while self.typ.compare(k, self.ld) >= 0:
223                             k, v = self._decode(self.cur.prev())
224                     self.item = k, v
225             except notfound:
226                 self.item = StopIteration
227
228         @dloopfun
229         def next(self):
230             try:
231                 k, v = self.item = self._decode(self.cur.next())
232                 if (self.ld is not missing and
233                     ((self.li and self.typ.compare(k, self.ld) > 0) or
234                      (not self.li and self.typ.compare(k, self.ld) >= 0))):
235                     self.item = StopIteration
236             except notfound:
237                 self.item = StopIteration
238
239         @dloopfun
240         def prev(self):
241             try:
242                 self.item = self._decode(self.cur.prev())
243                 if (self.fd is not missing and
244                     ((self.fi and self.typ.compare(k, self.fd) < 0) or
245                      (not self.fi and self.typ.compare(k, self.fd) <= 0))):
246                     self.item = StopIteration
247             except notfound:
248                 self.item = StopIteration
249
250         def __next__(self):
251             if self.item is None:
252                 if not self.rev:
253                     self.next()
254                 else:
255                     self.prev()
256             if self.item is StopIteration:
257                 raise StopIteration()
258             ret, self.item = self.item, None
259             return ret
260
261         def skip(self, n=1):
262             try:
263                 for i in range(n):
264                     next(self)
265             except StopIteration:
266                 return
267
268     def get(self, *, match=missing, ge=missing, gt=missing, lt=missing, le=missing, all=False, reverse=False):
269         if all:
270             cur = self.cursor(self, missing, True, missing, True, reverse)
271         elif match is not missing:
272             cur = self.cursor(self, match, True, match, True, reverse)
273         elif ge is not missing or gt is not missing or lt is not missing or le is not missing:
274             if ge is not missing:
275                 fd, fi = ge, True
276             elif gt is not missing:
277                 fd, fi = gt, False
278             else:
279                 fd, fi = missing, True
280             if le is not missing:
281                 ld, li = le, True
282             elif lt is not missing:
283                 ld, li = lt, False
284             else:
285                 ld, li = missing, True
286             cur = self.cursor(self, fd, fi, ld, li, reverse)
287         else:
288             raise NameError("invalid get() specification")
289         done = False
290         try:
291             if not reverse:
292                 cur.first()
293             else:
294                 cur.last()
295             done = True
296             return cur
297         finally:
298             if not done:
299                 cur.close()
300
301     @txnfun(lambda self: self.db.env.env)
302     def put(self, key, id, *, tx):
303         obid = struct.pack(">Q", id)
304         if not self.db.ob.has_key(obid, txn=tx.tx):
305             raise ValueError("no such object in database: " + str(id))
306         try:
307             self.bk.put(self.typ.encode(key), obid, txn=tx.tx, flags=bd.DB_NODUPDATA)
308         except bd.DBKeyExistError:
309             return False
310         return True
311
312     @txnfun(lambda self: self.db.env.env)
313     def remove(self, key, id, *, tx):
314         obid = struct.pack(">Q", id)
315         if not self.db.ob.has_key(obid, txn=tx.tx):
316             raise ValueError("no such object in database: " + str(id))
317         cur = self.bk.cursor(txn=tx.tx)
318         try:
319             try:
320                 cur.get_both(self.typ.encode(key), obid)
321             except notfound:
322                 return False
323             cur.delete()
324         finally:
325             cur.close()
326         return True