Initial import.
[didex.git] / didex / index.py
CommitLineData
a95055e8
FT
1import struct, contextlib
2from . import db, lib
3from .db import bd
4
5deadlock = bd.DBLockDeadlockError
6notfound = bd.DBNotFoundError
7
8class simpletype(object):
9 def __init__(self, encode, decode):
10 self.enc = encode
11 self.dec = decode
12
13 def encode(self, ob):
14 return self.enc(ob)
15 def decode(self, dat):
16 return self.dec(dat)
17 def compare(self, a, b):
18 if a < b:
19 return -1
20 elif a > b:
21 return 1
22 else:
23 return 0
24
25 @classmethod
26 def struct(cls, fmt):
27 return cls(lambda ob: struct.pack(fmt, ob),
28 lambda dat: struct.unpack(fmt, dat)[0])
29
30class maybe(object):
31 def __init__(self, bk):
32 self.bk = bk
33
34 def encode(self, ob):
35 if ob is None: return b""
36 return b"\0" + self.bk.encode(ob)
37 def decode(self, dat):
38 if dat == b"": return None
39 return self.bk.dec(dat[1:])
40 def compare(self, a, b):
41 if a is b is None:
42 return 0
43 elif a is None:
44 return -1
45 elif b is None:
46 return 1
47 else:
48 return self.bk.compare(a[1:], b[1:])
49
50t_int = simpletype.struct(">Q")
51
52class index(object):
53 def __init__(self, db, name, datatype):
54 self.db = db
55 self.nm = name
56 self.typ = datatype
57
58missing = object()
59
60class ordered(index, lib.closable):
61 def __init__(self, db, name, datatype, duplicates, create=True):
62 super().__init__(db, name, datatype)
63 self.dup = duplicates
64 fl = bd.DB_THREAD | bd.DB_AUTO_COMMIT
65 if create: fl |= bd.DB_CREATE
66 def initdb(db):
67 def compare(a, b):
68 if a == b == "": return 0
69 return self.typ.compare(self.typ.decode(a), self.typ.decode(b))
70 db.set_flags(bd.DB_DUPSORT)
71 db.set_bt_compare(compare)
72 self.bk = db._opendb("i-" + name, bd.DB_BTREE, fl, initdb)
73 self.bk.set_get_returns_none(False)
74
75 def close(self):
76 self.bk.close()
77
78 class cursor(lib.closable):
79 def __init__(self, idx, cur, item, stop):
80 self.idx = idx
81 self.cur = cur
82 self.item = item
83 self.stop = stop
84
85 def close(self):
86 if self.cur is not None:
87 self.cur.close()
88
89 def __iter__(self):
90 return self
91
92 def peek(self):
93 if self.item is None:
94 raise StopIteration()
95 rk, rv = self.item
96 rk = self.idx.typ.decode(rk)
97 rv = struct.unpack(">Q", rv)[0]
98 if self.stop(rk):
99 self.item = None
100 raise StopIteration()
101 return rk, rv
102
103 def __next__(self):
104 rk, rv = self.peek()
105 try:
106 while True:
107 try:
108 self.item = self.cur.next()
109 break
110 except deadlock:
111 continue
112 except notfound:
113 self.item = None
114 return rk, rv
115
116 def skip(self, n=1):
117 try:
118 for i in range(n):
119 next(self)
120 except StopIteration:
121 return
122
123 def get(self, *, match=missing, ge=missing, gt=missing, lt=missing, le=missing, all=False):
124 while True:
125 try:
126 cur = self.bk.cursor()
127 done = False
128 try:
129 if match is not missing:
130 try:
131 k, v = cur.set(self.typ.encode(match))
132 except notfound:
133 return self.cursor(None, None, None, None)
134 else:
135 done = True
136 return self.cursor(self, cur, (k, v), lambda o: (self.typ.compare(o, match) != 0))
137 elif all:
138 try:
139 k, v = cur.first()
140 except notfound:
141 return self.cursor(None, None, None, None)
142 else:
143 done = True
144 return self.cursor(self, cur, (k, v), lambda o: False)
145 elif ge is not missing or gt is not missing or lt is not missing or le is not missing:
146 skip = False
147 try:
148 if ge is not missing:
149 k, v = cur.set_range(self.typ.encode(ge))
150 elif gt is not missing:
151 k, v = cur.set_range(self.typ.encode(gt))
152 skip = True
153 else:
154 k, v = cur.first()
155 except notfound:
156 return self.cursor(None, None, None, None)
157 if lt is not missing:
158 stop = lambda o: self.typ.compare(o, lt) >= 0
159 elif le is not missing:
160 stop = lambda o: self.typ.compare(o, le) > 0
161 else:
162 stop = lambda o: False
163 ret = self.cursor(self, cur, (k, v), stop)
164 if skip:
165 try:
166 while self.typ.compare(ret.peek()[0], gt) == 0:
167 next(ret)
168 except StopIteration:
169 pass
170 done = True
171 return ret
172 else:
173 raise NameError("invalid get() specification")
174 finally:
175 if not done:
176 cur.close()
177 except deadlock:
178 continue
179
180 def put(self, key, id):
181 while True:
182 try:
183 with db.txn(self.db.env.env) as tx:
184 obid = struct.pack(">Q", id)
185 if not self.db.ob.has_key(obid, txn=tx.tx):
186 raise ValueError("no such object in database: " + str(id))
187 try:
188 self.bk.put(self.typ.encode(key), obid, txn=tx.tx, flags=bd.DB_NODUPDATA)
189 except bd.DBKeyExistError:
190 return False
191 tx.commit()
192 return True
193 except deadlock:
194 continue
195
196 def remove(self, key, id):
197 while True:
198 try:
199 with db.txn(self.db.env.env) as tx:
200 obid = struct.pack(">Q", id)
201 if not self.db.ob.has_key(obid, txn=tx.tx):
202 raise ValueError("no such object in database: " + str(id))
203 cur = self.bk.cursor(txn=tx.tx)
204 try:
205 try:
206 cur.get_both(self.typ.encode(key), obid)
207 except notfound:
208 return False
209 cur.delete()
210 finally:
211 cur.close()
212 tx.commit()
213 return True
214 except deadlock:
215 continue