-import os, threading, select
+import sys, os, errno, threading, select, traceback
+
+class epoller(object):
+ exc_handler = None
-class pool(object):
def __init__(self):
- self.clients = set()
+ self.registered = {}
self.lock = threading.RLock()
+ self.ep = None
self.th = None
- self.ipipe = -1
-
- def add(self, cl):
- with self.lock:
- self.clients.add(cl)
- self._ckrun()
- cl.registered = self
- self._interrupt()
+ self._daemon = True
- def __iter__(self):
- with self.lock:
- return iter([cl for cl in self.clients if not cl.closed])
-
- def broadcast(self, data, eof=False):
- with self.lock:
- for cl in self:
- cl.obuf.extend(data)
- if eof:
- cl.closed = True
- self._interrupt()
+ @staticmethod
+ def _evsfor(ch):
+ return ((select.EPOLLIN if ch.readable else 0) |
+ (select.EPOLLOUT if ch.writable else 0))
def _ckrun(self):
- if self.clients and self.th is None:
- th = threading.Thread(target=self._run, name="Async watcher thread")
+ if self.registered and self.th is None:
+ th = threading.Thread(target=self._run, name="Async epoll thread")
+ th.daemon = self._daemon
th.start()
self.th = th
- def _interrupt(self):
- fd = self.ipipe
- if fd >= 0 and threading.current_thread() != self.th:
- os.write(fd, b"a")
+ def exception(self, ch, *exc):
+ self.remove(ch)
+ if self.exc_handler is None:
+ traceback.print_exception(exc)
+ else:
+ self.exc_handler(ch, *exc)
- def _remove(self, cl):
- self.clients.remove(cl)
- cl.registered = None
- cl._doclose()
+ def _cb(self, ch, nm):
+ try:
+ m = getattr(ch, nm, None)
+ if m is None:
+ raise AttributeError("%r has no %s method" % (ch, nm))
+ m()
+ except Exception as exc:
+ self.exception(ch, *sys.exc_info())
def _run(self):
- ipr, ipw = None, None
+ ep = select.epoll()
try:
- ipr, ipw = os.pipe()
- self.ipipe = ipw
- while True:
- with self.lock:
- for cl in list(self.clients):
- if cl.closed and not cl.writable:
- self._remove(cl)
- if not self.clients:
- break
- rsk = [cl for cl in self.clients if not cl.closed] + [ipr]
- wsk = [cl for cl in self.clients if cl.writable]
- # XXX: Switch to epoll.
- rsk, wsk, esk = select.select(rsk, wsk, [])
- for sk in rsk:
- if sk == ipr:
- os.read(ipr, 1024)
- elif sk in self.clients:
- sk._doread()
- for sk in wsk:
- if sk in self.clients:
- sk._dowrite()
+ with self.lock:
+ for fd, (ob, evs) in self.registered.items():
+ ep.register(fd, evs)
+ self.ep = ep
+
+ while self.registered:
+ try:
+ evlist = ep.poll(10)
+ except IOError as exc:
+ if exc.errno == errno.EINTR:
+ continue
+ raise
+ for fd, evs in evlist:
+ with self.lock:
+ if fd not in self.registered:
+ continue
+ ch, cevs = self.registered[fd]
+ if fd in self.registered and evs & (select.EPOLLIN | select.EPOLLHUP | select.EPOLLERR):
+ self._cb(ch, "read")
+ if fd in self.registered and evs & select.EPOLLOUT:
+ self._cb(ch, "write")
+ if fd in self.registered:
+ nevs = self._evsfor(ch)
+ if nevs == 0:
+ del self.registered[fd]
+ ep.unregister(fd)
+ self._cb(ch, "close")
+ elif nevs != cevs:
+ self.registered[fd] = ch, nevs
+ ep.modify(fd, nevs)
+
finally:
with self.lock:
self.th = None
- self.ipipe = -1
+ self.ep = None
self._ckrun()
- if ipr is not None:
- try: os.close(ipr)
- except: pass
- if ipw is not None:
- try: os.close(ipw)
- except: pass
-
-class client(object):
- pool = None
-
- def __init__(self, sock):
- self.sk = sock
+ ep.close()
+
+ @property
+ def daemon(self): return self._daemon
+ @daemon.setter
+ def daemon(self, value):
+ self._daemon = bool(value)
+ with self.lock:
+ if self.th is not None:
+ self.th = daemon = self._daemon
+
+ def add(self, ch):
+ with self.lock:
+ fd = ch.fileno()
+ if fd in self.registered:
+ raise KeyError("fd %i is already registered" % fd)
+ evs = self._evsfor(ch)
+ if evs == 0:
+ ch.close()
+ return
+ ch.watcher = self
+ self.registered[fd] = (ch, evs)
+ if self.ep:
+ self.ep.register(fd, evs)
+ self._ckrun()
+
+ def remove(self, ch, ignore=False):
+ with self.lock:
+ fd = ch.fileno()
+ if fd not in self.registered:
+ if ignore:
+ return
+ raise KeyError("fd %i is not registered" % fd)
+ pch, cevs = self.registered[fd]
+ if pch is not ch:
+ raise ValueError("fd %i registered via object %r, cannot remove with %r" % (pch, ch))
+ del self.registered[fd]
+ if self.ep:
+ self.ep.unregister(fd)
+ ch.close()
+
+ def update(self, ch, ignore=False):
+ with self.lock:
+ fd = ch.fileno()
+ if fd not in self.registered:
+ if ignore:
+ return
+ raise KeyError("fd %i is not registered" % fd)
+ pch, cevs = self.registered[fd]
+ if pch is not ch:
+ raise ValueError("fd %i registered via object %r, cannot update with %r" % (pch, ch))
+ evs = self._evsfor(ch)
+ if evs == 0:
+ del self.registered[fd]
+ if self.ep:
+ self.ep.unregister(fd)
+ ch.close()
+ elif evs != cevs:
+ self.registered[fd] = ch, evs
+ if self.ep:
+ self.ep.modify(fd, evs)
+
+def watcher():
+ return epoller()
+
+class sockbuffer(object):
+ def __init__(self, sk):
+ self.sk = sk
+ self.eof = False
self.obuf = bytearray()
- self.closed = False
- self.registered = None
- p = self.pool
- if p is not None:
- p.add(self)
+ self.watcher = None
def fileno(self):
return self.sk.fileno()
def close(self):
- self.closed = True
- if self.registered:
- self.registered._interrupt()
-
- def write(self, data):
- self.obuf.extend(data)
- if self.registered:
- self.registered._interrupt()
-
- @property
- def writable(self):
- return bool(self.obuf)
+ self.sk.close()
def gotdata(self, data):
if data == b"":
- self.close()
+ self.eof = True
- def _doread(self):
- try:
- ret = self.sk.recv(1024)
- except IOError:
- self.close()
- self.gotdata(ret)
+ def send(self, data, eof=False):
+ self.obuf.extend(data)
+ if eof:
+ self.eof = True
+ if self.watcher is not None:
+ self.watcher.update(self, True)
- def _dowrite(self):
+ @property
+ def readable(self):
+ return not self.eof
+ def read(self):
try:
- if self.obuf:
- ret = self.sk.send(self.obuf)
- self.obuf[:ret] = b""
+ data = self.sk.recv(1024)
+ self.gotdata(data)
except IOError:
- self.close()
+ self.obuf[:] = b""
+ self.eof = True
- def _doclose(self):
+ @property
+ def writable(self):
+ return bool(self.obuf);
+ def write(self):
try:
- self.sk.close()
+ ret = self.sk.send(self.obuf)
+ self.obuf[:ret] = b""
except IOError:
- pass
+ self.obuf[:] = b""
+ self.eof = True
+
+class callbuffer(object):
+ def __init__(self):
+ self.queue = []
+ self.rp, self.wp = os.pipe()
+ self.lock = threading.Lock()
+ self.eof = False
+
+ def fileno(self):
+ return self.rp
+
+ def close(self):
+ with self.lock:
+ try:
+ if self.wp >= 0:
+ os.close(self.wp)
+ self.wp = -1
+ finally:
+ if self.rp >= 0:
+ os.close(self.rp)
+ self.rp = -1
+
+ @property
+ def readable(self):
+ return not self.eof
+ def read(self):
+ with self.lock:
+ try:
+ data = os.read(self.rp, 1024)
+ if data == b"":
+ self.eof = True
+ except IOError:
+ self.eof = True
+ cbs = list(self.queue)
+ self.queue[:] = []
+ for cb in cbs:
+ cb()
+
+ writable = False
+
+ def call(self, cb):
+ with self.lock:
+ if self.wp < 0:
+ raise Exception("stopped")
+ self.queue.append(cb)
+ os.write(self.wp, b"a")
+
+ def stop(self):
+ with self.lock:
+ if self.wp >= 0:
+ os.close(self.wp)
+ self.wp = -1