Added support for compound indices.
[didex.git] / didex / values.py
1 import threading
2 from . import store, lib, index
3 from .store import storedesc
4
5 __all__ = ["simple", "multi"]
6
7 class cursor(lib.closable):
8     def __init__(self, bk, st):
9         self.bk = bk
10         self.st = st
11
12     def close(self):
13         self.bk.close()
14
15     def __iter__(self):
16         return self
17
18     def __next__(self):
19         k, id = next(self.bk)
20         return k, self.st.get(id)
21
22     def skip(self, n=1):
23         self.bk.skip(n)
24
25 class base(storedesc):
26     def __init__(self, store, indextype, name, datatype):
27         self.store = store
28         self.indextype = indextype
29         self.name = name
30         self.typ = datatype
31         self.idx = None
32         self.lk = threading.Lock()
33
34     def index(self):
35         with self.lk:
36             if self.idx is None:
37                 self.idx = self.indextype(self.store.db(), self.name, self.typ)
38             return self.idx
39
40     def get(self, **kwargs):
41         return cursor(self.index().get(**kwargs), self.store)
42
43 class descbase(base):
44     def __init__(self, store, indextype, name, datatype, default):
45         super().__init__(store, indextype, name, datatype)
46         self.default = default
47         self.mattr = "__idx_%s_new" % name
48         self.iattr = "__idx_%s_cur" % name
49
50     def __get__(self, obj, cls):
51         if obj is None: return self
52         return getattr(obj, self.mattr, self.default)
53
54     def __set__(self, obj, val):
55         setattr(obj, self.mattr, val)
56
57     def __delete__(self, obj):
58         delattr(obj, self.mattr)
59
60 class simple(descbase):
61     def __init__(self, store, indextype, name, datatype, default=None):
62         super().__init__(store, indextype, name, datatype, default)
63
64     def register(self, id, obj, tx):
65         val = self.__get__(obj, None)
66         self.index().put(val, id, tx=tx)
67         tx.postcommit(lambda: setattr(obj, self.iattr, val))
68
69     def unregister(self, id, obj, tx):
70         self.index().remove(getattr(obj, self.iattr), id, tx=tx)
71         tx.postcommit(lambda: delattr(obj, self.iattr))
72
73     def update(self, id, obj, tx):
74         val = self.__get__(obj, None)
75         ival = getattr(obj, self.iattr)
76         if val != ival:
77             idx = self.index()
78             idx.remove(ival, id, tx=tx)
79             idx.put(val, id, tx=tx)
80             tx.postcommit(lambda: setattr(obj, self.iattr, val))
81
82 class multi(descbase):
83     def __init__(self, store, indextype, name, datatype):
84         super().__init__(store, indextype, name, datatype, ())
85
86     def register(self, id, obj, tx):
87         vals = frozenset(self.__get__(obj, None))
88         idx = self.index()
89         for val in vals:
90             idx.put(val, id, tx=tx)
91         tx.postcommit(lambda: setattr(obj, self.iattr, vals))
92
93     def unregister(self, id, obj, tx):
94         idx = self.index()
95         for val in getattr(obj, self.iattr):
96             idx.remove(val, id, tx=tx)
97         tx.postcommit(lambda: delattr(obj, self.iattr))
98
99     def update(self, id, obj, tx):
100         vals = frozenset(self.__get__(obj, None))
101         ivals = getattr(obj, self.iattr)
102         if vals != ivals:
103             idx = self.index()
104             for val in ivals - vals:
105                 idx.remove(val, id, tx=tx)
106             for val in vals - ivals:
107                 idx.put(val, id, tx=tx)
108             tx.postcommit(lambda: setattr(obj, self.iattr, vals))
109
110 class compound(base):
111     def __init__(self, indextype, name, *parts):
112         super().__init__(parts[0].store, indextype, name, index.compound(*(part.typ for part in parts)))
113         self.parts = parts
114         self.iattr = "__idx_%s_cur" % name
115
116     def register(self, id, obj, tx):
117         val = tuple(part.__get__(obj, None) for part in self.parts)
118         self.index().put(val, id, tx=tx)
119         tx.postcommit(lambda: setattr(obj, self.iattr, val))
120
121     def unregister(self, id, obj, tx):
122         self.index().remove(getattr(obj, self.iattr), id, tx=tx)
123         tx.postcommit(lambda: delattr(obj, self.iattr))
124
125     def update(self, id, obj, tx):
126         val = tuple(part.__get__(obj, None) for part in self.parts)
127         ival = getattr(obj, self.iattr)
128         if val != ival:
129             idx = self.index()
130             idx.remove(ival, id, tx=tx)
131             idx.put(val, id, tx=tx)
132             tx.postcommit(lambda: setattr(obj, self.iattr, val))