Added convenience get() wrappers.
[didex.git] / didex / values.py
CommitLineData
b080a59c 1import threading
bd14729f 2from . import store, lib, index
b080a59c
FT
3from .store import storedesc
4
fed04312 5__all__ = ["simple", "multi", "compound"]
cbf73d3a 6
b080a59c
FT
7class cursor(lib.closable):
8 def __init__(self, bk, st):
9 self.bk = bk
10 self.st = st
11
12 def close(self):
13 self.bk.close()
14
15 def __iter__(self):
16 return self
17
18 def __next__(self):
19 k, id = next(self.bk)
20 return k, self.st.get(id)
21
22 def skip(self, n=1):
23 self.bk.skip(n)
24
25class base(storedesc):
bd14729f 26 def __init__(self, store, indextype, name, datatype):
b080a59c
FT
27 self.store = store
28 self.indextype = indextype
29 self.name = name
30 self.typ = datatype
b080a59c
FT
31 self.idx = None
32 self.lk = threading.Lock()
b080a59c
FT
33
34 def index(self):
35 with self.lk:
36 if self.idx is None:
37 self.idx = self.indextype(self.store.db(), self.name, self.typ)
38 return self.idx
39
bd14729f
FT
40 def get(self, **kwargs):
41 return cursor(self.index().get(**kwargs), self.store)
42
874b91d5
FT
43 def get1(self, *, check=False, default=KeyError, **kwargs):
44 with self.get(**kwargs) as cursor:
45 try:
46 k, v = next(cursor)
47 except StopIteration:
48 if default is not KeyError:
49 return default
50 raise KeyError("no matches in " + self.name, kwargs)
51 if check:
52 try:
53 next(cursor)
54 except StopIteration:
55 pass
56 else:
57 raise ValueError("unexpected multiple matchies in " + self.name, kwargs)
58 return v
59
60 def list(self, **kwargs):
61 with self.get(**kwargs) as cursor:
62 return [v for k, v in cursor]
63
bd14729f
FT
64class descbase(base):
65 def __init__(self, store, indextype, name, datatype, default):
66 super().__init__(store, indextype, name, datatype)
67 self.default = default
68 self.mattr = "__idx_%s_new" % name
69 self.iattr = "__idx_%s_cur" % name
70
b080a59c
FT
71 def __get__(self, obj, cls):
72 if obj is None: return self
73 return getattr(obj, self.mattr, self.default)
74
75 def __set__(self, obj, val):
76 setattr(obj, self.mattr, val)
77
78 def __delete__(self, obj):
79 delattr(obj, self.mattr)
80
bd14729f 81class simple(descbase):
b080a59c
FT
82 def __init__(self, store, indextype, name, datatype, default=None):
83 super().__init__(store, indextype, name, datatype, default)
84
85 def register(self, id, obj, tx):
86 val = self.__get__(obj, None)
87 self.index().put(val, id, tx=tx)
88 tx.postcommit(lambda: setattr(obj, self.iattr, val))
89
90 def unregister(self, id, obj, tx):
91 self.index().remove(getattr(obj, self.iattr), id, tx=tx)
92 tx.postcommit(lambda: delattr(obj, self.iattr))
93
94 def update(self, id, obj, tx):
95 val = self.__get__(obj, None)
96 ival = getattr(obj, self.iattr)
97 if val != ival:
98 idx = self.index()
99 idx.remove(ival, id, tx=tx)
100 idx.put(val, id, tx=tx)
101 tx.postcommit(lambda: setattr(obj, self.iattr, val))
102
bd14729f 103class multi(descbase):
b080a59c
FT
104 def __init__(self, store, indextype, name, datatype):
105 super().__init__(store, indextype, name, datatype, ())
106
107 def register(self, id, obj, tx):
108 vals = frozenset(self.__get__(obj, None))
109 idx = self.index()
110 for val in vals:
111 idx.put(val, id, tx=tx)
112 tx.postcommit(lambda: setattr(obj, self.iattr, vals))
113
114 def unregister(self, id, obj, tx):
115 idx = self.index()
116 for val in getattr(obj, self.iattr):
117 idx.remove(val, id, tx=tx)
118 tx.postcommit(lambda: delattr(obj, self.iattr))
119
120 def update(self, id, obj, tx):
121 vals = frozenset(self.__get__(obj, None))
122 ivals = getattr(obj, self.iattr)
123 if vals != ivals:
124 idx = self.index()
125 for val in ivals - vals:
126 idx.remove(val, id, tx=tx)
127 for val in vals - ivals:
128 idx.put(val, id, tx=tx)
129 tx.postcommit(lambda: setattr(obj, self.iattr, vals))
bd14729f
FT
130
131class compound(base):
132 def __init__(self, indextype, name, *parts):
133 super().__init__(parts[0].store, indextype, name, index.compound(*(part.typ for part in parts)))
134 self.parts = parts
135 self.iattr = "__idx_%s_cur" % name
136
177fbee6
FT
137 def minim(self, *parts):
138 return self.typ.minim(*parts)
139 def maxim(self, *parts):
140 return self.typ.maxim(*parts)
141
142 def get(self, *, partial=None, **spec):
143 if partial is not None:
144 return super().get(ge=self.minim(*partial), le = self.maxim(*partial), **spec)
145 else:
146 return super().get(**spec)
147
bd14729f
FT
148 def register(self, id, obj, tx):
149 val = tuple(part.__get__(obj, None) for part in self.parts)
150 self.index().put(val, id, tx=tx)
151 tx.postcommit(lambda: setattr(obj, self.iattr, val))
152
153 def unregister(self, id, obj, tx):
154 self.index().remove(getattr(obj, self.iattr), id, tx=tx)
155 tx.postcommit(lambda: delattr(obj, self.iattr))
156
157 def update(self, id, obj, tx):
158 val = tuple(part.__get__(obj, None) for part in self.parts)
159 ival = getattr(obj, self.iattr)
160 if val != ival:
161 idx = self.index()
162 idx.remove(ival, id, tx=tx)
163 idx.put(val, id, tx=tx)
164 tx.postcommit(lambda: setattr(obj, self.iattr, val))