Commit | Line | Data |
---|---|---|
abb94f83 | 1 | import struct, contextlib, math |
a95055e8 | 2 | from . import db, lib |
8950191c | 3 | from .db import bd, txnfun |
a95055e8 | 4 | |
cbf73d3a FT |
5 | __all__ = ["maybe", "t_int", "t_uint", "t_float", "t_str", "ordered"] |
6 | ||
a95055e8 FT |
7 | deadlock = bd.DBLockDeadlockError |
8 | notfound = bd.DBNotFoundError | |
9 | ||
10 | class simpletype(object): | |
11 | def __init__(self, encode, decode): | |
12 | self.enc = encode | |
13 | self.dec = decode | |
14 | ||
15 | def encode(self, ob): | |
16 | return self.enc(ob) | |
17 | def decode(self, dat): | |
18 | return self.dec(dat) | |
19 | def compare(self, a, b): | |
20 | if a < b: | |
21 | return -1 | |
22 | elif a > b: | |
23 | return 1 | |
24 | else: | |
25 | return 0 | |
26 | ||
27 | @classmethod | |
28 | def struct(cls, fmt): | |
29 | return cls(lambda ob: struct.pack(fmt, ob), | |
30 | lambda dat: struct.unpack(fmt, dat)[0]) | |
31 | ||
32 | class maybe(object): | |
33 | def __init__(self, bk): | |
34 | self.bk = bk | |
35 | ||
36 | def encode(self, ob): | |
37 | if ob is None: return b"" | |
38 | return b"\0" + self.bk.encode(ob) | |
39 | def decode(self, dat): | |
40 | if dat == b"": return None | |
41 | return self.bk.dec(dat[1:]) | |
42 | def compare(self, a, b): | |
43 | if a is b is None: | |
44 | return 0 | |
45 | elif a is None: | |
46 | return -1 | |
47 | elif b is None: | |
48 | return 1 | |
49 | else: | |
50 | return self.bk.compare(a[1:], b[1:]) | |
51 | ||
abb94f83 FT |
52 | def floatcmp(a, b): |
53 | if math.isnan(a) and math.isnan(b): | |
54 | return 0 | |
55 | elif math.isnan(a): | |
56 | return -1 | |
57 | elif math.isnan(b): | |
58 | return 1 | |
59 | elif a < b: | |
60 | return -1 | |
61 | elif a > b: | |
62 | return 1 | |
63 | else: | |
64 | return 0 | |
65 | ||
66 | t_int = simpletype.struct(">q") | |
67 | t_uint = simpletype.struct(">Q") | |
68 | t_float = simpletype.struct(">d") | |
69 | t_float.compare = floatcmp | |
70 | t_str = simpletype((lambda ob: ob.encode("utf-8")), (lambda dat: dat.decode("utf-8"))) | |
a95055e8 FT |
71 | |
72 | class index(object): | |
73 | def __init__(self, db, name, datatype): | |
74 | self.db = db | |
75 | self.nm = name | |
76 | self.typ = datatype | |
77 | ||
78 | missing = object() | |
79 | ||
80 | class ordered(index, lib.closable): | |
eb274691 | 81 | def __init__(self, db, name, datatype, create=True): |
a95055e8 | 82 | super().__init__(db, name, datatype) |
a95055e8 FT |
83 | fl = bd.DB_THREAD | bd.DB_AUTO_COMMIT |
84 | if create: fl |= bd.DB_CREATE | |
85 | def initdb(db): | |
86 | def compare(a, b): | |
87 | if a == b == "": return 0 | |
88 | return self.typ.compare(self.typ.decode(a), self.typ.decode(b)) | |
89 | db.set_flags(bd.DB_DUPSORT) | |
90 | db.set_bt_compare(compare) | |
91 | self.bk = db._opendb("i-" + name, bd.DB_BTREE, fl, initdb) | |
92 | self.bk.set_get_returns_none(False) | |
93 | ||
94 | def close(self): | |
95 | self.bk.close() | |
96 | ||
97 | class cursor(lib.closable): | |
98 | def __init__(self, idx, cur, item, stop): | |
99 | self.idx = idx | |
100 | self.cur = cur | |
101 | self.item = item | |
102 | self.stop = stop | |
103 | ||
104 | def close(self): | |
105 | if self.cur is not None: | |
106 | self.cur.close() | |
107 | ||
108 | def __iter__(self): | |
109 | return self | |
110 | ||
111 | def peek(self): | |
112 | if self.item is None: | |
113 | raise StopIteration() | |
114 | rk, rv = self.item | |
115 | rk = self.idx.typ.decode(rk) | |
116 | rv = struct.unpack(">Q", rv)[0] | |
117 | if self.stop(rk): | |
118 | self.item = None | |
119 | raise StopIteration() | |
120 | return rk, rv | |
121 | ||
122 | def __next__(self): | |
123 | rk, rv = self.peek() | |
124 | try: | |
125 | while True: | |
126 | try: | |
127 | self.item = self.cur.next() | |
128 | break | |
129 | except deadlock: | |
130 | continue | |
131 | except notfound: | |
132 | self.item = None | |
133 | return rk, rv | |
134 | ||
135 | def skip(self, n=1): | |
136 | try: | |
137 | for i in range(n): | |
138 | next(self) | |
139 | except StopIteration: | |
140 | return | |
141 | ||
142 | def get(self, *, match=missing, ge=missing, gt=missing, lt=missing, le=missing, all=False): | |
143 | while True: | |
144 | try: | |
145 | cur = self.bk.cursor() | |
146 | done = False | |
147 | try: | |
148 | if match is not missing: | |
149 | try: | |
150 | k, v = cur.set(self.typ.encode(match)) | |
151 | except notfound: | |
152 | return self.cursor(None, None, None, None) | |
153 | else: | |
154 | done = True | |
155 | return self.cursor(self, cur, (k, v), lambda o: (self.typ.compare(o, match) != 0)) | |
156 | elif all: | |
157 | try: | |
158 | k, v = cur.first() | |
159 | except notfound: | |
160 | return self.cursor(None, None, None, None) | |
161 | else: | |
162 | done = True | |
163 | return self.cursor(self, cur, (k, v), lambda o: False) | |
164 | elif ge is not missing or gt is not missing or lt is not missing or le is not missing: | |
165 | skip = False | |
166 | try: | |
167 | if ge is not missing: | |
168 | k, v = cur.set_range(self.typ.encode(ge)) | |
169 | elif gt is not missing: | |
170 | k, v = cur.set_range(self.typ.encode(gt)) | |
171 | skip = True | |
172 | else: | |
173 | k, v = cur.first() | |
174 | except notfound: | |
175 | return self.cursor(None, None, None, None) | |
176 | if lt is not missing: | |
177 | stop = lambda o: self.typ.compare(o, lt) >= 0 | |
178 | elif le is not missing: | |
179 | stop = lambda o: self.typ.compare(o, le) > 0 | |
180 | else: | |
181 | stop = lambda o: False | |
182 | ret = self.cursor(self, cur, (k, v), stop) | |
183 | if skip: | |
184 | try: | |
185 | while self.typ.compare(ret.peek()[0], gt) == 0: | |
186 | next(ret) | |
187 | except StopIteration: | |
188 | pass | |
189 | done = True | |
190 | return ret | |
191 | else: | |
192 | raise NameError("invalid get() specification") | |
193 | finally: | |
194 | if not done: | |
195 | cur.close() | |
196 | except deadlock: | |
197 | continue | |
198 | ||
8950191c FT |
199 | @txnfun(lambda self: self.db.env.env) |
200 | def put(self, key, id, *, tx): | |
201 | obid = struct.pack(">Q", id) | |
202 | if not self.db.ob.has_key(obid, txn=tx.tx): | |
203 | raise ValueError("no such object in database: " + str(id)) | |
204 | try: | |
205 | self.bk.put(self.typ.encode(key), obid, txn=tx.tx, flags=bd.DB_NODUPDATA) | |
206 | except bd.DBKeyExistError: | |
207 | return False | |
208 | return True | |
209 | ||
210 | @txnfun(lambda self: self.db.env.env) | |
211 | def remove(self, key, id, *, tx): | |
212 | obid = struct.pack(">Q", id) | |
213 | if not self.db.ob.has_key(obid, txn=tx.tx): | |
214 | raise ValueError("no such object in database: " + str(id)) | |
215 | cur = self.bk.cursor(txn=tx.tx) | |
216 | try: | |
a95055e8 | 217 | try: |
8950191c FT |
218 | cur.get_both(self.typ.encode(key), obid) |
219 | except notfound: | |
220 | return False | |
221 | cur.delete() | |
222 | finally: | |
223 | cur.close() | |
224 | return True |