Initial import
[ldd.git] / ldd / server.py
1 #    ldd - DNS implementation in Python
2 #    Copyright (C) 2006 Fredrik Tolf <fredrik@dolda2000.com>
3 #
4 #    This program is free software; you can redistribute it and/or modify
5 #    it under the terms of the GNU General Public License as published by
6 #    the Free Software Foundation; either version 2 of the License, or
7 #    (at your option) any later version.
8 #
9 #    This program is distributed in the hope that it will be useful,
10 #    but WITHOUT ANY WARRANTY; without even the implied warranty of
11 #    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12 #    GNU General Public License for more details.
13 #
14 #    You should have received a copy of the GNU General Public License
15 #    along with this program; if not, write to the Free Software
16 #    Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA  02110-1301  USA
17
18 import socket
19 import threading
20 import select
21 import errno
22 import logging
23 import time
24
25 import proto
26 import rec
27 import dn
28 import resolver
29
30 logger = logging.getLogger("ldd.server")
31
32 class dnsserver:
33     class socklistener(threading.Thread):
34         def __init__(self, server):
35             threading.Thread.__init__(self)
36             self.server = server
37             self.alive = True
38
39         class sender:
40             def __init__(self, addr, sk):
41                 self.addr = addr
42                 self.sk = sk
43
44             def send(self, pkt):
45                 logger.debug("sending response to %04x", pkt.qid)
46                 self.sk.sendto(pkt.encode(), self.addr)
47
48         def run(self):
49             while self.alive:
50                 p = select.poll()
51                 for af, sk in self.server.sockets:
52                     p.register(sk.fileno(), select.POLLIN)
53                 try:
54                     fds = p.poll(1000)
55                 except select.error, e:
56                     if e[0] == errno.EINTR:
57                         continue
58                     raise
59                 for fd, event in fds:
60                     if event & select.POLLIN == 0:
61                         continue
62                     for af, sk in self.server.sockets:
63                         if sk.fileno() == fd:
64                             break
65                     else:
66                         continue
67                     req, addr = sk.recvfrom(65536)
68                     try:
69                         pkt = proto.decodepacket(req)
70                     except proto.malformedpacket, inst:
71                         resp = proto.packet(inst.qid, ["resp"])
72                         resp.rescode = proto.FORMERR
73                         sk.sendto(resp.encode(), addr)
74                     else:
75                         logger.debug("got request (%04x) from %s", pkt.qid, addr[0])
76                         pkt.addr = (af,) + addr
77                         self.server.queuereq(pkt, dnsserver.socklistener.sender(addr, sk))
78
79         def stop(self):
80             self.alive = False
81
82     class dispatcher(threading.Thread):
83         def __init__(self, server):
84             threading.Thread.__init__(self)
85             self.server = server
86             self.alive = True
87
88         def run(self):
89             while self.alive:
90                 req = self.server.dequeuereq()
91                 if req is not None:
92                     pkt, sender = req
93                     resp = self.server.handle(pkt)
94                     if resp is None:
95                         resp = proto.responsefor(pkt, proto.SERVFAIL)
96                     sender.send(resp)
97
98     class queuemonitor(threading.Thread):
99         def __init__(self, server):
100             threading.Thread.__init__(self)
101             self.server = server
102
103         def run(self):
104             while(self.server.running):
105                 self.server.queuelock.acquire()
106                 if len(self.server.queue) > 0:
107                     peeked = self.server.queue[0]
108                 else:
109                     peeked = None
110                 self.server.queuelock.release()
111                 if peeked is not None:
112                     if time.time() - peeked[0] > 1:
113                         newdsp = dnsserver.dispatcher(self.server)
114                         self.server.dispatchers += [newdsp]
115                         newdsp.start()
116                         logger.debug("starting new dispatcher, there are now %i", len(self.server.dispatchers))
117                 time.sleep(1)
118
119     def __init__(self):
120         self.sockets = []
121         self.queue = []
122         self.zones = []
123         self.listener = None
124         self.dispatchers = []
125         self.running = False
126         self.queuelock = threading.Condition()
127         self.knownkeys = []
128
129     def handle(self, pkt):
130         resp = None
131
132         if len(self.knownkeys) > 0:
133             import dnssec
134             dnssec.tsigverify(pkt, self.knownkeys)
135         
136         for query in pkt.qlist:
137             match = None
138             for zone in self.zones:
139                 if query.name in zone.origin:
140                     if match is None:
141                         match = zone
142                     elif len(zone.origin) > len(match.origin):
143                         match = zone
144             if match is None:
145                 return None
146             else:
147                 curresp = match.handle(query, pkt)
148                 if resp is None:
149                     resp = curresp
150                 else:
151                     resp.merge(curresp)
152         
153         if resp is not None and resp.tsigctx is not None and not resp.signed:
154             resp.tsigctx.signpkt(resp)
155
156         return resp
157
158     def addsock(self, af, socket):
159         self.sockets += [(af, socket)]
160
161     def addzone(self, zone):
162         self.zones += [zone]
163
164     def queuereq(self, req, sender):
165         self.queuelock.acquire()
166         self.queue += [(time.time(), req, sender)]
167         logger.debug("queue length+: %i", len(self.queue))
168         self.queuelock.notify()
169         self.queuelock.release()
170
171     def dequeuereq(self):
172         self.queuelock.acquire()
173         if len(self.queue) == 0:
174             self.queuelock.wait()
175         if len(self.queue) > 0:
176             ret = self.queue[0]
177             self.queue = self.queue[1:]
178         else:
179             ret = None
180         logger.debug("queue length-: %i", len(self.queue))
181         self.queuelock.release()
182         if ret is None:
183             return None
184         else:
185             return ret[1:]
186
187     def start(self):
188         if self.running:
189             raise Exception("already running")
190         lst = dnsserver.socklistener(self)
191         self.listener = lst
192         lst.start()
193         for i in xrange(10):
194             newdsp = dnsserver.dispatcher(self)
195             self.dispatchers += [newdsp]
196             newdsp.start()
197         self.running = True
198         self.monitor = dnsserver.queuemonitor(self)
199         self.monitor.start()
200
201     def stop(self):
202         self.listener.stop()
203         self.listener.join()
204         self.listener = None
205         for dsp in self.dispatchers:
206             dsp.alive = False
207         self.queuelock.acquire()
208         self.queuelock.notifyAll()
209         self.queuelock.release()
210         for dsp in self.dispatchers + []:
211             dsp.join()
212             self.dispatchers.remove(dsp)
213         self.running = False
214         self.monitor = None
215
216     def resolver(self, addr = None):
217         class myres(resolver.resolver):
218             def __init__(self, server, addr):
219                 self.server = server
220                 self.addr = addr
221             def resolve(self, packet):
222                 if self.addr is not None:
223                     packet.addr = self.addr
224                 packet.setflags(["internal"])
225                 return self.server.handle(packet)
226         return myres(self, addr)
227
228 class zone:
229     def __init__(self, origin, handler):
230         if type(origin) == str:
231             self.origin = dn.fromstring(origin)
232             self.origin.rooted = True
233         else:
234             self.origin = origin
235         self.handler = handler
236
237     def handle(self, query, pkt):
238         resp = self.handler.handle(query, pkt, self.origin)
239         return resp
240
241 class authzone(zone):
242     def __init__(self, aurecres, *args):
243         self.aurecres = aurecres
244         zone.__init__(self, *args)
245
246     def handle(self, query, pkt):
247         resp = zone.handle(self, query, pkt)
248         if not "internal" in pkt.flags:
249             if resp is None:
250                 resp = proto.responsefor(pkt)
251                 soa = zone.handle(self, rec.rrhead(self.origin, "SOA"), pkt)
252                 resp.aulist += soa.anlist
253                 resp.rescode = proto.NXDOMAIN
254             else:
255                 resolver.resolvecnames(resp, self.aurecres)
256                 nsrecs = zone.handle(self, rec.rrhead(self.origin, "NS"), pkt)
257                 if nsrecs is not None:
258                     resp.aulist += nsrecs.anlist
259                     for rr in nsrecs.anlist:
260                         resolver.resolveadditional(resp, rr, self.aurecres)
261         else:
262             if resp is None:
263                 return None
264         resp.setflags(["auth"])
265         return resp
266
267 class handler:
268     def handle(self, query, pkt, origin):
269         return None
270
271 class forwarder(handler):
272     def __init__(self, nameserver, timeout = 2000, retries = 3):
273         self.nameserver = nameserver
274         self.timeout = timeout
275         self.retries = retries
276     
277     def handle(self, query, pkt, origin):
278         sk = socket.socket(self.nameserver[0], socket.SOCK_DGRAM)
279         sk.bind(("", 0))
280         p = select.poll()
281         p.register(sk.fileno(), select.POLLIN)
282         for i in range(self.retries):
283             sk.sendto(pkt.encode(), self.nameserver[1:])
284             fds = p.poll(self.timeout)
285             if (sk.fileno(), select.POLLIN) in fds:
286                 break
287         else:
288             return None
289         resp = sk.recv(65536)
290         resp = proto.decodepacket(resp)
291         return resp
292     
293 class recurser(handler):
294     def __init__(self, resolver):
295         self.resolver = resolver
296
297     def handle(self, query, pkt, origin):
298         try:
299             resp = self.resolver.resolve(pkt)
300         except resolver.error:
301             return None
302         return resp
303
304 class chain(handler):
305     def __init__(self, chain):
306         self.chain = chain
307
308     def add(self, handler):
309         self.chain += [handler]
310     
311     def handle(self, *args):
312         for h in self.chain:
313             resp = h.handle(*args)
314             if resp is not None:
315                 return resp
316         return None
317