From 73761d103993f8ec26bb61d1d405edde9efb5ddd Mon Sep 17 00:00:00 2001 From: Fredrik Tolf Date: Mon, 3 Aug 2015 02:39:14 +0200 Subject: [PATCH] Fixed index opening thread-safety by including it in local transaction. --- didex/db.py | 13 +++++-------- didex/index.py | 6 +++--- didex/values.py | 24 ++++++++++++------------ 3 files changed, 20 insertions(+), 23 deletions(-) diff --git a/didex/db.py b/didex/db.py index 33eb0d8..539fc85 100644 --- a/didex/db.py +++ b/didex/db.py @@ -125,21 +125,18 @@ class database(object): self.env = env self.mode = mode self.fnm = name - fl = bd.DB_THREAD | bd.DB_AUTO_COMMIT + fl = bd.DB_THREAD if create: fl |= bd.DB_CREATE self.cf = self._opendb("cf", bd.DB_HASH, fl) self.ob = self._opendb("ob", bd.DB_HASH, fl) - def _opendb(self, dnm, typ, fl, init=None): + @txnfun(lambda self: self.env.env) + def _opendb(self, dnm, typ, fl, init=None, *, tx): ret = bd.DB(self.env.env) if init: init(ret) - while True: - try: - ret.open(self.fnm, dnm, typ, fl, self.mode) - except deadlock: - continue - return ret + ret.open(self.fnm, dnm, typ, fl, self.mode, txn=tx.tx) + return ret @txnfun(lambda self: self.env.env) def _nextseq(self, *, tx): diff --git a/didex/index.py b/didex/index.py index c2c55b6..5b5a5bc 100644 --- a/didex/index.py +++ b/didex/index.py @@ -158,9 +158,9 @@ class index(object): missing = object() class ordered(index, lib.closable): - def __init__(self, db, name, datatype, create=True): + def __init__(self, db, name, datatype, create=True, *, tx=None): super().__init__(db, name, datatype) - fl = bd.DB_THREAD | bd.DB_AUTO_COMMIT + fl = bd.DB_THREAD if create: fl |= bd.DB_CREATE def initdb(db): def compare(a, b): @@ -168,7 +168,7 @@ class ordered(index, lib.closable): return self.typ.compare(self.typ.decode(a), self.typ.decode(b)) db.set_flags(bd.DB_DUPSORT) db.set_bt_compare(compare) - self.bk = db._opendb("i-" + name, bd.DB_BTREE, fl, initdb) + self.bk = db._opendb("i-" + name, bd.DB_BTREE, fl, initdb, tx=tx) self.bk.set_get_returns_none(False) def close(self): diff --git a/didex/values.py b/didex/values.py index ecb78dd..a4473f0 100644 --- a/didex/values.py +++ b/didex/values.py @@ -31,14 +31,14 @@ class base(storedesc): self.idx = None self.lk = threading.Lock() - def index(self): + def index(self, tx): with self.lk: if self.idx is None: - self.idx = self.indextype(self.store.db(), self.name, self.typ) + self.idx = self.indextype(self.store.db(), self.name, self.typ, tx=tx) return self.idx def get(self, **kwargs): - return cursor(self.index().get(**kwargs), self.store) + return cursor(self.index(None).get(**kwargs), self.store) def get1(self, *, check=False, default=KeyError, **kwargs): with self.get(**kwargs) as cursor: @@ -84,18 +84,18 @@ class simple(descbase): def register(self, id, obj, tx): val = self.__get__(obj, None) - self.index().put(val, id, tx=tx) + self.index(tx).put(val, id, tx=tx) tx.postcommit(lambda: setattr(obj, self.iattr, val)) def unregister(self, id, obj, tx): - self.index().remove(getattr(obj, self.iattr), id, tx=tx) + self.index(tx).remove(getattr(obj, self.iattr), id, tx=tx) tx.postcommit(lambda: delattr(obj, self.iattr)) def update(self, id, obj, tx): val = self.__get__(obj, None) ival = getattr(obj, self.iattr) if val != ival: - idx = self.index() + idx = self.index(tx) idx.remove(ival, id, tx=tx) idx.put(val, id, tx=tx) tx.postcommit(lambda: setattr(obj, self.iattr, val)) @@ -106,13 +106,13 @@ class multi(descbase): def register(self, id, obj, tx): vals = frozenset(self.__get__(obj, None)) - idx = self.index() + idx = self.index(tx) for val in vals: idx.put(val, id, tx=tx) tx.postcommit(lambda: setattr(obj, self.iattr, vals)) def unregister(self, id, obj, tx): - idx = self.index() + idx = self.index(tx) for val in getattr(obj, self.iattr): idx.remove(val, id, tx=tx) tx.postcommit(lambda: delattr(obj, self.iattr)) @@ -121,7 +121,7 @@ class multi(descbase): vals = frozenset(self.__get__(obj, None)) ivals = getattr(obj, self.iattr) if vals != ivals: - idx = self.index() + idx = self.index(tx) for val in ivals - vals: idx.remove(val, id, tx=tx) for val in vals - ivals: @@ -147,18 +147,18 @@ class compound(base): def register(self, id, obj, tx): val = tuple(part.__get__(obj, None) for part in self.parts) - self.index().put(val, id, tx=tx) + self.index(tx).put(val, id, tx=tx) tx.postcommit(lambda: setattr(obj, self.iattr, val)) def unregister(self, id, obj, tx): - self.index().remove(getattr(obj, self.iattr), id, tx=tx) + self.index(tx).remove(getattr(obj, self.iattr), id, tx=tx) tx.postcommit(lambda: delattr(obj, self.iattr)) def update(self, id, obj, tx): val = tuple(part.__get__(obj, None) for part in self.parts) ival = getattr(obj, self.iattr) if val != ival: - idx = self.index() + idx = self.index(tx) idx.remove(ival, id, tx=tx) idx.put(val, id, tx=tx) tx.postcommit(lambda: setattr(obj, self.iattr, val)) -- 2.11.0