Added idlink.
[didex.git] / didex / values.py
1 import threading
2 from . import store, lib, index
3 from .store import storedesc
4
5 __all__ = ["simple", "multi", "compound", "idlink"]
6
7 class cursor(lib.closable):
8     def __init__(self, bk, st):
9         self.bk = bk
10         self.st = st
11
12     def close(self):
13         self.bk.close()
14
15     def __iter__(self):
16         return self
17
18     def __next__(self):
19         k, id = next(self.bk)
20         return k, self.st.get(id)
21
22     def skip(self, n=1):
23         self.bk.skip(n)
24
25 class base(storedesc):
26     def __init__(self, store, indextype, name, datatype):
27         self.store = store
28         self.indextype = indextype
29         self.name = name
30         self.typ = datatype
31         self.idx = None
32         self.lk = threading.Lock()
33
34     def index(self, tx):
35         with self.lk:
36             if self.idx is None:
37                 self.idx = self.indextype(self.store.db(), self.name, self.typ, tx=tx)
38             return self.idx
39
40     def get(self, **kwargs):
41         return cursor(self.index(None).get(**kwargs), self.store)
42
43     def get1(self, *, check=False, default=KeyError, **kwargs):
44         with self.get(**kwargs) as cursor:
45             try:
46                 k, v = next(cursor)
47             except StopIteration:
48                 if default is not KeyError:
49                     return default
50                 raise KeyError("no matches in " + self.name, kwargs)
51             if check:
52                 try:
53                     next(cursor)
54                 except StopIteration:
55                     pass
56                 else:
57                     raise ValueError("unexpected multiple matchies in " + self.name, kwargs)
58             return v
59
60     def list(self, **kwargs):
61         with self.get(**kwargs) as cursor:
62             return [v for k, v in cursor]
63
64 class descbase(base):
65     def __init__(self, store, indextype, name, datatype, default):
66         super().__init__(store, indextype, name, datatype)
67         self.default = default
68         self.mattr = "__idx_%s_new" % name
69         self.iattr = "__idx_%s_cur" % name
70
71     def __get__(self, obj, cls):
72         if obj is None: return self
73         return getattr(obj, self.mattr, self.default)
74
75     def __set__(self, obj, val):
76         setattr(obj, self.mattr, val)
77
78     def __delete__(self, obj):
79         delattr(obj, self.mattr)
80
81 class simple(descbase):
82     def __init__(self, store, indextype, name, datatype, default=None):
83         super().__init__(store, indextype, name, datatype, default)
84
85     def register(self, id, obj, tx):
86         val = self.__get__(obj, None)
87         self.index(tx).put(val, id, tx=tx)
88         tx.postcommit(lambda: setattr(obj, self.iattr, val))
89
90     def unregister(self, id, obj, tx):
91         self.index(tx).remove(getattr(obj, self.iattr), id, tx=tx)
92         tx.postcommit(lambda: delattr(obj, self.iattr))
93
94     def update(self, id, obj, tx):
95         val = self.__get__(obj, None)
96         ival = getattr(obj, self.iattr)
97         if val != ival:
98             idx = self.index(tx)
99             idx.remove(ival, id, tx=tx)
100             idx.put(val, id, tx=tx)
101             tx.postcommit(lambda: setattr(obj, self.iattr, val))
102
103 class multi(descbase):
104     def __init__(self, store, indextype, name, datatype):
105         super().__init__(store, indextype, name, datatype, ())
106
107     def register(self, id, obj, tx):
108         vals = frozenset(self.__get__(obj, None))
109         idx = self.index(tx)
110         for val in vals:
111             idx.put(val, id, tx=tx)
112         tx.postcommit(lambda: setattr(obj, self.iattr, vals))
113
114     def unregister(self, id, obj, tx):
115         idx = self.index(tx)
116         for val in getattr(obj, self.iattr):
117             idx.remove(val, id, tx=tx)
118         tx.postcommit(lambda: delattr(obj, self.iattr))
119
120     def update(self, id, obj, tx):
121         vals = frozenset(self.__get__(obj, None))
122         ivals = getattr(obj, self.iattr)
123         if vals != ivals:
124             idx = self.index(tx)
125             for val in ivals - vals:
126                 idx.remove(val, id, tx=tx)
127             for val in vals - ivals:
128                 idx.put(val, id, tx=tx)
129             tx.postcommit(lambda: setattr(obj, self.iattr, vals))
130
131 class compound(base):
132     def __init__(self, indextype, name, *parts):
133         super().__init__(parts[0].store, indextype, name, index.compound(*(part.typ for part in parts)))
134         self.parts = parts
135         self.iattr = "__idx_%s_cur" % name
136
137     def minim(self, *parts):
138         return self.typ.minim(*parts)
139     def maxim(self, *parts):
140         return self.typ.maxim(*parts)
141
142     def get(self, *, partial=None, **spec):
143         if partial is not None:
144             return super().get(ge=self.minim(*partial), le = self.maxim(*partial), **spec)
145         else:
146             return super().get(**spec)
147
148     def register(self, id, obj, tx):
149         val = tuple(part.__get__(obj, None) for part in self.parts)
150         self.index(tx).put(val, id, tx=tx)
151         tx.postcommit(lambda: setattr(obj, self.iattr, val))
152
153     def unregister(self, id, obj, tx):
154         self.index(tx).remove(getattr(obj, self.iattr), id, tx=tx)
155         tx.postcommit(lambda: delattr(obj, self.iattr))
156
157     def update(self, id, obj, tx):
158         val = tuple(part.__get__(obj, None) for part in self.parts)
159         ival = getattr(obj, self.iattr)
160         if val != ival:
161             idx = self.index(tx)
162             idx.remove(ival, id, tx=tx)
163             idx.put(val, id, tx=tx)
164             tx.postcommit(lambda: setattr(obj, self.iattr, val))
165
166 class idlink(object):
167     def __init__(self, name, atype):
168         self.atype = atype
169         self.battr = "__idlink_%s" % name
170
171     def __get__(self, obj, cls):
172         if obj is None: return self
173         ret = self.atype.store.get(getattr(obj, self.battr))
174         assert isinstance(ret, self.atype)
175         return ret
176
177     def __set__(self, obj, val):
178         assert isinstance(val, self.atype)
179         setattr(obj, self.battr, val.id)
180
181     def __delete__(self, obj):
182         delattr(obj, self.battr)