Added more indexed types.
[didex.git] / didex / index.py
CommitLineData
abb94f83 1import struct, contextlib, math
a95055e8 2from . import db, lib
8950191c 3from .db import bd, txnfun
a95055e8
FT
4
5deadlock = bd.DBLockDeadlockError
6notfound = bd.DBNotFoundError
7
8class simpletype(object):
9 def __init__(self, encode, decode):
10 self.enc = encode
11 self.dec = decode
12
13 def encode(self, ob):
14 return self.enc(ob)
15 def decode(self, dat):
16 return self.dec(dat)
17 def compare(self, a, b):
18 if a < b:
19 return -1
20 elif a > b:
21 return 1
22 else:
23 return 0
24
25 @classmethod
26 def struct(cls, fmt):
27 return cls(lambda ob: struct.pack(fmt, ob),
28 lambda dat: struct.unpack(fmt, dat)[0])
29
30class maybe(object):
31 def __init__(self, bk):
32 self.bk = bk
33
34 def encode(self, ob):
35 if ob is None: return b""
36 return b"\0" + self.bk.encode(ob)
37 def decode(self, dat):
38 if dat == b"": return None
39 return self.bk.dec(dat[1:])
40 def compare(self, a, b):
41 if a is b is None:
42 return 0
43 elif a is None:
44 return -1
45 elif b is None:
46 return 1
47 else:
48 return self.bk.compare(a[1:], b[1:])
49
abb94f83
FT
50def floatcmp(a, b):
51 if math.isnan(a) and math.isnan(b):
52 return 0
53 elif math.isnan(a):
54 return -1
55 elif math.isnan(b):
56 return 1
57 elif a < b:
58 return -1
59 elif a > b:
60 return 1
61 else:
62 return 0
63
64t_int = simpletype.struct(">q")
65t_uint = simpletype.struct(">Q")
66t_float = simpletype.struct(">d")
67t_float.compare = floatcmp
68t_str = simpletype((lambda ob: ob.encode("utf-8")), (lambda dat: dat.decode("utf-8")))
a95055e8
FT
69
70class index(object):
71 def __init__(self, db, name, datatype):
72 self.db = db
73 self.nm = name
74 self.typ = datatype
75
76missing = object()
77
78class ordered(index, lib.closable):
eb274691 79 def __init__(self, db, name, datatype, create=True):
a95055e8 80 super().__init__(db, name, datatype)
a95055e8
FT
81 fl = bd.DB_THREAD | bd.DB_AUTO_COMMIT
82 if create: fl |= bd.DB_CREATE
83 def initdb(db):
84 def compare(a, b):
85 if a == b == "": return 0
86 return self.typ.compare(self.typ.decode(a), self.typ.decode(b))
87 db.set_flags(bd.DB_DUPSORT)
88 db.set_bt_compare(compare)
89 self.bk = db._opendb("i-" + name, bd.DB_BTREE, fl, initdb)
90 self.bk.set_get_returns_none(False)
91
92 def close(self):
93 self.bk.close()
94
95 class cursor(lib.closable):
96 def __init__(self, idx, cur, item, stop):
97 self.idx = idx
98 self.cur = cur
99 self.item = item
100 self.stop = stop
101
102 def close(self):
103 if self.cur is not None:
104 self.cur.close()
105
106 def __iter__(self):
107 return self
108
109 def peek(self):
110 if self.item is None:
111 raise StopIteration()
112 rk, rv = self.item
113 rk = self.idx.typ.decode(rk)
114 rv = struct.unpack(">Q", rv)[0]
115 if self.stop(rk):
116 self.item = None
117 raise StopIteration()
118 return rk, rv
119
120 def __next__(self):
121 rk, rv = self.peek()
122 try:
123 while True:
124 try:
125 self.item = self.cur.next()
126 break
127 except deadlock:
128 continue
129 except notfound:
130 self.item = None
131 return rk, rv
132
133 def skip(self, n=1):
134 try:
135 for i in range(n):
136 next(self)
137 except StopIteration:
138 return
139
140 def get(self, *, match=missing, ge=missing, gt=missing, lt=missing, le=missing, all=False):
141 while True:
142 try:
143 cur = self.bk.cursor()
144 done = False
145 try:
146 if match is not missing:
147 try:
148 k, v = cur.set(self.typ.encode(match))
149 except notfound:
150 return self.cursor(None, None, None, None)
151 else:
152 done = True
153 return self.cursor(self, cur, (k, v), lambda o: (self.typ.compare(o, match) != 0))
154 elif all:
155 try:
156 k, v = cur.first()
157 except notfound:
158 return self.cursor(None, None, None, None)
159 else:
160 done = True
161 return self.cursor(self, cur, (k, v), lambda o: False)
162 elif ge is not missing or gt is not missing or lt is not missing or le is not missing:
163 skip = False
164 try:
165 if ge is not missing:
166 k, v = cur.set_range(self.typ.encode(ge))
167 elif gt is not missing:
168 k, v = cur.set_range(self.typ.encode(gt))
169 skip = True
170 else:
171 k, v = cur.first()
172 except notfound:
173 return self.cursor(None, None, None, None)
174 if lt is not missing:
175 stop = lambda o: self.typ.compare(o, lt) >= 0
176 elif le is not missing:
177 stop = lambda o: self.typ.compare(o, le) > 0
178 else:
179 stop = lambda o: False
180 ret = self.cursor(self, cur, (k, v), stop)
181 if skip:
182 try:
183 while self.typ.compare(ret.peek()[0], gt) == 0:
184 next(ret)
185 except StopIteration:
186 pass
187 done = True
188 return ret
189 else:
190 raise NameError("invalid get() specification")
191 finally:
192 if not done:
193 cur.close()
194 except deadlock:
195 continue
196
8950191c
FT
197 @txnfun(lambda self: self.db.env.env)
198 def put(self, key, id, *, tx):
199 obid = struct.pack(">Q", id)
200 if not self.db.ob.has_key(obid, txn=tx.tx):
201 raise ValueError("no such object in database: " + str(id))
202 try:
203 self.bk.put(self.typ.encode(key), obid, txn=tx.tx, flags=bd.DB_NODUPDATA)
204 except bd.DBKeyExistError:
205 return False
206 return True
207
208 @txnfun(lambda self: self.db.env.env)
209 def remove(self, key, id, *, tx):
210 obid = struct.pack(">Q", id)
211 if not self.db.ob.has_key(obid, txn=tx.tx):
212 raise ValueError("no such object in database: " + str(id))
213 cur = self.bk.cursor(txn=tx.tx)
214 try:
a95055e8 215 try:
8950191c
FT
216 cur.get_both(self.typ.encode(key), obid)
217 except notfound:
218 return False
219 cur.delete()
220 finally:
221 cur.close()
222 return True