1 import struct, contextlib, math
3 from .db import bd, txnfun
5 __all__ = ["maybe", "t_int", "t_uint", "t_float", "t_str", "ordered"]
7 deadlock = bd.DBLockDeadlockError
8 notfound = bd.DBNotFoundError
10 class simpletype(object):
11 def __init__(self, encode, decode):
17 def decode(self, dat):
19 def compare(self, a, b):
29 return cls(lambda ob: struct.pack(fmt, ob),
30 lambda dat: struct.unpack(fmt, dat)[0])
33 def __init__(self, bk):
37 if ob is None: return b""
38 return b"\0" + self.bk.encode(ob)
39 def decode(self, dat):
40 if dat == b"": return None
41 return self.bk.dec(dat[1:])
42 def compare(self, a, b):
50 return self.bk.compare(a[1:], b[1:])
52 class compound(object):
53 def __init__(self, *parts):
56 def encode(self, obs):
57 if len(obs) != len(self.parts):
58 raise ValueError("invalid length of compound data: " + str(len(obs)) + ", rather than " + len(self.parts))
60 for ob, part in zip(obs, self.parts):
63 buf.append(0x80 | len(dat))
66 buf.extend(struct.pack(">i", len(dat)))
69 def decode(self, dat):
72 for part in self.parts:
77 ln = struct.unpack(">i", dat[off:off + 4])[0]
79 ret.append(part.decode(dat[off:off + len]))
82 def compare(self, al, bl):
83 if (len(al) != len(self.parts)) or (len(bl) != len(self.parts)):
84 raise ValueError("invalid length of compound data: " + str(len(al)) + ", " + str(len(bl)) + ", rather than " + len(self.parts))
85 for a, b, part in zip(al, bl, self.parts):
86 c = part.compare(a, b)
92 if math.isnan(a) and math.isnan(b):
105 t_int = simpletype.struct(">q")
106 t_uint = simpletype.struct(">Q")
107 t_float = simpletype.struct(">d")
108 t_float.compare = floatcmp
109 t_str = simpletype((lambda ob: ob.encode("utf-8")), (lambda dat: dat.decode("utf-8")))
112 def __init__(self, db, name, datatype):
119 class ordered(index, lib.closable):
120 def __init__(self, db, name, datatype, create=True):
121 super().__init__(db, name, datatype)
122 fl = bd.DB_THREAD | bd.DB_AUTO_COMMIT
123 if create: fl |= bd.DB_CREATE
126 if a == b == "": return 0
127 return self.typ.compare(self.typ.decode(a), self.typ.decode(b))
128 db.set_flags(bd.DB_DUPSORT)
129 db.set_bt_compare(compare)
130 self.bk = db._opendb("i-" + name, bd.DB_BTREE, fl, initdb)
131 self.bk.set_get_returns_none(False)
136 class cursor(lib.closable):
137 def __init__(self, idx, cur, item, stop):
144 if self.cur is not None:
151 if self.item is None:
152 raise StopIteration()
154 rk = self.idx.typ.decode(rk)
155 rv = struct.unpack(">Q", rv)[0]
158 raise StopIteration()
166 self.item = self.cur.next()
178 except StopIteration:
181 def get(self, *, match=missing, ge=missing, gt=missing, lt=missing, le=missing, all=False):
184 cur = self.bk.cursor()
187 if match is not missing:
189 k, v = cur.set(self.typ.encode(match))
191 return self.cursor(None, None, None, None)
194 return self.cursor(self, cur, (k, v), lambda o: (self.typ.compare(o, match) != 0))
199 return self.cursor(None, None, None, None)
202 return self.cursor(self, cur, (k, v), lambda o: False)
203 elif ge is not missing or gt is not missing or lt is not missing or le is not missing:
206 if ge is not missing:
207 k, v = cur.set_range(self.typ.encode(ge))
208 elif gt is not missing:
209 k, v = cur.set_range(self.typ.encode(gt))
214 return self.cursor(None, None, None, None)
215 if lt is not missing:
216 stop = lambda o: self.typ.compare(o, lt) >= 0
217 elif le is not missing:
218 stop = lambda o: self.typ.compare(o, le) > 0
220 stop = lambda o: False
221 ret = self.cursor(self, cur, (k, v), stop)
224 while self.typ.compare(ret.peek()[0], gt) == 0:
226 except StopIteration:
231 raise NameError("invalid get() specification")
238 @txnfun(lambda self: self.db.env.env)
239 def put(self, key, id, *, tx):
240 obid = struct.pack(">Q", id)
241 if not self.db.ob.has_key(obid, txn=tx.tx):
242 raise ValueError("no such object in database: " + str(id))
244 self.bk.put(self.typ.encode(key), obid, txn=tx.tx, flags=bd.DB_NODUPDATA)
245 except bd.DBKeyExistError:
249 @txnfun(lambda self: self.db.env.env)
250 def remove(self, key, id, *, tx):
251 obid = struct.pack(">Q", id)
252 if not self.db.ob.has_key(obid, txn=tx.tx):
253 raise ValueError("no such object in database: " + str(id))
254 cur = self.bk.cursor(txn=tx.tx)
257 cur.get_both(self.typ.encode(key), obid)