Improved sshsock error handling somewhat.
[pdm.git] / pdm / sshsock.py
1 import sys, os
2 import subprocess, socket, fcntl, select
3
4 class sshsocket(object):
5     def __init__(self, host, path, user = None, port = None):
6         args = ["ssh"]
7         if user is not None:
8             args += ["-u", str(user)]
9         if port is not None:
10             args += ["-p", str(int(port))]
11         args += [host]
12         args += ["python", "-m", "pdm.sshsock", path]
13         self.proc = subprocess.Popen(args, stdin=subprocess.PIPE, stdout=subprocess.PIPE, close_fds=True)
14         fcntl.fcntl(self.proc.stdout, fcntl.F_SETFL, fcntl.fcntl(self.proc.stdout, fcntl.F_GETFL) | os.O_NONBLOCK)
15         head = self.recv(5)
16         if head != "SSOCK":
17             raise socket.error("unexpected reply from %s: %r" % (host, head))
18         head = self.recv(1)
19         if head == "+":
20             return
21         elif head == "-":
22             buf = ""
23             while True:
24                 r = self.recv(1)
25                 if r in ("\n", ""):
26                     break
27                 buf += r
28             raise socket.error(buf)
29         else:
30             raise socket.error("unexpected reply from %s: %r" % (host, head))
31
32     def close(self):
33         if self.proc is not None:
34             self.proc.stdin.close()
35             self.proc.stdout.close()
36             self.proc.wait()
37             self.proc = None
38
39     def send(self, data, flags = 0):
40         self.proc.stdin.write(data)
41         return len(data)
42
43     def recv(self, buflen, flags = 0):
44         if (flags & socket.MSG_DONTWAIT) == 0:
45             select.select([self.proc.stdout], [], [])
46         return self.proc.stdout.read(buflen)
47
48     def fileno(self):
49         return self.proc.stdout.fileno()
50
51     def __del__(self):
52         self.close()
53
54 def cli():
55     fcntl.fcntl(sys.stdin, fcntl.F_SETFL, fcntl.fcntl(sys.stdin, fcntl.F_GETFL) | os.O_NONBLOCK)
56     sk = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
57     try:
58         try:
59             sk.connect(sys.argv[1])
60         except socket.error as err:
61             sys.stdout.write("SSOCK-connect: %s\n" % err)
62             return
63         sys.stdout.write("SSOCK+\n")
64         buf1 = ""
65         buf2 = ""
66         while True:
67             wfd = []
68             if buf1: wfd.append(sk)
69             if buf2: wfd.append(sys.stdout)
70             rfd, wfd, efd = select.select([sk, sys.stdin], wfd, [])
71             if sk in rfd:
72                 ret = sk.recv(65536)
73                 if ret == "":
74                     break
75                 else:
76                     buf2 += ret
77             if sys.stdin in rfd:
78                 ret = sys.stdin.read()
79                 if ret == "":
80                     break
81                 else:
82                     buf1 = ret
83             if sk in wfd:
84                 ret = sk.send(buf1)
85                 buf1 = buf1[ret:]
86             if sys.stdout in wfd:
87                 sys.stdout.write(buf2)
88                 sys.stdout.flush()
89                 buf2 = ""
90     finally:
91         sk.close()
92
93 if __name__ == "__main__":
94     cli()