Fixed cache bug.
[didex.git] / didex / index.py
1 import struct, contextlib
2 from . import db, lib
3 from .db import bd, txnfun
4
5 deadlock = bd.DBLockDeadlockError
6 notfound = bd.DBNotFoundError
7
8 class simpletype(object):
9     def __init__(self, encode, decode):
10         self.enc = encode
11         self.dec = decode
12
13     def encode(self, ob):
14         return self.enc(ob)
15     def decode(self, dat):
16         return self.dec(dat)
17     def compare(self, a, b):
18         if a < b:
19             return -1
20         elif a > b:
21             return 1
22         else:
23             return 0
24
25     @classmethod
26     def struct(cls, fmt):
27         return cls(lambda ob: struct.pack(fmt, ob),
28                    lambda dat: struct.unpack(fmt, dat)[0])
29
30 class maybe(object):
31     def __init__(self, bk):
32         self.bk = bk
33
34     def encode(self, ob):
35         if ob is None: return b""
36         return b"\0" + self.bk.encode(ob)
37     def decode(self, dat):
38         if dat == b"": return None
39         return self.bk.dec(dat[1:])
40     def compare(self, a, b):
41         if a is b is None:
42             return 0
43         elif a is None:
44             return -1
45         elif b is None:
46             return 1
47         else:
48             return self.bk.compare(a[1:], b[1:])
49
50 t_int = simpletype.struct(">Q")
51
52 class index(object):
53     def __init__(self, db, name, datatype):
54         self.db = db
55         self.nm = name
56         self.typ = datatype
57
58 missing = object()
59
60 class ordered(index, lib.closable):
61     def __init__(self, db, name, datatype, create=True):
62         super().__init__(db, name, datatype)
63         fl = bd.DB_THREAD | bd.DB_AUTO_COMMIT
64         if create: fl |= bd.DB_CREATE
65         def initdb(db):
66             def compare(a, b):
67                 if a == b == "": return 0
68                 return self.typ.compare(self.typ.decode(a), self.typ.decode(b))
69             db.set_flags(bd.DB_DUPSORT)
70             db.set_bt_compare(compare)
71         self.bk = db._opendb("i-" + name, bd.DB_BTREE, fl, initdb)
72         self.bk.set_get_returns_none(False)
73
74     def close(self):
75         self.bk.close()
76
77     class cursor(lib.closable):
78         def __init__(self, idx, cur, item, stop):
79             self.idx = idx
80             self.cur = cur
81             self.item = item
82             self.stop = stop
83
84         def close(self):
85             if self.cur is not None:
86                 self.cur.close()
87
88         def __iter__(self):
89             return self
90
91         def peek(self):
92             if self.item is None:
93                 raise StopIteration()
94             rk, rv = self.item
95             rk = self.idx.typ.decode(rk)
96             rv = struct.unpack(">Q", rv)[0]
97             if self.stop(rk):
98                 self.item = None
99                 raise StopIteration()
100             return rk, rv
101
102         def __next__(self):
103             rk, rv = self.peek()
104             try:
105                 while True:
106                     try:
107                         self.item = self.cur.next()
108                         break
109                     except deadlock:
110                         continue
111             except notfound:
112                 self.item = None
113             return rk, rv
114
115         def skip(self, n=1):
116             try:
117                 for i in range(n):
118                     next(self)
119             except StopIteration:
120                 return
121
122     def get(self, *, match=missing, ge=missing, gt=missing, lt=missing, le=missing, all=False):
123         while True:
124             try:
125                 cur = self.bk.cursor()
126                 done = False
127                 try:
128                     if match is not missing:
129                         try:
130                             k, v = cur.set(self.typ.encode(match))
131                         except notfound:
132                             return self.cursor(None, None, None, None)
133                         else:
134                             done = True
135                             return self.cursor(self, cur, (k, v), lambda o: (self.typ.compare(o, match) != 0))
136                     elif all:
137                         try:
138                             k, v = cur.first()
139                         except notfound:
140                             return self.cursor(None, None, None, None)
141                         else:
142                             done = True
143                             return self.cursor(self, cur, (k, v), lambda o: False)
144                     elif ge is not missing or gt is not missing or lt is not missing or le is not missing:
145                         skip = False
146                         try:
147                             if ge is not missing:
148                                 k, v = cur.set_range(self.typ.encode(ge))
149                             elif gt is not missing:
150                                 k, v = cur.set_range(self.typ.encode(gt))
151                                 skip = True
152                             else:
153                                 k, v = cur.first()
154                         except notfound:
155                             return self.cursor(None, None, None, None)
156                         if lt is not missing:
157                             stop = lambda o: self.typ.compare(o, lt) >= 0
158                         elif le is not missing:
159                             stop = lambda o: self.typ.compare(o, le) > 0
160                         else:
161                             stop = lambda o: False
162                         ret = self.cursor(self, cur, (k, v), stop)
163                         if skip:
164                             try:
165                                 while self.typ.compare(ret.peek()[0], gt) == 0:
166                                     next(ret)
167                             except StopIteration:
168                                 pass
169                         done = True
170                         return ret
171                     else:
172                         raise NameError("invalid get() specification")
173                 finally:
174                     if not done:
175                         cur.close()
176             except deadlock:
177                 continue
178
179     @txnfun(lambda self: self.db.env.env)
180     def put(self, key, id, *, tx):
181         obid = struct.pack(">Q", id)
182         if not self.db.ob.has_key(obid, txn=tx.tx):
183             raise ValueError("no such object in database: " + str(id))
184         try:
185             self.bk.put(self.typ.encode(key), obid, txn=tx.tx, flags=bd.DB_NODUPDATA)
186         except bd.DBKeyExistError:
187             return False
188         return True
189
190     @txnfun(lambda self: self.db.env.env)
191     def remove(self, key, id, *, tx):
192         obid = struct.pack(">Q", id)
193         if not self.db.ob.has_key(obid, txn=tx.tx):
194             raise ValueError("no such object in database: " + str(id))
195         cur = self.bk.cursor(txn=tx.tx)
196         try:
197             try:
198                 cur.get_both(self.typ.encode(key), obid)
199             except notfound:
200                 return False
201             cur.delete()
202         finally:
203             cur.close()
204         return True