Added an explicit index type for object IDs.
[didex.git] / didex / index.py
CommitLineData
abb94f83 1import struct, contextlib, math
a95055e8 2from . import db, lib
6efe4e23 3from .db import bd, txnfun, dloopfun
a95055e8 4
61b65544 5__all__ = ["maybe", "t_int", "t_uint", "t_dbid", "t_float", "t_str", "ordered"]
cbf73d3a 6
a95055e8
FT
7deadlock = bd.DBLockDeadlockError
8notfound = bd.DBNotFoundError
9
10class simpletype(object):
11 def __init__(self, encode, decode):
12 self.enc = encode
13 self.dec = decode
14
15 def encode(self, ob):
16 return self.enc(ob)
17 def decode(self, dat):
18 return self.dec(dat)
19 def compare(self, a, b):
20 if a < b:
21 return -1
22 elif a > b:
23 return 1
24 else:
25 return 0
26
27 @classmethod
28 def struct(cls, fmt):
29 return cls(lambda ob: struct.pack(fmt, ob),
30 lambda dat: struct.unpack(fmt, dat)[0])
31
32class maybe(object):
33 def __init__(self, bk):
34 self.bk = bk
35
36 def encode(self, ob):
37 if ob is None: return b""
38 return b"\0" + self.bk.encode(ob)
39 def decode(self, dat):
40 if dat == b"": return None
41 return self.bk.dec(dat[1:])
42 def compare(self, a, b):
43 if a is b is None:
44 return 0
45 elif a is None:
46 return -1
47 elif b is None:
48 return 1
49 else:
50 return self.bk.compare(a[1:], b[1:])
51
bd14729f
FT
52class compound(object):
53 def __init__(self, *parts):
54 self.parts = parts
55
177fbee6
FT
56 small = object()
57 large = object()
58 def minim(self, *parts):
59 return parts + tuple([self.small] * (len(self.parts) - len(parts)))
60 def maxim(self, *parts):
61 return parts + tuple([self.large] * (len(self.parts) - len(parts)))
62
bd14729f
FT
63 def encode(self, obs):
64 if len(obs) != len(self.parts):
65 raise ValueError("invalid length of compound data: " + str(len(obs)) + ", rather than " + len(self.parts))
66 buf = bytearray()
67 for ob, part in zip(obs, self.parts):
177fbee6
FT
68 if ob is self.small:
69 buf.append(0x01)
70 elif ob is self.large:
71 buf.append(0x02)
bd14729f 72 else:
177fbee6
FT
73 dat = part.encode(ob)
74 if len(dat) < 128:
75 buf.append(0x80 | len(dat))
76 buf.extend(dat)
77 else:
78 buf.extend(struct.pack(">BI", 0, len(dat)))
79 buf.extend(dat)
bd14729f
FT
80 return bytes(buf)
81 def decode(self, dat):
82 ret = []
83 off = 0
84 for part in self.parts:
177fbee6
FT
85 fl = dat[off]
86 off += 1
87 if fl & 0x80:
88 ln = fl & 0x7f
89 elif fl == 0x01:
90 ret.append(self.small)
91 continue
92 elif fl == 0x02:
93 ret.append(self.large)
94 continue
bd14729f 95 else:
177fbee6 96 ln = struct.unpack(">I", dat[off:off + 4])[0]
bd14729f 97 off += 4
177fbee6
FT
98 ret.append(part.decode(dat[off:off + ln]))
99 off += ln
bd14729f
FT
100 return tuple(ret)
101 def compare(self, al, bl):
102 if (len(al) != len(self.parts)) or (len(bl) != len(self.parts)):
103 raise ValueError("invalid length of compound data: " + str(len(al)) + ", " + str(len(bl)) + ", rather than " + len(self.parts))
104 for a, b, part in zip(al, bl, self.parts):
177fbee6
FT
105 if a in (self.small, self.large) or b in (self.small, self.large):
106 if a is b:
107 return 0
108 if a is self.small:
109 return -1
110 elif b is self.small:
111 return 1
112 elif a is self.large:
113 return 1
114 elif b is self.large:
115 return -1
bd14729f
FT
116 c = part.compare(a, b)
117 if c != 0:
118 return c
119 return 0
120
abb94f83
FT
121def floatcmp(a, b):
122 if math.isnan(a) and math.isnan(b):
123 return 0
124 elif math.isnan(a):
125 return -1
126 elif math.isnan(b):
127 return 1
128 elif a < b:
129 return -1
130 elif a > b:
131 return 1
132 else:
133 return 0
134
135t_int = simpletype.struct(">q")
136t_uint = simpletype.struct(">Q")
61b65544 137t_dbid = t_uint
abb94f83
FT
138t_float = simpletype.struct(">d")
139t_float.compare = floatcmp
140t_str = simpletype((lambda ob: ob.encode("utf-8")), (lambda dat: dat.decode("utf-8")))
a95055e8
FT
141
142class index(object):
143 def __init__(self, db, name, datatype):
144 self.db = db
145 self.nm = name
146 self.typ = datatype
147
148missing = object()
149
150class ordered(index, lib.closable):
eb274691 151 def __init__(self, db, name, datatype, create=True):
a95055e8 152 super().__init__(db, name, datatype)
a95055e8
FT
153 fl = bd.DB_THREAD | bd.DB_AUTO_COMMIT
154 if create: fl |= bd.DB_CREATE
155 def initdb(db):
156 def compare(a, b):
157 if a == b == "": return 0
158 return self.typ.compare(self.typ.decode(a), self.typ.decode(b))
159 db.set_flags(bd.DB_DUPSORT)
160 db.set_bt_compare(compare)
161 self.bk = db._opendb("i-" + name, bd.DB_BTREE, fl, initdb)
162 self.bk.set_get_returns_none(False)
163
164 def close(self):
165 self.bk.close()
166
167 class cursor(lib.closable):
6efe4e23 168 def __init__(self, idx, fd, fi, ld, li, reverse):
a95055e8 169 self.idx = idx
6efe4e23
FT
170 self.typ = idx.typ
171 self.cur = self.idx.bk.cursor()
172 self.item = None
173 self.fd = fd
174 self.fi = fi
175 self.ld = ld
176 self.li = li
177 self.rev = reverse
a95055e8
FT
178
179 def close(self):
180 if self.cur is not None:
181 self.cur.close()
6efe4e23 182 self.cur = None
a95055e8
FT
183
184 def __iter__(self):
185 return self
186
6efe4e23
FT
187 def _decode(self, d):
188 k, v = d
a48a2d5d 189 k = self.typ.decode(k)
6efe4e23
FT
190 v = struct.unpack(">Q", v)[0]
191 return k, v
a95055e8 192
6efe4e23
FT
193 @dloopfun
194 def first(self):
195 try:
d6d41a45 196 if self.fd is missing:
6efe4e23
FT
197 self.item = self._decode(self.cur.first())
198 else:
199 k, v = self._decode(self.cur.set_range(self.typ.encode(self.fd)))
200 if not self.fi:
201 while self.typ.compare(k, self.fd) == 0:
202 k, v = self._decode(self.cur.next())
203 self.item = k, v
204 except notfound:
205 self.item = StopIteration
206
207 @dloopfun
208 def last(self):
209 try:
d6d41a45 210 if self.ld is missing:
6efe4e23
FT
211 self.item = self._decode(self.cur.last())
212 else:
d6d41a45
FT
213 try:
214 k, v = self._decode(self.cur.set_range(self.typ.encode(self.ld)))
215 except notfound:
216 k, v = self._decode(self.cur.last())
217 if self.li:
218 while self.typ.compare(k, self.ld) == 0:
6efe4e23 219 k, v = self._decode(self.cur.next())
d6d41a45
FT
220 while self.typ.compare(k, self.ld) > 0:
221 k, v = self._decode(self.cur.prev())
6efe4e23 222 else:
d6d41a45 223 while self.typ.compare(k, self.ld) >= 0:
6efe4e23
FT
224 k, v = self._decode(self.cur.prev())
225 self.item = k, v
226 except notfound:
227 self.item = StopIteration
228
229 @dloopfun
230 def next(self):
231 try:
232 k, v = self.item = self._decode(self.cur.next())
d6d41a45
FT
233 if (self.ld is not missing and
234 ((self.li and self.typ.compare(k, self.ld) > 0) or
235 (not self.li and self.typ.compare(k, self.ld) >= 0))):
6efe4e23
FT
236 self.item = StopIteration
237 except notfound:
238 self.item = StopIteration
239
240 @dloopfun
241 def prev(self):
a95055e8 242 try:
6efe4e23 243 self.item = self._decode(self.cur.prev())
d6d41a45
FT
244 if (self.fd is not missing and
245 ((self.fi and self.typ.compare(k, self.fd) < 0) or
246 (not self.fi and self.typ.compare(k, self.fd) <= 0))):
6efe4e23 247 self.item = StopIteration
a95055e8 248 except notfound:
6efe4e23
FT
249 self.item = StopIteration
250
251 def __next__(self):
6efe4e23
FT
252 if self.item is None:
253 if not self.rev:
254 self.next()
255 else:
256 self.prev()
a48a2d5d
FT
257 if self.item is StopIteration:
258 raise StopIteration()
6efe4e23
FT
259 ret, self.item = self.item, None
260 return ret
a95055e8
FT
261
262 def skip(self, n=1):
263 try:
264 for i in range(n):
265 next(self)
266 except StopIteration:
267 return
268
6efe4e23
FT
269 def get(self, *, match=missing, ge=missing, gt=missing, lt=missing, le=missing, all=False, reverse=False):
270 if all:
d6d41a45 271 cur = self.cursor(self, missing, True, missing, True, reverse)
6efe4e23
FT
272 elif match is not missing:
273 cur = self.cursor(self, match, True, match, True, reverse)
274 elif ge is not missing or gt is not missing or lt is not missing or le is not missing:
275 if ge is not missing:
276 fd, fi = ge, True
277 elif gt is not missing:
278 fd, fi = gt, False
279 else:
d6d41a45 280 fd, fi = missing, True
6efe4e23
FT
281 if le is not missing:
282 ld, li = le, True
283 elif lt is not missing:
284 ld, li = lt, False
285 else:
d6d41a45 286 ld, li = missing, True
6efe4e23
FT
287 cur = self.cursor(self, fd, fi, ld, li, reverse)
288 else:
289 raise NameError("invalid get() specification")
290 done = False
291 try:
292 if not reverse:
293 cur.first()
294 else:
295 cur.last()
296 done = True
297 return cur
298 finally:
299 if not done:
300 cur.close()
a95055e8 301
8950191c
FT
302 @txnfun(lambda self: self.db.env.env)
303 def put(self, key, id, *, tx):
304 obid = struct.pack(">Q", id)
305 if not self.db.ob.has_key(obid, txn=tx.tx):
306 raise ValueError("no such object in database: " + str(id))
307 try:
308 self.bk.put(self.typ.encode(key), obid, txn=tx.tx, flags=bd.DB_NODUPDATA)
309 except bd.DBKeyExistError:
310 return False
311 return True
312
313 @txnfun(lambda self: self.db.env.env)
314 def remove(self, key, id, *, tx):
315 obid = struct.pack(">Q", id)
316 if not self.db.ob.has_key(obid, txn=tx.tx):
317 raise ValueError("no such object in database: " + str(id))
318 cur = self.bk.cursor(txn=tx.tx)
319 try:
a95055e8 320 try:
8950191c
FT
321 cur.get_both(self.typ.encode(key), obid)
322 except notfound:
323 return False
324 cur.delete()
325 finally:
326 cur.close()
327 return True