X-Git-Url: http://dolda2000.com/gitweb/?p=ldd.git;a=blobdiff_plain;f=ldd%2Fserver.py;fp=ldd%2Fserver.py;h=86fd88591cec4fc79324f20f946275ab363885a9;hp=0000000000000000000000000000000000000000;hb=769e7ed964e3720cf25825dd5390af5fb0bf4851;hpb=2e783944bffb349dff8667dab0ba0c48b21c9504 diff --git a/ldd/server.py b/ldd/server.py new file mode 100644 index 0000000..86fd885 --- /dev/null +++ b/ldd/server.py @@ -0,0 +1,317 @@ +# ldd - DNS implementation in Python +# Copyright (C) 2006 Fredrik Tolf +# +# This program is free software; you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation; either version 2 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program; if not, write to the Free Software +# Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA + +import socket +import threading +import select +import errno +import logging +import time + +import proto +import rec +import dn +import resolver + +logger = logging.getLogger("ldd.server") + +class dnsserver: + class socklistener(threading.Thread): + def __init__(self, server): + threading.Thread.__init__(self) + self.server = server + self.alive = True + + class sender: + def __init__(self, addr, sk): + self.addr = addr + self.sk = sk + + def send(self, pkt): + logger.debug("sending response to %04x", pkt.qid) + self.sk.sendto(pkt.encode(), self.addr) + + def run(self): + while self.alive: + p = select.poll() + for af, sk in self.server.sockets: + p.register(sk.fileno(), select.POLLIN) + try: + fds = p.poll(1000) + except select.error, e: + if e[0] == errno.EINTR: + continue + raise + for fd, event in fds: + if event & select.POLLIN == 0: + continue + for af, sk in self.server.sockets: + if sk.fileno() == fd: + break + else: + continue + req, addr = sk.recvfrom(65536) + try: + pkt = proto.decodepacket(req) + except proto.malformedpacket, inst: + resp = proto.packet(inst.qid, ["resp"]) + resp.rescode = proto.FORMERR + sk.sendto(resp.encode(), addr) + else: + logger.debug("got request (%04x) from %s", pkt.qid, addr[0]) + pkt.addr = (af,) + addr + self.server.queuereq(pkt, dnsserver.socklistener.sender(addr, sk)) + + def stop(self): + self.alive = False + + class dispatcher(threading.Thread): + def __init__(self, server): + threading.Thread.__init__(self) + self.server = server + self.alive = True + + def run(self): + while self.alive: + req = self.server.dequeuereq() + if req is not None: + pkt, sender = req + resp = self.server.handle(pkt) + if resp is None: + resp = proto.responsefor(pkt, proto.SERVFAIL) + sender.send(resp) + + class queuemonitor(threading.Thread): + def __init__(self, server): + threading.Thread.__init__(self) + self.server = server + + def run(self): + while(self.server.running): + self.server.queuelock.acquire() + if len(self.server.queue) > 0: + peeked = self.server.queue[0] + else: + peeked = None + self.server.queuelock.release() + if peeked is not None: + if time.time() - peeked[0] > 1: + newdsp = dnsserver.dispatcher(self.server) + self.server.dispatchers += [newdsp] + newdsp.start() + logger.debug("starting new dispatcher, there are now %i", len(self.server.dispatchers)) + time.sleep(1) + + def __init__(self): + self.sockets = [] + self.queue = [] + self.zones = [] + self.listener = None + self.dispatchers = [] + self.running = False + self.queuelock = threading.Condition() + self.knownkeys = [] + + def handle(self, pkt): + resp = None + + if len(self.knownkeys) > 0: + import dnssec + dnssec.tsigverify(pkt, self.knownkeys) + + for query in pkt.qlist: + match = None + for zone in self.zones: + if query.name in zone.origin: + if match is None: + match = zone + elif len(zone.origin) > len(match.origin): + match = zone + if match is None: + return None + else: + curresp = match.handle(query, pkt) + if resp is None: + resp = curresp + else: + resp.merge(curresp) + + if resp is not None and resp.tsigctx is not None and not resp.signed: + resp.tsigctx.signpkt(resp) + + return resp + + def addsock(self, af, socket): + self.sockets += [(af, socket)] + + def addzone(self, zone): + self.zones += [zone] + + def queuereq(self, req, sender): + self.queuelock.acquire() + self.queue += [(time.time(), req, sender)] + logger.debug("queue length+: %i", len(self.queue)) + self.queuelock.notify() + self.queuelock.release() + + def dequeuereq(self): + self.queuelock.acquire() + if len(self.queue) == 0: + self.queuelock.wait() + if len(self.queue) > 0: + ret = self.queue[0] + self.queue = self.queue[1:] + else: + ret = None + logger.debug("queue length-: %i", len(self.queue)) + self.queuelock.release() + if ret is None: + return None + else: + return ret[1:] + + def start(self): + if self.running: + raise Exception("already running") + lst = dnsserver.socklistener(self) + self.listener = lst + lst.start() + for i in xrange(10): + newdsp = dnsserver.dispatcher(self) + self.dispatchers += [newdsp] + newdsp.start() + self.running = True + self.monitor = dnsserver.queuemonitor(self) + self.monitor.start() + + def stop(self): + self.listener.stop() + self.listener.join() + self.listener = None + for dsp in self.dispatchers: + dsp.alive = False + self.queuelock.acquire() + self.queuelock.notifyAll() + self.queuelock.release() + for dsp in self.dispatchers + []: + dsp.join() + self.dispatchers.remove(dsp) + self.running = False + self.monitor = None + + def resolver(self, addr = None): + class myres(resolver.resolver): + def __init__(self, server, addr): + self.server = server + self.addr = addr + def resolve(self, packet): + if self.addr is not None: + packet.addr = self.addr + packet.setflags(["internal"]) + return self.server.handle(packet) + return myres(self, addr) + +class zone: + def __init__(self, origin, handler): + if type(origin) == str: + self.origin = dn.fromstring(origin) + self.origin.rooted = True + else: + self.origin = origin + self.handler = handler + + def handle(self, query, pkt): + resp = self.handler.handle(query, pkt, self.origin) + return resp + +class authzone(zone): + def __init__(self, aurecres, *args): + self.aurecres = aurecres + zone.__init__(self, *args) + + def handle(self, query, pkt): + resp = zone.handle(self, query, pkt) + if not "internal" in pkt.flags: + if resp is None: + resp = proto.responsefor(pkt) + soa = zone.handle(self, rec.rrhead(self.origin, "SOA"), pkt) + resp.aulist += soa.anlist + resp.rescode = proto.NXDOMAIN + else: + resolver.resolvecnames(resp, self.aurecres) + nsrecs = zone.handle(self, rec.rrhead(self.origin, "NS"), pkt) + if nsrecs is not None: + resp.aulist += nsrecs.anlist + for rr in nsrecs.anlist: + resolver.resolveadditional(resp, rr, self.aurecres) + else: + if resp is None: + return None + resp.setflags(["auth"]) + return resp + +class handler: + def handle(self, query, pkt, origin): + return None + +class forwarder(handler): + def __init__(self, nameserver, timeout = 2000, retries = 3): + self.nameserver = nameserver + self.timeout = timeout + self.retries = retries + + def handle(self, query, pkt, origin): + sk = socket.socket(self.nameserver[0], socket.SOCK_DGRAM) + sk.bind(("", 0)) + p = select.poll() + p.register(sk.fileno(), select.POLLIN) + for i in range(self.retries): + sk.sendto(pkt.encode(), self.nameserver[1:]) + fds = p.poll(self.timeout) + if (sk.fileno(), select.POLLIN) in fds: + break + else: + return None + resp = sk.recv(65536) + resp = proto.decodepacket(resp) + return resp + +class recurser(handler): + def __init__(self, resolver): + self.resolver = resolver + + def handle(self, query, pkt, origin): + try: + resp = self.resolver.resolve(pkt) + except resolver.error: + return None + return resp + +class chain(handler): + def __init__(self, chain): + self.chain = chain + + def add(self, handler): + self.chain += [handler] + + def handle(self, *args): + for h in self.chain: + resp = h.handle(*args) + if resp is not None: + return resp + return None +