Added some convenient __all__ imports.
[didex.git] / didex / index.py
CommitLineData
abb94f83 1import struct, contextlib, math
a95055e8 2from . import db, lib
8950191c 3from .db import bd, txnfun
a95055e8 4
cbf73d3a
FT
5__all__ = ["maybe", "t_int", "t_uint", "t_float", "t_str", "ordered"]
6
a95055e8
FT
7deadlock = bd.DBLockDeadlockError
8notfound = bd.DBNotFoundError
9
10class simpletype(object):
11 def __init__(self, encode, decode):
12 self.enc = encode
13 self.dec = decode
14
15 def encode(self, ob):
16 return self.enc(ob)
17 def decode(self, dat):
18 return self.dec(dat)
19 def compare(self, a, b):
20 if a < b:
21 return -1
22 elif a > b:
23 return 1
24 else:
25 return 0
26
27 @classmethod
28 def struct(cls, fmt):
29 return cls(lambda ob: struct.pack(fmt, ob),
30 lambda dat: struct.unpack(fmt, dat)[0])
31
32class maybe(object):
33 def __init__(self, bk):
34 self.bk = bk
35
36 def encode(self, ob):
37 if ob is None: return b""
38 return b"\0" + self.bk.encode(ob)
39 def decode(self, dat):
40 if dat == b"": return None
41 return self.bk.dec(dat[1:])
42 def compare(self, a, b):
43 if a is b is None:
44 return 0
45 elif a is None:
46 return -1
47 elif b is None:
48 return 1
49 else:
50 return self.bk.compare(a[1:], b[1:])
51
abb94f83
FT
52def floatcmp(a, b):
53 if math.isnan(a) and math.isnan(b):
54 return 0
55 elif math.isnan(a):
56 return -1
57 elif math.isnan(b):
58 return 1
59 elif a < b:
60 return -1
61 elif a > b:
62 return 1
63 else:
64 return 0
65
66t_int = simpletype.struct(">q")
67t_uint = simpletype.struct(">Q")
68t_float = simpletype.struct(">d")
69t_float.compare = floatcmp
70t_str = simpletype((lambda ob: ob.encode("utf-8")), (lambda dat: dat.decode("utf-8")))
a95055e8
FT
71
72class index(object):
73 def __init__(self, db, name, datatype):
74 self.db = db
75 self.nm = name
76 self.typ = datatype
77
78missing = object()
79
80class ordered(index, lib.closable):
eb274691 81 def __init__(self, db, name, datatype, create=True):
a95055e8 82 super().__init__(db, name, datatype)
a95055e8
FT
83 fl = bd.DB_THREAD | bd.DB_AUTO_COMMIT
84 if create: fl |= bd.DB_CREATE
85 def initdb(db):
86 def compare(a, b):
87 if a == b == "": return 0
88 return self.typ.compare(self.typ.decode(a), self.typ.decode(b))
89 db.set_flags(bd.DB_DUPSORT)
90 db.set_bt_compare(compare)
91 self.bk = db._opendb("i-" + name, bd.DB_BTREE, fl, initdb)
92 self.bk.set_get_returns_none(False)
93
94 def close(self):
95 self.bk.close()
96
97 class cursor(lib.closable):
98 def __init__(self, idx, cur, item, stop):
99 self.idx = idx
100 self.cur = cur
101 self.item = item
102 self.stop = stop
103
104 def close(self):
105 if self.cur is not None:
106 self.cur.close()
107
108 def __iter__(self):
109 return self
110
111 def peek(self):
112 if self.item is None:
113 raise StopIteration()
114 rk, rv = self.item
115 rk = self.idx.typ.decode(rk)
116 rv = struct.unpack(">Q", rv)[0]
117 if self.stop(rk):
118 self.item = None
119 raise StopIteration()
120 return rk, rv
121
122 def __next__(self):
123 rk, rv = self.peek()
124 try:
125 while True:
126 try:
127 self.item = self.cur.next()
128 break
129 except deadlock:
130 continue
131 except notfound:
132 self.item = None
133 return rk, rv
134
135 def skip(self, n=1):
136 try:
137 for i in range(n):
138 next(self)
139 except StopIteration:
140 return
141
142 def get(self, *, match=missing, ge=missing, gt=missing, lt=missing, le=missing, all=False):
143 while True:
144 try:
145 cur = self.bk.cursor()
146 done = False
147 try:
148 if match is not missing:
149 try:
150 k, v = cur.set(self.typ.encode(match))
151 except notfound:
152 return self.cursor(None, None, None, None)
153 else:
154 done = True
155 return self.cursor(self, cur, (k, v), lambda o: (self.typ.compare(o, match) != 0))
156 elif all:
157 try:
158 k, v = cur.first()
159 except notfound:
160 return self.cursor(None, None, None, None)
161 else:
162 done = True
163 return self.cursor(self, cur, (k, v), lambda o: False)
164 elif ge is not missing or gt is not missing or lt is not missing or le is not missing:
165 skip = False
166 try:
167 if ge is not missing:
168 k, v = cur.set_range(self.typ.encode(ge))
169 elif gt is not missing:
170 k, v = cur.set_range(self.typ.encode(gt))
171 skip = True
172 else:
173 k, v = cur.first()
174 except notfound:
175 return self.cursor(None, None, None, None)
176 if lt is not missing:
177 stop = lambda o: self.typ.compare(o, lt) >= 0
178 elif le is not missing:
179 stop = lambda o: self.typ.compare(o, le) > 0
180 else:
181 stop = lambda o: False
182 ret = self.cursor(self, cur, (k, v), stop)
183 if skip:
184 try:
185 while self.typ.compare(ret.peek()[0], gt) == 0:
186 next(ret)
187 except StopIteration:
188 pass
189 done = True
190 return ret
191 else:
192 raise NameError("invalid get() specification")
193 finally:
194 if not done:
195 cur.close()
196 except deadlock:
197 continue
198
8950191c
FT
199 @txnfun(lambda self: self.db.env.env)
200 def put(self, key, id, *, tx):
201 obid = struct.pack(">Q", id)
202 if not self.db.ob.has_key(obid, txn=tx.tx):
203 raise ValueError("no such object in database: " + str(id))
204 try:
205 self.bk.put(self.typ.encode(key), obid, txn=tx.tx, flags=bd.DB_NODUPDATA)
206 except bd.DBKeyExistError:
207 return False
208 return True
209
210 @txnfun(lambda self: self.db.env.env)
211 def remove(self, key, id, *, tx):
212 obid = struct.pack(">Q", id)
213 if not self.db.ob.has_key(obid, txn=tx.tx):
214 raise ValueError("no such object in database: " + str(id))
215 cur = self.bk.cursor(txn=tx.tx)
216 try:
a95055e8 217 try:
8950191c
FT
218 cur.get_both(self.typ.encode(key), obid)
219 except notfound:
220 return False
221 cur.delete()
222 finally:
223 cur.close()
224 return True