Fixed some index bugs.
[didex.git] / didex / index.py
CommitLineData
abb94f83 1import struct, contextlib, math
a95055e8 2from . import db, lib
6efe4e23 3from .db import bd, txnfun, dloopfun
a95055e8 4
cbf73d3a
FT
5__all__ = ["maybe", "t_int", "t_uint", "t_float", "t_str", "ordered"]
6
a95055e8
FT
7deadlock = bd.DBLockDeadlockError
8notfound = bd.DBNotFoundError
9
10class simpletype(object):
11 def __init__(self, encode, decode):
12 self.enc = encode
13 self.dec = decode
14
15 def encode(self, ob):
16 return self.enc(ob)
17 def decode(self, dat):
18 return self.dec(dat)
19 def compare(self, a, b):
20 if a < b:
21 return -1
22 elif a > b:
23 return 1
24 else:
25 return 0
26
27 @classmethod
28 def struct(cls, fmt):
29 return cls(lambda ob: struct.pack(fmt, ob),
30 lambda dat: struct.unpack(fmt, dat)[0])
31
32class maybe(object):
33 def __init__(self, bk):
34 self.bk = bk
35
36 def encode(self, ob):
37 if ob is None: return b""
38 return b"\0" + self.bk.encode(ob)
39 def decode(self, dat):
40 if dat == b"": return None
41 return self.bk.dec(dat[1:])
42 def compare(self, a, b):
43 if a is b is None:
44 return 0
45 elif a is None:
46 return -1
47 elif b is None:
48 return 1
49 else:
50 return self.bk.compare(a[1:], b[1:])
51
bd14729f
FT
52class compound(object):
53 def __init__(self, *parts):
54 self.parts = parts
55
56 def encode(self, obs):
57 if len(obs) != len(self.parts):
58 raise ValueError("invalid length of compound data: " + str(len(obs)) + ", rather than " + len(self.parts))
59 buf = bytearray()
60 for ob, part in zip(obs, self.parts):
61 dat = part.encode(ob)
62 if len(dat) < 128:
63 buf.append(0x80 | len(dat))
64 buf.extend(dat)
65 else:
66 buf.extend(struct.pack(">i", len(dat)))
67 buf.extend(dat)
68 return bytes(buf)
69 def decode(self, dat):
70 ret = []
71 off = 0
72 for part in self.parts:
73 if dat[off] & 0x80:
74 ln = dat[off] & 0x7f
75 off += 1
76 else:
77 ln = struct.unpack(">i", dat[off:off + 4])[0]
78 off += 4
79 ret.append(part.decode(dat[off:off + len]))
80 off += len
81 return tuple(ret)
82 def compare(self, al, bl):
83 if (len(al) != len(self.parts)) or (len(bl) != len(self.parts)):
84 raise ValueError("invalid length of compound data: " + str(len(al)) + ", " + str(len(bl)) + ", rather than " + len(self.parts))
85 for a, b, part in zip(al, bl, self.parts):
86 c = part.compare(a, b)
87 if c != 0:
88 return c
89 return 0
90
abb94f83
FT
91def floatcmp(a, b):
92 if math.isnan(a) and math.isnan(b):
93 return 0
94 elif math.isnan(a):
95 return -1
96 elif math.isnan(b):
97 return 1
98 elif a < b:
99 return -1
100 elif a > b:
101 return 1
102 else:
103 return 0
104
105t_int = simpletype.struct(">q")
106t_uint = simpletype.struct(">Q")
107t_float = simpletype.struct(">d")
108t_float.compare = floatcmp
109t_str = simpletype((lambda ob: ob.encode("utf-8")), (lambda dat: dat.decode("utf-8")))
a95055e8
FT
110
111class index(object):
112 def __init__(self, db, name, datatype):
113 self.db = db
114 self.nm = name
115 self.typ = datatype
116
117missing = object()
118
119class ordered(index, lib.closable):
eb274691 120 def __init__(self, db, name, datatype, create=True):
a95055e8 121 super().__init__(db, name, datatype)
a95055e8
FT
122 fl = bd.DB_THREAD | bd.DB_AUTO_COMMIT
123 if create: fl |= bd.DB_CREATE
124 def initdb(db):
125 def compare(a, b):
126 if a == b == "": return 0
127 return self.typ.compare(self.typ.decode(a), self.typ.decode(b))
128 db.set_flags(bd.DB_DUPSORT)
129 db.set_bt_compare(compare)
130 self.bk = db._opendb("i-" + name, bd.DB_BTREE, fl, initdb)
131 self.bk.set_get_returns_none(False)
132
133 def close(self):
134 self.bk.close()
135
136 class cursor(lib.closable):
6efe4e23 137 def __init__(self, idx, fd, fi, ld, li, reverse):
a95055e8 138 self.idx = idx
6efe4e23
FT
139 self.typ = idx.typ
140 self.cur = self.idx.bk.cursor()
141 self.item = None
142 self.fd = fd
143 self.fi = fi
144 self.ld = ld
145 self.li = li
146 self.rev = reverse
a95055e8
FT
147
148 def close(self):
149 if self.cur is not None:
150 self.cur.close()
6efe4e23 151 self.cur = None
a95055e8
FT
152
153 def __iter__(self):
154 return self
155
6efe4e23
FT
156 def _decode(self, d):
157 k, v = d
a48a2d5d 158 k = self.typ.decode(k)
6efe4e23
FT
159 v = struct.unpack(">Q", v)[0]
160 return k, v
a95055e8 161
6efe4e23
FT
162 @dloopfun
163 def first(self):
164 try:
d6d41a45 165 if self.fd is missing:
6efe4e23
FT
166 self.item = self._decode(self.cur.first())
167 else:
168 k, v = self._decode(self.cur.set_range(self.typ.encode(self.fd)))
169 if not self.fi:
170 while self.typ.compare(k, self.fd) == 0:
171 k, v = self._decode(self.cur.next())
172 self.item = k, v
173 except notfound:
174 self.item = StopIteration
175
176 @dloopfun
177 def last(self):
178 try:
d6d41a45 179 if self.ld is missing:
6efe4e23
FT
180 self.item = self._decode(self.cur.last())
181 else:
d6d41a45
FT
182 try:
183 k, v = self._decode(self.cur.set_range(self.typ.encode(self.ld)))
184 except notfound:
185 k, v = self._decode(self.cur.last())
186 if self.li:
187 while self.typ.compare(k, self.ld) == 0:
6efe4e23 188 k, v = self._decode(self.cur.next())
d6d41a45
FT
189 while self.typ.compare(k, self.ld) > 0:
190 k, v = self._decode(self.cur.prev())
6efe4e23 191 else:
d6d41a45 192 while self.typ.compare(k, self.ld) >= 0:
6efe4e23
FT
193 k, v = self._decode(self.cur.prev())
194 self.item = k, v
195 except notfound:
196 self.item = StopIteration
197
198 @dloopfun
199 def next(self):
200 try:
201 k, v = self.item = self._decode(self.cur.next())
d6d41a45
FT
202 if (self.ld is not missing and
203 ((self.li and self.typ.compare(k, self.ld) > 0) or
204 (not self.li and self.typ.compare(k, self.ld) >= 0))):
6efe4e23
FT
205 self.item = StopIteration
206 except notfound:
207 self.item = StopIteration
208
209 @dloopfun
210 def prev(self):
a95055e8 211 try:
6efe4e23 212 self.item = self._decode(self.cur.prev())
d6d41a45
FT
213 if (self.fd is not missing and
214 ((self.fi and self.typ.compare(k, self.fd) < 0) or
215 (not self.fi and self.typ.compare(k, self.fd) <= 0))):
6efe4e23 216 self.item = StopIteration
a95055e8 217 except notfound:
6efe4e23
FT
218 self.item = StopIteration
219
220 def __next__(self):
6efe4e23
FT
221 if self.item is None:
222 if not self.rev:
223 self.next()
224 else:
225 self.prev()
a48a2d5d
FT
226 if self.item is StopIteration:
227 raise StopIteration()
6efe4e23
FT
228 ret, self.item = self.item, None
229 return ret
a95055e8
FT
230
231 def skip(self, n=1):
232 try:
233 for i in range(n):
234 next(self)
235 except StopIteration:
236 return
237
6efe4e23
FT
238 def get(self, *, match=missing, ge=missing, gt=missing, lt=missing, le=missing, all=False, reverse=False):
239 if all:
d6d41a45 240 cur = self.cursor(self, missing, True, missing, True, reverse)
6efe4e23
FT
241 elif match is not missing:
242 cur = self.cursor(self, match, True, match, True, reverse)
243 elif ge is not missing or gt is not missing or lt is not missing or le is not missing:
244 if ge is not missing:
245 fd, fi = ge, True
246 elif gt is not missing:
247 fd, fi = gt, False
248 else:
d6d41a45 249 fd, fi = missing, True
6efe4e23
FT
250 if le is not missing:
251 ld, li = le, True
252 elif lt is not missing:
253 ld, li = lt, False
254 else:
d6d41a45 255 ld, li = missing, True
6efe4e23
FT
256 cur = self.cursor(self, fd, fi, ld, li, reverse)
257 else:
258 raise NameError("invalid get() specification")
259 done = False
260 try:
261 if not reverse:
262 cur.first()
263 else:
264 cur.last()
265 done = True
266 return cur
267 finally:
268 if not done:
269 cur.close()
a95055e8 270
8950191c
FT
271 @txnfun(lambda self: self.db.env.env)
272 def put(self, key, id, *, tx):
273 obid = struct.pack(">Q", id)
274 if not self.db.ob.has_key(obid, txn=tx.tx):
275 raise ValueError("no such object in database: " + str(id))
276 try:
277 self.bk.put(self.typ.encode(key), obid, txn=tx.tx, flags=bd.DB_NODUPDATA)
278 except bd.DBKeyExistError:
279 return False
280 return True
281
282 @txnfun(lambda self: self.db.env.env)
283 def remove(self, key, id, *, tx):
284 obid = struct.pack(">Q", id)
285 if not self.db.ob.has_key(obid, txn=tx.tx):
286 raise ValueError("no such object in database: " + str(id))
287 cur = self.bk.cursor(txn=tx.tx)
288 try:
a95055e8 289 try:
8950191c
FT
290 cur.get_both(self.typ.encode(key), obid)
291 except notfound:
292 return False
293 cur.delete()
294 finally:
295 cur.close()
296 return True