Fixed all obvious byte/str errors.
[pdm.git] / pdm / srv.py
1 """Management for daemon processes
2
3 This module contains a utility to listen for management commands on a
4 socket, lending itself to managing daemon processes.
5 """
6
7 import os, sys, socket, threading, grp, select
8 import types, pprint, traceback
9 import pickle, struct
10
11 __all__ = ["listener", "unixlistener", "tcplistener", "listen"]
12
13 protocols = {}
14
15 class repl(object):
16     def __init__(self, cl):
17         self.cl = cl
18         self.mod = types.ModuleType("repl")
19         self.mod.echo = self.echo
20         self.printer = pprint.PrettyPrinter(indent = 4, depth = 6)
21         cl.send(b"+REPL\n")
22
23     def sendlines(self, text):
24         for line in text.split("\n"):
25             self.cl.send(b" " + line.encode("utf-8") + b"\n")
26
27     def echo(self, ob):
28         self.sendlines(self.printer.pformat(ob))
29
30     def command(self, cmd):
31         cmd = cmd.decode("utf-8")
32         try:
33             try:
34                 ccode = compile(cmd, "PDM Input", "eval")
35             except SyntaxError:
36                 ccode = compile(cmd, "PDM Input", "exec")
37                 exec(ccode, self.mod.__dict__)
38                 self.cl.send(b"+OK\n")
39             else:
40                 self.echo(eval(ccode, self.mod.__dict__))
41                 self.cl.send(b"+OK\n")
42         except:
43             for line in traceback.format_exception(*sys.exc_info()):
44                 self.cl.send(b" " + line.encode("utf-8"))
45             self.cl.send(b"+EXC\n")
46
47     def handle(self, buf):
48         p = buf.find(b"\n\n")
49         if p < 0:
50             return buf
51         cmd = buf[:p + 1]
52         self.command(cmd)
53         return buf[p + 2:]
54 protocols["repl"] = repl
55
56 class perf(object):
57     def __init__(self, cl):
58         self.cl = cl
59         self.odtab = {}
60         cl.send(b"+PERF1\n")
61         self.buf = ""
62         self.lock = threading.Lock()
63         self.subscribed = {}
64
65     def closed(self):
66         for id, recv in self.subscribed.items():
67             ob = self.odtab[id]
68             if ob is None: continue
69             ob, protos = ob
70             try:
71                 ob.unsubscribe(recv)
72             except: pass
73
74     def send(self, *args):
75         self.lock.acquire()
76         try:
77             buf = pickle.dumps(args)
78             buf = struct.pack(">l", len(buf)) + buf
79             self.cl.send(buf)
80         finally:
81             self.lock.release()
82
83     def bindob(self, id, ob):
84         if not hasattr(ob, "pdm_protocols"):
85             raise ValueError("Object does not support PDM introspection")
86         try:
87             proto = ob.pdm_protocols()
88         except Exception as exc:
89             raise ValueError("PDM introspection failed", exc)
90         self.odtab[id] = ob, proto
91         return proto
92
93     def bind(self, id, module, obnm):
94         resmod = sys.modules.get(module)
95         if resmod is None:
96             self.send("-", ImportError("No such module: %s" % module))
97             return
98         try:
99             ob = getattr(resmod, obnm)
100         except AttributeError:
101             self.send("-", AttributeError("No such object: %s" % obnm))
102             return
103         try:
104             proto = self.bindob(id, ob)
105         except Exception as exc:
106             self.send("-", exc)
107             return
108         self.send("+", proto)
109
110     def getob(self, id, proto):
111         ob = self.odtab.get(id)
112         if ob is None:
113             self.send("-", ValueError("No such bound ID: %r" % id))
114             return None
115         ob, protos = ob
116         if proto not in protos:
117             self.send("-", ValueError("Object does not support that protocol"))
118             return None
119         return ob
120
121     def lookup(self, tgtid, srcid, obnm):
122         src = self.getob(srcid, "dir")
123         if src is None:
124             return
125         try:
126             ob = src.lookup(obnm)
127         except KeyError as exc:
128             self.send("-", exc)
129             return
130         try:
131             proto = self.bindob(tgtid, ob)
132         except Exception as exc:
133             self.send("-", exc)
134             return
135         self.send("+", proto)
136
137     def unbind(self, id):
138         ob = self.odtab.get(id)
139         if ob is None:
140             self.send("-", KeyError("No such name bound: %r" % id))
141             return
142         ob, protos = ob
143         del self.odtab[id]
144         recv = self.subscribed.get(id)
145         if recv is not None:
146             ob.unsubscribe(recv)
147             del self.subscribed[id]
148         self.send("+")
149
150     def listdir(self, id):
151         ob = self.getob(id, "dir")
152         if ob is None:
153             return
154         self.send("+", ob.listdir())
155
156     def readattr(self, id):
157         ob = self.getob(id, "attr")
158         if ob is None:
159             return
160         try:
161             ret = ob.readattr()
162         except Exception as exc:
163             self.send("-", Exception("Could not read attribute"))
164             return
165         self.send("+", ret)
166
167     def attrinfo(self, id):
168         ob = self.getob(id, "attr")
169         if ob is None:
170             return
171         self.send("+", ob.attrinfo())
172
173     def invoke(self, id, method, args, kwargs):
174         ob = self.getob(id, "invoke")
175         if ob is None:
176             return
177         try:
178             self.send("+", ob.invoke(method, *args, **kwargs))
179         except Exception as exc:
180             self.send("-", exc)
181
182     def event(self, id, ob, ev):
183         self.send("*", id, ev)
184
185     def subscribe(self, id):
186         ob = self.getob(id, "event")
187         if ob is None:
188             return
189         if id in self.subscribed:
190             self.send("-", ValueError("Already subscribed"))
191         def recv(ev):
192             self.event(id, ob, ev)
193         ob.subscribe(recv)
194         self.subscribed[id] = recv
195         self.send("+")
196
197     def unsubscribe(self, id):
198         ob = self.getob(id, "event")
199         if ob is None:
200             return
201         recv = self.subscribed.get(id)
202         if recv is None:
203             self.send("-", ValueError("Not subscribed"))
204         ob.unsubscribe(recv)
205         del self.subscribed[id]
206         self.send("+")
207
208     def command(self, data):
209         cmd = data[0]
210         if cmd == "bind":
211             self.bind(*data[1:])
212         elif cmd == "unbind":
213             self.unbind(*data[1:])
214         elif cmd == "lookup":
215             self.lookup(*data[1:])
216         elif cmd == "ls":
217             self.listdir(*data[1:])
218         elif cmd == "readattr":
219             self.readattr(*data[1:])
220         elif cmd == "attrinfo":
221             self.attrinfo(*data[1:])
222         elif cmd == "invoke":
223             self.invoke(*data[1:])
224         elif cmd == "subs":
225             self.subscribe(*data[1:])
226         elif cmd == "unsubs":
227             self.unsubscribe(*data[1:])
228         else:
229             self.send("-", Exception("Unknown command: %r" % (cmd,)))
230
231     def handle(self, buf):
232         if len(buf) < 4:
233             return buf
234         dlen = struct.unpack(">l", buf[:4])[0]
235         if len(buf) < dlen + 4:
236             return buf
237         data = pickle.loads(buf[4:dlen + 4])
238         self.command(data)
239         return buf[dlen + 4:]
240         
241 protocols["perf"] = perf
242
243 class client(threading.Thread):
244     def __init__(self, sk):
245         super(client, self).__init__(name = "Management client")
246         self.setDaemon(True)
247         self.sk = sk
248         self.handler = self
249
250     def send(self, data):
251         return self.sk.send(data)
252
253     def choose(self, proto):
254         try:
255             proto = proto.decode("ascii")
256         except UnicodeError:
257             proto = None
258         if proto in protocols:
259             self.handler = protocols[proto](self)
260         else:
261             self.send("-ERR Unknown protocol: %s\n" % proto)
262             raise Exception()
263
264     def handle(self, buf):
265         p = buf.find(b"\n")
266         if p >= 0:
267             proto = buf[:p]
268             buf = buf[p + 1:]
269             self.choose(proto)
270         return buf
271
272     def run(self):
273         try:
274             buf = b""
275             self.send(b"+PDM1\n")
276             while True:
277                 ret = self.sk.recv(1024)
278                 if ret == b"":
279                     return
280                 buf += ret
281                 while True:
282                     try:
283                         nbuf = self.handler.handle(buf)
284                     except:
285                         #for line in traceback.format_exception(*sys.exc_info()):
286                         #    print(line)
287                         return
288                     if nbuf == buf:
289                         break
290                     buf = nbuf
291         finally:
292             try:
293                 self.sk.close()
294             finally:
295                 if hasattr(self.handler, "closed"):
296                     self.handler.closed()
297             
298
299 class listener(threading.Thread):
300     def __init__(self):
301         super(listener, self).__init__(name = "Management listener")
302         self.setDaemon(True)
303
304     def listen(self, sk):
305         self.running = True
306         while self.running:
307             rfd, wfd, efd = select.select([sk], [], [sk], 1)
308             for fd in rfd:
309                 if fd == sk:
310                     nsk, addr = sk.accept()
311                     self.accept(nsk, addr)
312
313     def stop(self):
314         self.running = False
315         self.join()
316
317     def accept(self, sk, addr):
318         cl = client(sk)
319         cl.start()
320
321 class unixlistener(listener):
322     def __init__(self, name, mode = 0o600, group = None):
323         super(unixlistener, self).__init__()
324         self.name = name
325         self.mode = mode
326         self.group = group
327
328     def run(self):
329         sk = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
330         ul = False
331         try:
332             if os.path.exists(self.name) and os.path.stat.S_ISSOCK(os.stat(self.name).st_mode):
333                 os.unlink(self.name)
334             sk.bind(self.name)
335             ul = True
336             os.chmod(self.name, self.mode)
337             if self.group is not None:
338                 os.chown(self.name, os.getuid(), grp.getgrnam(self.group).gr_gid)
339             sk.listen(16)
340             self.listen(sk)
341         finally:
342             sk.close()
343             if ul:
344                 os.unlink(self.name)
345
346 class tcplistener(listener):
347     def __init__(self, port, bindaddr = "127.0.0.1"):
348         super(tcplistener, self).__init__()
349         self.port = port
350         self.bindaddr = bindaddr
351
352     def run(self):
353         sk = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
354         try:
355             sk.bind((self.bindaddr, self.port))
356             sk.listen(16)
357             self.listen(sk)
358         finally:
359             sk.close()
360
361 def listen(spec):
362     if ":" in spec:
363         first = spec[:spec.index(":")]
364         last = spec[spec.rindex(":") + 1:]
365     else:
366         first = spec
367         last = spec
368     if "/" in first:
369         parts = spec.split(":")
370         mode = 0o600
371         group = None
372         if len(parts) > 1:
373             mode = int(parts[1], 8)
374         if len(parts) > 2:
375             group = parts[2]
376         ret = unixlistener(parts[0], mode = mode, group = group)
377         ret.start()
378         return ret
379     if last.isdigit():
380         p = spec.rindex(":")
381         host = spec[:p]
382         port = int(spec[p + 1:])
383         ret = tcplistener(port, bindaddr = host)
384         ret.start()
385         return ret
386     raise ValueError("Unparsable listener specification: %r" % spec)
387
388 import pdm.perf