5d206ffc7c4bba0e6982fddd2d259289aef5e696
[didex.git] / didex / index.py
1 import struct, contextlib
2 from . import db, lib
3 from .db import bd
4
5 deadlock = bd.DBLockDeadlockError
6 notfound = bd.DBNotFoundError
7
8 class simpletype(object):
9     def __init__(self, encode, decode):
10         self.enc = encode
11         self.dec = decode
12
13     def encode(self, ob):
14         return self.enc(ob)
15     def decode(self, dat):
16         return self.dec(dat)
17     def compare(self, a, b):
18         if a < b:
19             return -1
20         elif a > b:
21             return 1
22         else:
23             return 0
24
25     @classmethod
26     def struct(cls, fmt):
27         return cls(lambda ob: struct.pack(fmt, ob),
28                    lambda dat: struct.unpack(fmt, dat)[0])
29
30 class maybe(object):
31     def __init__(self, bk):
32         self.bk = bk
33
34     def encode(self, ob):
35         if ob is None: return b""
36         return b"\0" + self.bk.encode(ob)
37     def decode(self, dat):
38         if dat == b"": return None
39         return self.bk.dec(dat[1:])
40     def compare(self, a, b):
41         if a is b is None:
42             return 0
43         elif a is None:
44             return -1
45         elif b is None:
46             return 1
47         else:
48             return self.bk.compare(a[1:], b[1:])
49
50 t_int = simpletype.struct(">Q")
51
52 class index(object):
53     def __init__(self, db, name, datatype):
54         self.db = db
55         self.nm = name
56         self.typ = datatype
57
58 missing = object()
59
60 class ordered(index, lib.closable):
61     def __init__(self, db, name, datatype, duplicates, create=True):
62         super().__init__(db, name, datatype)
63         self.dup = duplicates
64         fl = bd.DB_THREAD | bd.DB_AUTO_COMMIT
65         if create: fl |= bd.DB_CREATE
66         def initdb(db):
67             def compare(a, b):
68                 if a == b == "": return 0
69                 return self.typ.compare(self.typ.decode(a), self.typ.decode(b))
70             db.set_flags(bd.DB_DUPSORT)
71             db.set_bt_compare(compare)
72         self.bk = db._opendb("i-" + name, bd.DB_BTREE, fl, initdb)
73         self.bk.set_get_returns_none(False)
74
75     def close(self):
76         self.bk.close()
77
78     class cursor(lib.closable):
79         def __init__(self, idx, cur, item, stop):
80             self.idx = idx
81             self.cur = cur
82             self.item = item
83             self.stop = stop
84
85         def close(self):
86             if self.cur is not None:
87                 self.cur.close()
88
89         def __iter__(self):
90             return self
91
92         def peek(self):
93             if self.item is None:
94                 raise StopIteration()
95             rk, rv = self.item
96             rk = self.idx.typ.decode(rk)
97             rv = struct.unpack(">Q", rv)[0]
98             if self.stop(rk):
99                 self.item = None
100                 raise StopIteration()
101             return rk, rv
102
103         def __next__(self):
104             rk, rv = self.peek()
105             try:
106                 while True:
107                     try:
108                         self.item = self.cur.next()
109                         break
110                     except deadlock:
111                         continue
112             except notfound:
113                 self.item = None
114             return rk, rv
115
116         def skip(self, n=1):
117             try:
118                 for i in range(n):
119                     next(self)
120             except StopIteration:
121                 return
122
123     def get(self, *, match=missing, ge=missing, gt=missing, lt=missing, le=missing, all=False):
124         while True:
125             try:
126                 cur = self.bk.cursor()
127                 done = False
128                 try:
129                     if match is not missing:
130                         try:
131                             k, v = cur.set(self.typ.encode(match))
132                         except notfound:
133                             return self.cursor(None, None, None, None)
134                         else:
135                             done = True
136                             return self.cursor(self, cur, (k, v), lambda o: (self.typ.compare(o, match) != 0))
137                     elif all:
138                         try:
139                             k, v = cur.first()
140                         except notfound:
141                             return self.cursor(None, None, None, None)
142                         else:
143                             done = True
144                             return self.cursor(self, cur, (k, v), lambda o: False)
145                     elif ge is not missing or gt is not missing or lt is not missing or le is not missing:
146                         skip = False
147                         try:
148                             if ge is not missing:
149                                 k, v = cur.set_range(self.typ.encode(ge))
150                             elif gt is not missing:
151                                 k, v = cur.set_range(self.typ.encode(gt))
152                                 skip = True
153                             else:
154                                 k, v = cur.first()
155                         except notfound:
156                             return self.cursor(None, None, None, None)
157                         if lt is not missing:
158                             stop = lambda o: self.typ.compare(o, lt) >= 0
159                         elif le is not missing:
160                             stop = lambda o: self.typ.compare(o, le) > 0
161                         else:
162                             stop = lambda o: False
163                         ret = self.cursor(self, cur, (k, v), stop)
164                         if skip:
165                             try:
166                                 while self.typ.compare(ret.peek()[0], gt) == 0:
167                                     next(ret)
168                             except StopIteration:
169                                 pass
170                         done = True
171                         return ret
172                     else:
173                         raise NameError("invalid get() specification")
174                 finally:
175                     if not done:
176                         cur.close()
177             except deadlock:
178                 continue
179
180     def put(self, key, id):
181         while True:
182             try:
183                 with db.txn(self.db.env.env) as tx:
184                     obid = struct.pack(">Q", id)
185                     if not self.db.ob.has_key(obid, txn=tx.tx):
186                         raise ValueError("no such object in database: " + str(id))
187                     try:
188                         self.bk.put(self.typ.encode(key), obid, txn=tx.tx, flags=bd.DB_NODUPDATA)
189                     except bd.DBKeyExistError:
190                         return False
191                     tx.commit()
192                     return True
193             except deadlock:
194                 continue
195
196     def remove(self, key, id):
197         while True:
198             try:
199                 with db.txn(self.db.env.env) as tx:
200                     obid = struct.pack(">Q", id)
201                     if not self.db.ob.has_key(obid, txn=tx.tx):
202                         raise ValueError("no such object in database: " + str(id))
203                     cur = self.bk.cursor(txn=tx.tx)
204                     try:
205                         try:
206                             cur.get_both(self.typ.encode(key), obid)
207                         except notfound:
208                             return False
209                         cur.delete()
210                     finally:
211                         cur.close()
212                     tx.commit()
213                     return True
214             except deadlock:
215                 continue