a2828cedd421e6e8d4a7de8cc976296fd956058b
[didex.git] / didex / index.py
1 import struct, contextlib, math
2 from . import db, lib
3 from .db import bd, txnfun
4
5 __all__ = ["maybe", "t_int", "t_uint", "t_float", "t_str", "ordered"]
6
7 deadlock = bd.DBLockDeadlockError
8 notfound = bd.DBNotFoundError
9
10 class simpletype(object):
11     def __init__(self, encode, decode):
12         self.enc = encode
13         self.dec = decode
14
15     def encode(self, ob):
16         return self.enc(ob)
17     def decode(self, dat):
18         return self.dec(dat)
19     def compare(self, a, b):
20         if a < b:
21             return -1
22         elif a > b:
23             return 1
24         else:
25             return 0
26
27     @classmethod
28     def struct(cls, fmt):
29         return cls(lambda ob: struct.pack(fmt, ob),
30                    lambda dat: struct.unpack(fmt, dat)[0])
31
32 class maybe(object):
33     def __init__(self, bk):
34         self.bk = bk
35
36     def encode(self, ob):
37         if ob is None: return b""
38         return b"\0" + self.bk.encode(ob)
39     def decode(self, dat):
40         if dat == b"": return None
41         return self.bk.dec(dat[1:])
42     def compare(self, a, b):
43         if a is b is None:
44             return 0
45         elif a is None:
46             return -1
47         elif b is None:
48             return 1
49         else:
50             return self.bk.compare(a[1:], b[1:])
51
52 def floatcmp(a, b):
53     if math.isnan(a) and math.isnan(b):
54         return 0
55     elif math.isnan(a):
56         return -1
57     elif math.isnan(b):
58         return 1
59     elif a < b:
60         return -1
61     elif a > b:
62         return 1
63     else:
64         return 0
65
66 t_int = simpletype.struct(">q")
67 t_uint = simpletype.struct(">Q")
68 t_float = simpletype.struct(">d")
69 t_float.compare = floatcmp
70 t_str = simpletype((lambda ob: ob.encode("utf-8")), (lambda dat: dat.decode("utf-8")))
71
72 class index(object):
73     def __init__(self, db, name, datatype):
74         self.db = db
75         self.nm = name
76         self.typ = datatype
77
78 missing = object()
79
80 class ordered(index, lib.closable):
81     def __init__(self, db, name, datatype, create=True):
82         super().__init__(db, name, datatype)
83         fl = bd.DB_THREAD | bd.DB_AUTO_COMMIT
84         if create: fl |= bd.DB_CREATE
85         def initdb(db):
86             def compare(a, b):
87                 if a == b == "": return 0
88                 return self.typ.compare(self.typ.decode(a), self.typ.decode(b))
89             db.set_flags(bd.DB_DUPSORT)
90             db.set_bt_compare(compare)
91         self.bk = db._opendb("i-" + name, bd.DB_BTREE, fl, initdb)
92         self.bk.set_get_returns_none(False)
93
94     def close(self):
95         self.bk.close()
96
97     class cursor(lib.closable):
98         def __init__(self, idx, cur, item, stop):
99             self.idx = idx
100             self.cur = cur
101             self.item = item
102             self.stop = stop
103
104         def close(self):
105             if self.cur is not None:
106                 self.cur.close()
107
108         def __iter__(self):
109             return self
110
111         def peek(self):
112             if self.item is None:
113                 raise StopIteration()
114             rk, rv = self.item
115             rk = self.idx.typ.decode(rk)
116             rv = struct.unpack(">Q", rv)[0]
117             if self.stop(rk):
118                 self.item = None
119                 raise StopIteration()
120             return rk, rv
121
122         def __next__(self):
123             rk, rv = self.peek()
124             try:
125                 while True:
126                     try:
127                         self.item = self.cur.next()
128                         break
129                     except deadlock:
130                         continue
131             except notfound:
132                 self.item = None
133             return rk, rv
134
135         def skip(self, n=1):
136             try:
137                 for i in range(n):
138                     next(self)
139             except StopIteration:
140                 return
141
142     def get(self, *, match=missing, ge=missing, gt=missing, lt=missing, le=missing, all=False):
143         while True:
144             try:
145                 cur = self.bk.cursor()
146                 done = False
147                 try:
148                     if match is not missing:
149                         try:
150                             k, v = cur.set(self.typ.encode(match))
151                         except notfound:
152                             return self.cursor(None, None, None, None)
153                         else:
154                             done = True
155                             return self.cursor(self, cur, (k, v), lambda o: (self.typ.compare(o, match) != 0))
156                     elif all:
157                         try:
158                             k, v = cur.first()
159                         except notfound:
160                             return self.cursor(None, None, None, None)
161                         else:
162                             done = True
163                             return self.cursor(self, cur, (k, v), lambda o: False)
164                     elif ge is not missing or gt is not missing or lt is not missing or le is not missing:
165                         skip = False
166                         try:
167                             if ge is not missing:
168                                 k, v = cur.set_range(self.typ.encode(ge))
169                             elif gt is not missing:
170                                 k, v = cur.set_range(self.typ.encode(gt))
171                                 skip = True
172                             else:
173                                 k, v = cur.first()
174                         except notfound:
175                             return self.cursor(None, None, None, None)
176                         if lt is not missing:
177                             stop = lambda o: self.typ.compare(o, lt) >= 0
178                         elif le is not missing:
179                             stop = lambda o: self.typ.compare(o, le) > 0
180                         else:
181                             stop = lambda o: False
182                         ret = self.cursor(self, cur, (k, v), stop)
183                         if skip:
184                             try:
185                                 while self.typ.compare(ret.peek()[0], gt) == 0:
186                                     next(ret)
187                             except StopIteration:
188                                 pass
189                         done = True
190                         return ret
191                     else:
192                         raise NameError("invalid get() specification")
193                 finally:
194                     if not done:
195                         cur.close()
196             except deadlock:
197                 continue
198
199     @txnfun(lambda self: self.db.env.env)
200     def put(self, key, id, *, tx):
201         obid = struct.pack(">Q", id)
202         if not self.db.ob.has_key(obid, txn=tx.tx):
203             raise ValueError("no such object in database: " + str(id))
204         try:
205             self.bk.put(self.typ.encode(key), obid, txn=tx.tx, flags=bd.DB_NODUPDATA)
206         except bd.DBKeyExistError:
207             return False
208         return True
209
210     @txnfun(lambda self: self.db.env.env)
211     def remove(self, key, id, *, tx):
212         obid = struct.pack(">Q", id)
213         if not self.db.ob.has_key(obid, txn=tx.tx):
214             raise ValueError("no such object in database: " + str(id))
215         cur = self.bk.cursor(txn=tx.tx)
216         try:
217             try:
218                 cur.get_both(self.typ.encode(key), obid)
219             except notfound:
220                 return False
221             cur.delete()
222         finally:
223             cur.close()
224         return True