Fixed multi-index bug.
[didex.git] / didex / values.py
1 import threading
2 from . import store, lib, index
3 from .store import storedesc
4
5 __all__ = ["simple", "multi", "compound", "idlink"]
6
7 class cursor(lib.closable):
8     def __init__(self, bk, st):
9         self.bk = bk
10         self.st = st
11
12     def close(self):
13         self.bk.close()
14
15     def __iter__(self):
16         return self
17
18     def __next__(self):
19         k, id = next(self.bk)
20         return k, self.st.get(id)
21
22     def skip(self, n=1):
23         self.bk.skip(n)
24
25 class base(storedesc):
26     def __init__(self, store, indextype, name, datatype):
27         self.store = store
28         self.indextype = indextype
29         self.name = name
30         self.typ = datatype
31         self.idx = None
32         self.lk = threading.Lock()
33
34     def index(self, tx):
35         with self.lk:
36             if self.idx is None:
37                 self.idx = self.indextype(self.store.db(), self.name, self.typ, tx=tx)
38             return self.idx
39
40     def get(self, **kwargs):
41         return cursor(self.index(None).get(**kwargs), self.store)
42
43     def get1(self, *, check=False, default=KeyError, **kwargs):
44         with self.get(**kwargs) as cursor:
45             try:
46                 k, v = next(cursor)
47             except StopIteration:
48                 if default is not KeyError:
49                     return default
50                 raise KeyError("no matches in " + self.name, kwargs)
51             if check:
52                 try:
53                     next(cursor)
54                 except StopIteration:
55                     pass
56                 else:
57                     raise ValueError("unexpected multiple matchies in " + self.name, kwargs)
58             return v
59
60     def list(self, **kwargs):
61         with self.get(**kwargs) as cursor:
62             return [v for k, v in cursor]
63
64 class descbase(base):
65     def __init__(self, store, indextype, name, datatype, default):
66         super().__init__(store, indextype, name, datatype)
67         self.default = default
68         self.mattr = "__ival_%s" % name
69
70     def __get__(self, obj, cls):
71         if obj is None: return self
72         return getattr(obj, self.mattr, self.default)
73
74     def __set__(self, obj, val):
75         setattr(obj, self.mattr, val)
76
77     def __delete__(self, obj):
78         delattr(obj, self.mattr)
79
80 class simple(descbase):
81     def __init__(self, store, indextype, name, datatype, default=None):
82         super().__init__(store, indextype, name, datatype, default)
83
84     def register(self, id, obj, tx):
85         val = self.__get__(obj, None)
86         self.index(tx).put(val, id, tx=tx)
87         tx.postcommit(lambda: self.store.icache.__setitem__((obj, self), val))
88
89     def unregister(self, id, obj, tx):
90         self.index(tx).remove(self.store.icache[obj, self], id, tx=tx)
91         tx.postcommit(lambda: self.store.icache.__delitem__((obj, self)))
92
93     def update(self, id, obj, tx):
94         val = self.__get__(obj, None)
95         ival = self.store.icache[obj, self]
96         if val != ival:
97             idx = self.index(tx)
98             idx.remove(ival, id, tx=tx)
99             idx.put(val, id, tx=tx)
100             tx.postcommit(lambda: self.store.icache.__setitem__((obj, self), val))
101
102     def loaded(self, id, obj, tx):
103         val = self.__get__(obj, None)
104         tx.postcommit(lambda: self.store.icache.__setitem__((obj, self), val))
105
106 class multi(descbase):
107     def __init__(self, store, indextype, name, datatype):
108         super().__init__(store, indextype, name, datatype, ())
109
110     def register(self, id, obj, tx):
111         vals = frozenset(self.__get__(obj, None))
112         idx = self.index(tx)
113         for val in vals:
114             idx.put(val, id, tx=tx)
115         tx.postcommit(lambda: self.store.icache.__setitem__((obj, self), vals))
116
117     def unregister(self, id, obj, tx):
118         idx = self.index(tx)
119         for val in self.store.icache[obj, self]:
120             idx.remove(val, id, tx=tx)
121         tx.postcommit(lambda: self.store.icache.__delitem__((obj, self)))
122
123     def update(self, id, obj, tx):
124         vals = frozenset(self.__get__(obj, None))
125         ivals = self.store.icache[obj, self]
126         if vals != ivals:
127             idx = self.index(tx)
128             for val in ivals - vals:
129                 idx.remove(val, id, tx=tx)
130             for val in vals - ivals:
131                 idx.put(val, id, tx=tx)
132             tx.postcommit(lambda: self.store.icache.__setitem__((obj, self), vals))
133
134     def loaded(self, id, obj, tx):
135         vals = frozenset(self.__get__(obj, None))
136         tx.postcommit(lambda: self.store.icache.__setitem__((obj, self), vals))
137
138 class compound(base):
139     def __init__(self, indextype, name, *parts):
140         super().__init__(parts[0].store, indextype, name, index.compound(*(part.typ for part in parts)))
141         self.parts = parts
142
143     def minim(self, *parts):
144         return self.typ.minim(*parts)
145     def maxim(self, *parts):
146         return self.typ.maxim(*parts)
147
148     def get(self, *, partial=None, **spec):
149         if partial is not None:
150             return super().get(ge=self.minim(*partial), le = self.maxim(*partial), **spec)
151         else:
152             return super().get(**spec)
153
154     def register(self, id, obj, tx):
155         val = tuple(part.__get__(obj, None) for part in self.parts)
156         self.index(tx).put(val, id, tx=tx)
157         tx.postcommit(lambda: self.store.icache.__setitem__((obj, self), val))
158
159     def unregister(self, id, obj, tx):
160         self.index(tx).remove(self.store.icache[obj, self], id, tx=tx)
161         tx.postcommit(lambda: self.store.icache.__delitem__((obj, self)))
162
163     def update(self, id, obj, tx):
164         val = tuple(part.__get__(obj, None) for part in self.parts)
165         ival = self.store.icache[obj, self]
166         if val != ival:
167             idx = self.index(tx)
168             idx.remove(ival, id, tx=tx)
169             idx.put(val, id, tx=tx)
170             tx.postcommit(lambda: self.store.icache.__setitem__((obj, self), val))
171
172     def loaded(self, id, obj, tx):
173         val = tuple(part.__get__(obj, None) for part in self.parts)
174         tx.postcommit(lambda: self.store.icache.__setitem__((obj, self), val))
175
176 class idlink(object):
177     def __init__(self, name, atype):
178         self.atype = atype
179         self.battr = "__idlink_%s" % name
180
181     def __get__(self, obj, cls):
182         if obj is None: return self
183         ret = self.atype.store.get(getattr(obj, self.battr))
184         assert isinstance(ret, self.atype)
185         return ret
186
187     def __set__(self, obj, val):
188         assert isinstance(val, self.atype)
189         setattr(obj, self.battr, val.id)
190
191     def __delete__(self, obj):
192         delattr(obj, self.battr)