Added more indexed types.
[didex.git] / didex / index.py
1 import struct, contextlib, math
2 from . import db, lib
3 from .db import bd, txnfun
4
5 deadlock = bd.DBLockDeadlockError
6 notfound = bd.DBNotFoundError
7
8 class simpletype(object):
9     def __init__(self, encode, decode):
10         self.enc = encode
11         self.dec = decode
12
13     def encode(self, ob):
14         return self.enc(ob)
15     def decode(self, dat):
16         return self.dec(dat)
17     def compare(self, a, b):
18         if a < b:
19             return -1
20         elif a > b:
21             return 1
22         else:
23             return 0
24
25     @classmethod
26     def struct(cls, fmt):
27         return cls(lambda ob: struct.pack(fmt, ob),
28                    lambda dat: struct.unpack(fmt, dat)[0])
29
30 class maybe(object):
31     def __init__(self, bk):
32         self.bk = bk
33
34     def encode(self, ob):
35         if ob is None: return b""
36         return b"\0" + self.bk.encode(ob)
37     def decode(self, dat):
38         if dat == b"": return None
39         return self.bk.dec(dat[1:])
40     def compare(self, a, b):
41         if a is b is None:
42             return 0
43         elif a is None:
44             return -1
45         elif b is None:
46             return 1
47         else:
48             return self.bk.compare(a[1:], b[1:])
49
50 def floatcmp(a, b):
51     if math.isnan(a) and math.isnan(b):
52         return 0
53     elif math.isnan(a):
54         return -1
55     elif math.isnan(b):
56         return 1
57     elif a < b:
58         return -1
59     elif a > b:
60         return 1
61     else:
62         return 0
63
64 t_int = simpletype.struct(">q")
65 t_uint = simpletype.struct(">Q")
66 t_float = simpletype.struct(">d")
67 t_float.compare = floatcmp
68 t_str = simpletype((lambda ob: ob.encode("utf-8")), (lambda dat: dat.decode("utf-8")))
69
70 class index(object):
71     def __init__(self, db, name, datatype):
72         self.db = db
73         self.nm = name
74         self.typ = datatype
75
76 missing = object()
77
78 class ordered(index, lib.closable):
79     def __init__(self, db, name, datatype, create=True):
80         super().__init__(db, name, datatype)
81         fl = bd.DB_THREAD | bd.DB_AUTO_COMMIT
82         if create: fl |= bd.DB_CREATE
83         def initdb(db):
84             def compare(a, b):
85                 if a == b == "": return 0
86                 return self.typ.compare(self.typ.decode(a), self.typ.decode(b))
87             db.set_flags(bd.DB_DUPSORT)
88             db.set_bt_compare(compare)
89         self.bk = db._opendb("i-" + name, bd.DB_BTREE, fl, initdb)
90         self.bk.set_get_returns_none(False)
91
92     def close(self):
93         self.bk.close()
94
95     class cursor(lib.closable):
96         def __init__(self, idx, cur, item, stop):
97             self.idx = idx
98             self.cur = cur
99             self.item = item
100             self.stop = stop
101
102         def close(self):
103             if self.cur is not None:
104                 self.cur.close()
105
106         def __iter__(self):
107             return self
108
109         def peek(self):
110             if self.item is None:
111                 raise StopIteration()
112             rk, rv = self.item
113             rk = self.idx.typ.decode(rk)
114             rv = struct.unpack(">Q", rv)[0]
115             if self.stop(rk):
116                 self.item = None
117                 raise StopIteration()
118             return rk, rv
119
120         def __next__(self):
121             rk, rv = self.peek()
122             try:
123                 while True:
124                     try:
125                         self.item = self.cur.next()
126                         break
127                     except deadlock:
128                         continue
129             except notfound:
130                 self.item = None
131             return rk, rv
132
133         def skip(self, n=1):
134             try:
135                 for i in range(n):
136                     next(self)
137             except StopIteration:
138                 return
139
140     def get(self, *, match=missing, ge=missing, gt=missing, lt=missing, le=missing, all=False):
141         while True:
142             try:
143                 cur = self.bk.cursor()
144                 done = False
145                 try:
146                     if match is not missing:
147                         try:
148                             k, v = cur.set(self.typ.encode(match))
149                         except notfound:
150                             return self.cursor(None, None, None, None)
151                         else:
152                             done = True
153                             return self.cursor(self, cur, (k, v), lambda o: (self.typ.compare(o, match) != 0))
154                     elif all:
155                         try:
156                             k, v = cur.first()
157                         except notfound:
158                             return self.cursor(None, None, None, None)
159                         else:
160                             done = True
161                             return self.cursor(self, cur, (k, v), lambda o: False)
162                     elif ge is not missing or gt is not missing or lt is not missing or le is not missing:
163                         skip = False
164                         try:
165                             if ge is not missing:
166                                 k, v = cur.set_range(self.typ.encode(ge))
167                             elif gt is not missing:
168                                 k, v = cur.set_range(self.typ.encode(gt))
169                                 skip = True
170                             else:
171                                 k, v = cur.first()
172                         except notfound:
173                             return self.cursor(None, None, None, None)
174                         if lt is not missing:
175                             stop = lambda o: self.typ.compare(o, lt) >= 0
176                         elif le is not missing:
177                             stop = lambda o: self.typ.compare(o, le) > 0
178                         else:
179                             stop = lambda o: False
180                         ret = self.cursor(self, cur, (k, v), stop)
181                         if skip:
182                             try:
183                                 while self.typ.compare(ret.peek()[0], gt) == 0:
184                                     next(ret)
185                             except StopIteration:
186                                 pass
187                         done = True
188                         return ret
189                     else:
190                         raise NameError("invalid get() specification")
191                 finally:
192                     if not done:
193                         cur.close()
194             except deadlock:
195                 continue
196
197     @txnfun(lambda self: self.db.env.env)
198     def put(self, key, id, *, tx):
199         obid = struct.pack(">Q", id)
200         if not self.db.ob.has_key(obid, txn=tx.tx):
201             raise ValueError("no such object in database: " + str(id))
202         try:
203             self.bk.put(self.typ.encode(key), obid, txn=tx.tx, flags=bd.DB_NODUPDATA)
204         except bd.DBKeyExistError:
205             return False
206         return True
207
208     @txnfun(lambda self: self.db.env.env)
209     def remove(self, key, id, *, tx):
210         obid = struct.pack(">Q", id)
211         if not self.db.ob.has_key(obid, txn=tx.tx):
212             raise ValueError("no such object in database: " + str(id))
213         cur = self.bk.cursor(txn=tx.tx)
214         try:
215             try:
216                 cur.get_both(self.typ.encode(key), obid)
217             except notfound:
218                 return False
219             cur.delete()
220         finally:
221             cur.close()
222         return True