Initial import master
authorfredrik <fredrik@959494ce-11ee-0310-bf91-de5d638817bd>
Sat, 20 Jan 2007 04:45:47 +0000 (04:45 +0000)
committerfredrik <fredrik@959494ce-11ee-0310-bf91-de5d638817bd>
Sat, 20 Jan 2007 04:45:47 +0000 (04:45 +0000)
git-svn-id: svn+ssh://svn.dolda2000.com/srv/svn/repos/src/ldd@832 959494ce-11ee-0310-bf91-de5d638817bd

16 files changed:
MANIFEST [new file with mode: 0644]
dnsdbtool [new file with mode: 0755]
ldd/__init__.py [new file with mode: 0644]
ldd/dbzone.py [new file with mode: 0644]
ldd/dn.py [new file with mode: 0644]
ldd/dnssec.py [new file with mode: 0644]
ldd/filters.py [new file with mode: 0644]
ldd/mdns.py [new file with mode: 0644]
ldd/proto.py [new file with mode: 0644]
ldd/rec.py [new file with mode: 0644]
ldd/rescache.py [new file with mode: 0644]
ldd/resolver.py [new file with mode: 0644]
ldd/server.py [new file with mode: 0644]
lddd [new file with mode: 0755]
resolve [new file with mode: 0755]
setup.py [new file with mode: 0644]

diff --git a/MANIFEST b/MANIFEST
new file mode 100644 (file)
index 0000000..336b4a9
--- /dev/null
+++ b/MANIFEST
@@ -0,0 +1,12 @@
+setup.py
+pydns/__init__.py
+pydns/dbzone.py
+pydns/dn.py
+pydns/dnssec.py
+pydns/filters.py
+pydns/mdns.py
+pydns/proto.py
+pydns/rec.py
+pydns/rescache.py
+pydns/resolver.py
+pydns/server.py
diff --git a/dnsdbtool b/dnsdbtool
new file mode 100755 (executable)
index 0000000..3d4ef1e
--- /dev/null
+++ b/dnsdbtool
@@ -0,0 +1,218 @@
+#!/usr/bin/python
+#    ldd - DNS implementation in Python
+#    Copyright (C) 2006 Fredrik Tolf <fredrik@dolda2000.com>
+#
+#    This program is free software; you can redistribute it and/or modify
+#    it under the terms of the GNU General Public License as published by
+#    the Free Software Foundation; either version 2 of the License, or
+#    (at your option) any later version.
+#
+#    This program is distributed in the hope that it will be useful,
+#    but WITHOUT ANY WARRANTY; without even the implied warranty of
+#    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+#    GNU General Public License for more details.
+#
+#    You should have received a copy of the GNU General Public License
+#    along with this program; if not, write to the Free Software
+#    Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA  02110-1301  USA
+
+
+import sys
+from getopt import getopt
+import socket
+
+from ldd import dbzone, rec, dn
+
+if len(sys.argv) < 3:
+    print "usage: dnsdbtool [-d] DBDIR DBNAME [COMMAND [ARGS...]]"
+    print "       dnsdbtool -c DBDIR DBNAME"
+    sys.exit(1)
+
+debug = False
+currrset = None
+curname = None
+opts, args = getopt(sys.argv[1:], "cd")
+for o, a in opts:
+    if o == "-c":
+        dbzone.dnsdb.create(args[0], args[1])
+        sys.exit(0)
+    if o == "-d":
+        debug = True
+
+db = dbzone.dnsdb(args[0], args[1])
+
+args = args[2:]
+
+class error(Exception):
+    def __init__(self, text):
+        self.text = text
+
+    def __str__(self):
+        return self.text
+
+def rrusage(name, syn):
+    u = ""
+    for e in syn:
+        u += " " + e[1].upper()
+    return "usage: " + name + u
+
+def mkrr(line):
+    rtype = rec.rtypebyname(line[0])
+    if rtype is None:
+        raise error("no such RR type " + line[0])
+    rtype = rec.rtypebyid(rtype)
+    syn = rtype[2]
+    if(len(line) != len(syn) + 1):
+        raise error(rrusage(rtype[1], syn))
+    return rec.rrdata(*line)
+
+def tokenize(line):
+    tokens = []
+    state = 0
+    for c in line + " ":
+        if state == 0:
+            if not c.isspace():
+                ctok = ""
+                state = 1
+        if state == 1:
+            if c.isspace():
+                tokens += [ctok]
+                state = 0
+            elif c == '"':
+                state = 2
+            elif c == "\\":
+                state = 3
+            else:
+                ctok += c
+        elif state == 2:
+            if c == '"':
+                state = 1
+            elif c == "\\":
+                state = 4
+            else:
+                ctok += c
+        elif state == 3:
+            ctok += c
+            state = 1
+        elif state == 4:
+            ctok += c
+            state = 2
+    return tokens
+
+def assertargs(cmd, line, num, usage):
+    if(len(line) < num):
+        raise error("usage: " + cmd + " " + usage)
+
+def runcommand(line):
+    global currrset, curname
+    cmd = line[0]
+    line = line[1:]
+    if cmd == "addnm":
+        assertargs(cmd, line, 3, "[-f FLAG] NAME TTL RR RRPARAM...")
+        opts, line = getopt(line, "f:")
+        opts = dict(opts)
+        head = rec.rrhead(dn.fromstring(line[0]), line[2])
+        ttl = int(line[1])
+        data = mkrr(line[2:])
+        rr = rec.rr(head, ttl, data)
+        if "-f" in opts: rr.setflags(opts["-f"].split(","))
+        db.addrr(head.name, rr)
+    elif cmd == "rmname":
+        assertargs(cmd, line, 1, "NAME")
+        name = dn.fromstring(line[0])
+        db.rmname(name)
+    elif cmd == "rmrt":
+        assertargs(cmd, line, 2, "NAME RRTYPE")
+        name = dn.fromstring(line[0])
+        db.rmrtype(name, line[1])
+    elif cmd == "lsnames":
+        for name in db.listnames():
+            print str(name)
+    elif cmd == "lsrr":
+        assertargs(cmd, line, 1, "NAME")
+        name = dn.fromstring(line[0])
+        rrset = db.lookup(name)
+#        dbzone.rootify(rrset, dn.fromstring("dolda2000.com."))
+        if rrset is None:
+            raise error("no such name in database")
+        for rr in rrset:
+            print str(rr)
+    elif cmd == "load":
+        assertargs(cmd, line, 1, "NAME")
+        name = dn.fromstring(line[0])
+        currrset = db.lookup(name)
+        curname = name
+        if currrset is None:
+            currrset = []
+    elif cmd == "add":
+        assertargs(cmd, line, 2, "[-f FLAG] TTL RR RRPARAM...")
+        opts, line = getopt(line, "f:")
+        opts = dict(opts)
+        head = rec.rrhead(curname, line[1])
+        ttl = int(line[0])
+        data = mkrr(line[1:])
+        rr = rec.rr(head, ttl, data)
+        if "-f" in opts: rr.setflags(opts["-f"].split(","))
+        currrset += [rr]
+    elif cmd == "ls":
+        if currrset is None:
+            raise error("no RRset loaded")
+        for i, rr in enumerate(currrset):
+            print str(i) + ": " + str(rr)
+    elif cmd == "rm":
+        assertargs(cmd, line, 1, "ID")
+        rrid = int(line[0])
+        currrset[rrid:rrid + 1] = []
+    elif cmd == "chttl":
+        assertargs(cmd, line, 2, "ID NEWTTL")
+        rrid = int(line[0])
+        ttl = int(line[1])
+        currrset[rrid].head.ttl = ttl
+    elif cmd == "chdt":
+        assertargs(cmd, line, 3, "ID NAME DATA")
+        rrid = int(line[0])
+        currrset[rrid].data[line[1]] = line[2]
+    elif cmd == "sf":
+        assertargs(cmd, line, 2, "ID FLAGS...")
+        rrid = int(line[0])
+        currrset[rrid].setflags(line[1:])
+    elif cmd == "cf":
+        assertargs(cmd, line, 2, "ID FLAGS...")
+        rrid = int(line[0])
+        currrset[rrid].clrflags(line[1:])
+    elif cmd == "store":
+        if len(line) > 0:
+            name = line[0]
+        else:
+            name = curname
+        db.set(name, currrset)
+    elif cmd == "?" or cmd == "help":
+        print "Available commands:"
+        print "addnm, rmname, rmrt, lsnames, lsrr, load, add, ls, rm,"
+        print "chttl, chdt, sf, cf, store"
+    else:
+        print "no such command: " + cmd
+
+if len(args) == 0:
+    while True:
+        if sys.stdin.isatty():
+            sys.stderr.write("> ")
+        line = sys.stdin.readline()
+        if line == "":
+            if sys.stdin.isatty(): print
+            break
+        try:
+            tokens = tokenize(line)
+            if len(tokens) > 0: runcommand(tokens)
+        except error:
+            sys.stderr.write(str(sys.exc_info()[1]) + "\n")
+else:
+    try:
+        runcommand(args)
+    except SystemExit:
+        raise
+    except:
+        if debug:
+            raise
+        else:
+            sys.stderr.write(str(sys.exc_info()[1]) + "\n")
diff --git a/ldd/__init__.py b/ldd/__init__.py
new file mode 100644 (file)
index 0000000..e69de29
diff --git a/ldd/dbzone.py b/ldd/dbzone.py
new file mode 100644 (file)
index 0000000..7cce85d
--- /dev/null
@@ -0,0 +1,273 @@
+#    ldd - DNS implementation in Python
+#    Copyright (C) 2006 Fredrik Tolf <fredrik@dolda2000.com>
+#
+#    This program is free software; you can redistribute it and/or modify
+#    it under the terms of the GNU General Public License as published by
+#    the Free Software Foundation; either version 2 of the License, or
+#    (at your option) any later version.
+#
+#    This program is distributed in the hope that it will be useful,
+#    but WITHOUT ANY WARRANTY; without even the implied warranty of
+#    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+#    GNU General Public License for more details.
+#
+#    You should have received a copy of the GNU General Public License
+#    along with this program; if not, write to the Free Software
+#    Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA  02110-1301  USA
+
+import bsddb
+import threading
+import pickle
+import logging
+
+import server
+import proto
+import rec
+import dn
+
+logger = logging.getLogger("ldd.dbzone")
+
+class dnsdb:
+    def __init__(self, dbdir, dbfile):
+        self.env = bsddb.db.DBEnv()
+        self.env.open(dbdir, bsddb.db.DB_JOINENV | bsddb.db.DB_THREAD)
+        self.db = bsddb.db.DB()
+        self.db.open(dbdir + "/" + dbfile, flags = bsddb.db.DB_THREAD)
+
+    def create(self, dbdir, dbfile):
+        env = bsddb.db.DBEnv()
+        env.open(dbdir, bsddb.db.DB_CREATE | bsddb.db.DB_EXCL | bsddb.db.DB_INIT_MPOOL | bsddb.db.DB_INIT_CDB | bsddb.db.DB_THREAD)
+        db = bsddb.db.DB()
+        db.open(dbdir + "/" + dbfile, dbtype = bsddb.db.DB_HASH, flags = bsddb.db.DB_CREATE | bsddb.db.DB_EXCL | bsddb.db.DB_THREAD)
+        db.close()
+        env.close()
+    create = classmethod(create)
+
+    def close(self):
+        self.db.close()
+        self.env.close()
+
+    def decoderecord(self, name, record):
+        set = pickle.loads(record)
+        rrset = []
+        for cur in set:
+            head = rec.rrhead(name, cur[0])
+            data = cur[2]
+            newrr = rec.rr(head, cur[1], data)
+            newrr.setflags(cur[3])
+            rrset += [newrr]
+        return rrset
+
+    def encoderecord(self, rrset):
+        set = []
+        for rr in rrset:
+            set += [(rr.head.rtype, rr.ttl, rr.data, rr.flags)]
+        return pickle.dumps(set)
+
+    def lookup(self, name):
+        record = self.db.get(str(name))
+        if record is None:
+            return None
+        return self.decoderecord(name, record)
+
+    def set(self, name, rrset):
+        self.db.put(str(name), self.encoderecord(rrset))
+        return True
+    
+    def hasname(self, name):
+        record = self.db.get(str(name))
+        return record is not None
+
+    def rmname(self, name):
+        try:
+            self.db.delete(str(name))
+        except bsddb.db.DBNotFoundError:
+            return False
+        return True
+
+    def rmrtype(self, name, rtype):
+        if type(rtype) == str:
+            rtype = rec.rtypebyname(rtype)
+        rrset = self.lookup(name)
+        if rrset is None:
+            return False
+        for rr in rrset:
+            if rr.head.rtype == rtype:
+                rrset.remove(rr)
+        self.set(name, rrset)
+        return True
+
+    def addrr(self, name, rr):
+        rrset = self.lookup(name)
+        if rrset is None:
+            rrset = []
+        rrset += [rr]
+        self.set(name, rrset)
+        return True
+
+    def listnames(self):
+        cursor = self.db.cursor()
+        ret = cursor.first()
+        if ret is not None:
+            name, record = ret
+            yield name
+            while True:
+                ret = cursor.next()
+                if ret is None:
+                    break
+                name, record = ret
+                yield name
+        cursor.close()
+
+def rootify(rrset, origin):
+    for rr in rrset:
+        if not rr.head.name.rooted:
+            rr.head.name += origin
+        for dname, dval in rr.data.rdata.items():
+            if isinstance(dval, dn.domainname) and not dval.rooted:
+                rr.data.rdata[dname] += origin
+
+class dbhandler(server.handler):
+    def __init__(self, dbdir, dbfile):
+        self.db = dnsdb(dbdir, dbfile)
+        self.doddns = False
+        self.authkeys = []
+
+    def handle(self, query, pkt, origin):
+        resp = proto.responsefor(pkt)
+        if pkt.opcode == proto.QUERY:
+            rrset = self.db.lookup(query.name)
+            if rrset is None and query.name in origin:
+                rrset = self.db.lookup(query.name - origin)
+            if rrset is None:
+                return None
+            rootify(rrset, origin)
+            resp.anlist = [rr for rr in rrset if rr.head.rtype == query.rtype or rr.head.istype("CNAME")]
+            return resp
+        if pkt.opcode == proto.UPDATE:
+            logger.debug("got DDNS request")
+            if len(pkt.qlist) != 1 or not pkt.qlist[0].istype("SOA"):
+                resp.rescode = proto.FORMERR
+                return resp
+            if pkt.qlist[0].name != origin:
+                resp.rescode = proto.NOTAUTH
+                return resp
+
+            # Check prerequisites
+            for rr in pkt.anlist:
+                if rr.ttl != 0:
+                    resp.rescode = proto.FORMERR
+                    return resp
+                if rr.head.name not in origin:
+                    resp.rescode = proto.NOTZONE
+                    return resp
+                myname = rr.head.name - origin
+                rrset = self.db.lookup(myname)
+                if rr.head.rclass == rec.CLASSANY:
+                    if rr.data is not None:
+                        resp.rescode = proto.FORMERR
+                        return resp
+                    if rr.head.rtype == proto.QTANY:
+                        if rrset is None:
+                            resp.rescode = proto.NXDOMAIN
+                            return resp
+                    else:
+                        if rrset is not None:
+                            for rr2 in rrset:
+                                if rr2.head.name == myname and rr.head.rtype == rr2.head.rtype:
+                                    break
+                            else:
+                                resp.rescode = proto.NXRRSET
+                                return resp
+                elif rr.head.rclass == rec.CLASSNONE:
+                    if rr.data is not None:
+                        resp.rescode = proto.FORMERR
+                        return resp
+                    if rr.head.rtype == proto.QTANY:
+                        if rrset is not None:
+                            resp.rescode = proto.YXDOMAIN
+                            return resp
+                    else:
+                        if rrset is not None:
+                            for rr2 in rrset:
+                                if rr2.head.name == myname and rr.head.rtype == rr2.head.rtype:
+                                    resp.rescode = proto.YXRRSET
+                                    return resp
+                elif rr.head.rclass == rec.CLASSIN:
+                    if rrset is not None:
+                        for rr2 in rrset:
+                            if rr2.head.name == myname and rr.head.rtype == rr2.head.rtype and rr.data == rr2.data:
+                                break
+                        else:
+                            resp.rescode = proto.NXRRSET
+                            return resp
+                else:
+                    resp.rescode = FORMERR
+                    return resp
+
+            # Check for permission
+            if not self.doddns:
+                resp.rescode = proto.REFUSED
+                return resp
+            if type(self.authkeys) == list:
+                if pkt.tsigctx is None:
+                    resp.rescode = proto.REFUSED
+                    return resp
+                if pkt.tsigctx.error != 0:
+                    resp.rescode = proto.NOTAUTH
+                    return resp
+                if pkt.tsigctx.key not in self.authkeys:
+                    resp.rescode = proto.REFUSED
+                    return resp
+            elif type(self.authkeys) == None:
+                authorized = True
+
+            # Do precheck on updates
+            for rr in pkt.aulist:
+                if rr.head.name not in origin:
+                    resp.rescode = proto.NOTZONE
+                    return resp
+                if rr.head.rclass == rec.CLASSIN:
+                    if rr.head.rtype == proto.QTANY or rr.data is None:
+                        resp.rescode = proto.FORMERR
+                        return resp
+                elif rr.head.rclass == rec.CLASSANY:
+                    if rr.data is not None:
+                        resp.rescode = proto.FORMERR
+                        return resp
+                elif rr.head.rclass == rec.CLASSNONE:
+                    if rr.head.rtype == proto.QTANY or rr.ttl != 0 or rr.data is None:
+                        resp.rescode = proto.FORMERR
+                        return resp
+                else:
+                    resp.rescode = proto.FORMERR
+                    return resp
+
+            # Perform updates
+            for rr in pkt.aulist:
+                myname = rr.head.name - origin
+                if rr.head.rclass == rec.CLASSIN:
+                    logger.info("adding rr (%s)", rr)
+                    self.db.addrr(myname, rr)
+                elif rr.head.rclass == rec.CLASSANY:
+                    if rr.head.rtype == proto.QTANY:
+                        logger.info("removing rrset (%s)", rr.head.name)
+                        self.db.rmname(myname)
+                    else:
+                        logger.info("removing rrset (%s, %s)", rr.head.name, rr.head.rtype)
+                        self.db.rmrtype(myname, rr.head.rtype)
+                elif rr.head.rclass == rec.CLASSNONE:
+                    logger.info("removing rr (%s)", rr)
+                    rrset = self.db.lookup(myname)
+                    changed = False
+                    if rrset is not None:
+                        for rr2 in rrset:
+                            if rr2.head == rr.head and rr2.data == rr.data:
+                                rrset.remove(rr2)
+                                changed = True
+                        self.db.set(myname, rrset)
+
+            return resp
+            
+        resp.rescode = proto.NOTIMP
+        return resp
diff --git a/ldd/dn.py b/ldd/dn.py
new file mode 100644 (file)
index 0000000..17becac
--- /dev/null
+++ b/ldd/dn.py
@@ -0,0 +1,131 @@
+#    ldd - DNS implementation in Python
+#    Copyright (C) 2006 Fredrik Tolf <fredrik@dolda2000.com>
+#
+#    This program is free software; you can redistribute it and/or modify
+#    it under the terms of the GNU General Public License as published by
+#    the Free Software Foundation; either version 2 of the License, or
+#    (at your option) any later version.
+#
+#    This program is distributed in the hope that it will be useful,
+#    but WITHOUT ANY WARRANTY; without even the implied warranty of
+#    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+#    GNU General Public License for more details.
+#
+#    You should have received a copy of the GNU General Public License
+#    along with this program; if not, write to the Free Software
+#    Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA  02110-1301  USA
+
+class DNNotIn(Exception):
+    def __init__(self, a, b):
+        self.a = a
+        self.b = b
+
+    def __str__(self):
+        return str(self.a) + " not in " + str(self.b)
+
+class domainname:
+    "A class for abstract representations of domain names"
+    
+    def __init__(self, parts, rooted):
+        self.parts = parts
+        self.rooted = rooted
+    
+    def __repr__(self):
+        ret = ""
+        if len(self.parts) > 0:
+            for p in self.parts[:-1]:
+                ret = ret + p + '.'
+            ret = ret + self.parts[-1]
+        if self.rooted:
+            ret = ret + '.'
+        return ret
+
+    def __add__(self, y):
+        if self.rooted:
+            raise Exception("cannot append to a rooted domain name")
+        return(domainname(self.parts + y.parts, y.rooted))
+
+    def __getitem__(self, y):
+        return domainname([self.parts[y]], self.rooted and (y == -1 or y == len(self.parts) - 1))
+
+    def __getslice__(self, i, j):
+        return domainname(self.parts[i:j], self.rooted and j >= len(self.parts))
+
+    def __len__(self):
+        return len(self.parts)
+
+    def __eq__(self, y):
+        if type(y) == str:
+            y = fromstring(y)
+        if self.rooted != y.rooted:
+            return False
+        if len(self.parts) != len(y.parts):
+            return False
+        for i in range(len(self.parts)):
+            if self.parts[i].lower() != y.parts[i].lower():
+                return False
+        return True
+
+    def __ne__(self, y):
+        return not self.__eq__(y)
+    
+    def __contains__(self, y):
+        if len(self) > len(y):
+            return False
+        if len(self) == 0:
+            return self.rooted == y.rooted
+        return y[-len(self):] == self
+
+    def __sub__(self, y):
+        if self not in y:
+            raise DNNotIn(self, y)
+        return self[:len(self) - len(y)]
+
+    def __hash__(self):
+        ret = 0
+        for part in self.parts:
+            ret = ret ^ hash(part)
+        if self.rooted:
+            ret = ret ^ -1
+        return ret
+
+    def canonwire(self):
+        ret = ""
+        for p in self.parts:
+            ret += chr(len(p))
+            ret += p.lower()
+        ret += chr(0)
+        return ret
+    
+class DNError(Exception):
+    emptypart = 1
+    illegalchar = 2
+    def __init__(self, kind):
+        self.kind = kind
+    def __str__(self):
+        return {1: "empty part",
+                2: "illegal character"}[self.kind]
+
+def fromstring(name):
+    parts = []
+    if name == ".":
+        return domainname([], True)
+    if name == "":
+        return domainname([], False)
+    while name.find('.') >= 0:
+        cur = name.find('.')
+        if cur == 0:
+            raise DNError(DNError.emptypart)
+        part = name[:cur]
+        for c in part:
+            if ord(c) < 33:
+                raise DNError(DNError.illegalchar)
+        parts.append(part)
+        name = name[cur + 1:]
+    if len(name) > 0:
+        parts.append(name)
+        rooted = False
+    else:
+        rooted = True
+    return domainname(parts, rooted)
+
diff --git a/ldd/dnssec.py b/ldd/dnssec.py
new file mode 100644 (file)
index 0000000..aa42943
--- /dev/null
@@ -0,0 +1,138 @@
+#    ldd - DNS implementation in Python
+#    Copyright (C) 2006 Fredrik Tolf <fredrik@dolda2000.com>
+#
+#    This program is free software; you can redistribute it and/or modify
+#    it under the terms of the GNU General Public License as published by
+#    the Free Software Foundation; either version 2 of the License, or
+#    (at your option) any later version.
+#
+#    This program is distributed in the hope that it will be useful,
+#    but WITHOUT ANY WARRANTY; without even the implied warranty of
+#    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+#    GNU General Public License for more details.
+#
+#    You should have received a copy of the GNU General Public License
+#    along with this program; if not, write to the Free Software
+#    Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA  02110-1301  USA
+
+import base64
+import time
+import struct
+from Crypto.Hash import HMAC, MD5
+
+import proto, rec, dn
+
+class tsigkey:
+    def __init__(self, name, algo, secret):
+        if type(name) == str:
+            self.name = dn.fromstring(name)
+        else:
+            self.name = name
+        if type(algo) == str:
+            self.algo = algobyname[algo]
+        else:
+            self.algo = algo
+        self.secret = secret
+
+    def sign(self, message):
+        return self.algo.sign(self.secret, message)
+
+class tsigalgo:
+    def __init__(self, name, cname, function):
+        self.name = name
+        if type(cname) == str:
+            self.cname = dn.fromstring(cname)
+        else:
+            self.cname = cname
+        self.function = function
+
+    def sign(self, secret, message):
+        return self.function(secret, message)
+
+class tsigctx:
+    def __init__(self, key, pkt, sr):
+        self.key = key
+        self.prevmac = sr.data["mac"]
+        self.error = 0
+
+    def signpkt(self, pkt):
+        tsigsign(pkt, None, ctx = self, error = self.error)
+
+def tsigsign(pkt, key, stime = None, fudge = 300, error = 0, other = "", ctx = None):
+    if stime is None: stime = int(time.time())
+    msg = ""
+    if ctx is not None:
+        if key is None:
+            key = ctx.key
+        msg += struct.pack(">H", len(ctx.prevmac)) + ctx.prevmac
+    msg += pkt.encode()
+    msg += key.name.canonwire()
+    msg += struct.pack(">HL", rec.CLASSANY, 0)
+    msg += key.algo.cname.canonwire()
+    msg += struct.pack(">Q", stime)[-6:]
+    msg += struct.pack(">3H", fudge, error, len(other))
+    msg += other
+    digest = key.sign(msg)
+    pkt.addad(rec.rr((key.name, "TSIG", rec.CLASSANY), 0, rec.rrdata("TSIG", key.algo.cname, stime, fudge, digest, pkt.qid, error, other)))
+    pkt.signed = True
+
+def tsigverify(pkt, keys, vertime = None):
+    if vertime is None: vertime = int(time.time())
+    if len(pkt.adlist) < 1:
+        return proto.FORMERR
+    sr = pkt.adlist[-1]
+    pkt.adlist = pkt.adlist[:-1]
+    if not sr.head.istype("TSIG") or sr.head.rclass != rec.CLASSANY:
+        return proto.FORMERR
+    for key in keys:
+        if key.name == sr.head.name:
+            break
+    else:
+        return proto.BADKEY
+    if key.algo.cname != sr.data["algo"]:
+        return proto.BADKEY
+
+    pkt.tsigctx = ctx = tsigctx(key, pkt, sr)
+    
+    other = sr.data["other"]
+    msg = pkt.encode()
+    msg += key.name.canonwire()
+    msg += struct.pack(">HL", rec.CLASSANY, 0)
+    msg += key.algo.cname.canonwire()
+    msg += struct.pack(">Q", sr.data["stime"])[-6:]
+    msg += struct.pack(">3H", sr.data["fudge"], sr.data["err"], len(other))
+    msg += other
+    digest = key.sign(msg)
+    if digest != sr.data["mac"]:
+        pkt.tsigctx = proto.BADSIG
+        return proto.BADSIG
+    if vertime != 0:
+        if abs(vertime - sr.data["stime"]) > sr.data["fudge"]:
+            pkt.tsigctx = proto.BADTIME
+            return proto.BADTIME
+    return key
+
+def signhmacmd5(secret, message):
+    s = HMAC.HMAC(secret, digestmod = MD5)
+    s.update(message)
+    return s.digest()
+
+def readkeys(keyfile):
+    close = False
+    if type(keyfile) == str:
+        keyfile = open(keyfile, "r")
+        close = True
+    try:
+        ret = []
+        for line in keyfile:
+            words = line.split()
+            if len(words) < 3:
+                continue
+            ret += [tsigkey(dn.fromstring(words[0]), words[1], base64.b64decode(words[2]))]
+        return ret
+    finally:
+        if close: keyfile.close()
+
+algos = [tsigalgo("hmac-md5", "hmac-md5.sig-alg.reg.int.", signhmacmd5)]
+
+algobyname = dict([(a.name, a) for a in algos])
diff --git a/ldd/filters.py b/ldd/filters.py
new file mode 100644 (file)
index 0000000..a1798c8
--- /dev/null
@@ -0,0 +1,98 @@
+#    ldd - DNS implementation in Python
+#    Copyright (C) 2006 Fredrik Tolf <fredrik@dolda2000.com>
+#
+#    This program is free software; you can redistribute it and/or modify
+#    it under the terms of the GNU General Public License as published by
+#    the Free Software Foundation; either version 2 of the License, or
+#    (at your option) any later version.
+#
+#    This program is distributed in the hope that it will be useful,
+#    but WITHOUT ANY WARRANTY; without even the implied warranty of
+#    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+#    GNU General Public License for more details.
+#
+#    You should have received a copy of the GNU General Public License
+#    along with this program; if not, write to the Free Software
+#    Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA  02110-1301  USA
+
+import socket
+import time
+import fcntl
+import struct
+
+import server
+
+def linuxifip4hack(ifname):
+    req = ifname + ("\0" * (32 - len(ifname)))
+    sk = socket.socket()
+    res = fcntl.ioctl(sk.fileno(), 0x8915, req)
+    sk.close()
+    sockaddr = res[16:]
+    return sockaddr[4:8]
+
+class valuecache:
+    def __init__(self, func, expire):
+        self.func = func
+        self.expire = expire
+        self.last = 0
+
+    def __call__(self, *args):
+        now = int(time.time())
+        if self.last == 0 or now - self.last > self.expire:
+            self.val = self.func(*(args))
+            self.last = now
+        return self.val
+
+class prefix6to4(server.handler):
+    def __init__(self, next, v4addr):
+        self.next = next
+        if callable(v4addr):
+            self.packed = v4addr
+        elif len(v4addr) == 4:
+            self.packed = v4addr
+        else:
+            self.packed = socket.inet_pton(socket.AF_INET, v4addr)
+
+    def handle(self, *args):
+        resp = self.next.handle(*args)
+        if resp is None:
+            return None
+        for rr in resp.allrrs():
+            if rr.head.istype("AAAA"):
+                addr = rr.data["address"]
+                if addr[0:6] == "\x20\x02\x00\x00\x00\x00":
+                    packed = self.packed
+                    if callable(packed):
+                        packed = packed()
+                    addr = addr[0:2] + packed + addr[6:]
+                    rr.data["address"] = addr
+        return resp
+
+class addrfilter(server.handler):
+    def __init__(self, default = None, matchers = []):
+        self.matchers = matchers
+        self.default = default
+
+    def setdefault(self, handler):
+        self.default = handler
+
+    def addmatcher(self, af, prefix, preflen, handler):
+        self.matchers += [(af, socket.inet_pton(af, prefix), preflen, handler)]
+    
+    def handle(self, query, pkt, origin):
+        matchlen = -1
+        match = self.default
+        if pkt.addr is not None:
+            for af, prefix, preflen, handler in self.matchers:
+                if pkt.addr[0] == af:
+                    addr = socket.inet_pton(af, pkt.addr[1])
+                    bytes = preflen >> 3
+                    restmask = 255 ^ ((1 << (8 - (preflen & 7))) - 1)
+                    if prefix[0:bytes] == addr[0:bytes] and \
+                           (ord(prefix[bytes]) & restmask) == (ord(addr[bytes]) & restmask):
+                        if preflen > matchlen:
+                            matchlen = preflen
+                            match = handler
+        if match is not None:
+            return match.handle(query, pkt, origin)
+        return None
diff --git a/ldd/mdns.py b/ldd/mdns.py
new file mode 100644 (file)
index 0000000..b1ce504
--- /dev/null
@@ -0,0 +1,42 @@
+#    ldd - DNS implementation in Python
+#    Copyright (C) 2006 Fredrik Tolf <fredrik@dolda2000.com>
+#
+#    This program is free software; you can redistribute it and/or modify
+#    it under the terms of the GNU General Public License as published by
+#    the Free Software Foundation; either version 2 of the License, or
+#    (at your option) any later version.
+#
+#    This program is distributed in the hope that it will be useful,
+#    but WITHOUT ANY WARRANTY; without even the implied warranty of
+#    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+#    GNU General Public License for more details.
+#
+#    You should have received a copy of the GNU General Public License
+#    along with this program; if not, write to the Free Software
+#    Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA  02110-1301  USA
+
+import socket
+import struct
+
+ip4addr = "224.0.0.251"
+ip6addr = "ff02::fb"
+
+def mkip4sock(port = 5353):
+    sk = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, socket.IPPROTO_UDP)
+    mcastinfo = socket.inet_pton(socket.AF_INET, ip4addr)
+    mcastinfo += socket.inet_pton(socket.AF_INET, "0.0.0.0")
+    sk.setsockopt(socket.SOL_IP, socket.IP_ADD_MEMBERSHIP, mcastinfo)
+    sk.setsockopt(socket.SOL_IP, socket.IP_MULTICAST_TTL, 255)
+    sk.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
+    sk.bind(("", port))
+    return sk
+
+def mkip6sock(port = 5353):
+    sk = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
+    mcastinfo = socket.inet_pton(socket.AF_INET6, ip6addr)
+    mcastinfo += struct.pack("I", 0)
+    sk.setsockopt(socket.SOL_IP, socket.IPV6_JOIN_GROUP, mcastinfo)
+    sk.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_MULTICAST_HOPS, 255)
+    sk.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
+    sk.bind(("", port))
+    return sk
diff --git a/ldd/proto.py b/ldd/proto.py
new file mode 100644 (file)
index 0000000..b63451f
--- /dev/null
@@ -0,0 +1,276 @@
+#    ldd - DNS implementation in Python
+#    Copyright (C) 2006 Fredrik Tolf <fredrik@dolda2000.com>
+#
+#    This program is free software; you can redistribute it and/or modify
+#    it under the terms of the GNU General Public License as published by
+#    the Free Software Foundation; either version 2 of the License, or
+#    (at your option) any later version.
+#
+#    This program is distributed in the hope that it will be useful,
+#    but WITHOUT ANY WARRANTY; without even the implied warranty of
+#    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+#    GNU General Public License for more details.
+#
+#    You should have received a copy of the GNU General Public License
+#    along with this program; if not, write to the Free Software
+#    Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA  02110-1301  USA
+
+from random import randint
+import struct
+
+import dn
+import rec
+
+class malformedpacket(Exception):
+    def __init__(self, text, qid):
+        self.text = text
+        self.qid = qid
+
+    def __str__(self):
+        return self.text
+
+class packet:
+    "An abstract representation of a DNS query"
+
+    def __init__(self, qid = None, flags = 0, addr = None):
+        if qid is None: qid = randint(0, 65535)
+        self.qid = qid
+        self.qlist = []
+        self.anlist = []
+        self.aulist = []
+        self.adlist = []
+        self.opcode = 0
+        self.rescode = 0
+        self.addr = addr
+        self.signed = False
+        self.tsigctx = None
+        if type(flags) == int:
+            self.initflags(flags)
+        elif type(flags) == set:
+            self.flags = flags
+        else:
+            self.flags = set(flags)
+        
+    def setflags(self, flags):
+        flags = set(flags)
+        self.flags |= flags
+
+    def clrflags(self, flags):
+        flags = set(flags)
+        self.flags -= flags
+
+    def initflags(self, flags):
+        nf = set()
+        if flags & 0x8000: nf.add("resp")
+        if flags & 0x0400: nf.add("auth")
+        if flags & 0x0200: nf.add("trunc")
+        if flags & 0x0100: nf.add("recurse")
+        if flags & 0x0080: nf.add("recursed")
+        if flags & 0x0020: nf.add("isauthen")
+        if flags & 0x0010: nf.add("authok")
+        self.opcode =   (flags & 0x7800) >> 11
+        self.rescode =   flags & 0x000f
+        self.flags = nf
+
+    def encodeflags(self):
+        ret = 0
+        if "resp"     in self.flags: ret |= 0x8000
+        if "auth"     in self.flags: ret |= 0x0400
+        if "trunc"    in self.flags: ret |= 0x0200
+        if "recurse"  in self.flags: ret |= 0x0100
+        if "recursed" in self.flags: ret |= 0x0080
+        if "authok"   in self.flags: ret |= 0x0010
+        ret |= self.opcode << 11
+        ret |= self.rescode
+        return ret
+
+    def addq(self, rr):
+        self.qlist.append(rr)
+
+    def addan(self, rr):
+        for rr2 in self.anlist:
+            if rr2.head  == rr.head and rr2.data == rr.data:
+                break
+        else:
+            self.anlist.append(rr)
+
+    def addau(self, rr):
+        for rr2 in self.aulist:
+            if rr2.head  == rr.head and rr2.data == rr.data:
+                break
+        else:
+            self.aulist.append(rr)
+
+    def addad(self, rr):
+        for rr2 in self.adlist:
+            if rr2.head  == rr.head and rr2.data == rr.data:
+                break
+        else:
+            self.adlist.append(rr)
+
+    def allrrs(self):
+        return self.anlist + self.aulist + self.adlist
+
+    def merge(self, other):
+        for lst in ["anlist", "aulist", "adlist"]:
+            for rr in getattr(other, lst):
+                for rr2 in getattr(self, lst):
+                    if rr2.head == rr.head and rr2.data == rr.data:
+                        break
+                else:
+                    getattr(self, lst).append(rr)
+    
+    def getanswer(self, name, rtype):
+        for rr in self.anlist + self.aulist + self.adlist:
+            if rr.head.istype(rtype) and rr.head.name == name:
+                return rr
+        return None
+
+    def hasanswers(self):
+        for q in self.qlist:
+            for rr in self.anlist + self.aulist + self.adlist:
+                if rr.head.rtype == q.rtype and rr.head.name == q.name:
+                    break
+                if rr.head.istype("CNAME") and rr.head.name == q.name and self.getanswer(rr.data["priname"], q.rtype) is not None:
+                    break
+            else:
+                break
+        else:
+            return True
+        return False
+        
+    def __str__(self):
+        ret = ""
+        ret += "ID: " + str(self.qid) + "\n"
+        ret += "Flags: " + str(self.flags) + "\n"
+        ret += "Opcode: " + str(self.opcode) + "\n"
+        ret += "Resp. code: " + str(self.rescode) + "\n"
+        ret += "Queries:\n"
+        for rr in self.qlist:
+            ret += "\t" + str(rr) + "\n"
+        ret += "Answers:\n"
+        for rr in self.anlist:
+            ret += "\t" + str(rr) + "\n"
+        ret += "Auth RRs:\n"
+        for rr in self.aulist:
+            ret += "\t" + str(rr) + "\n"
+        ret += "Additional RRs:\n"
+        for rr in self.adlist:
+            ret += "\t" + str(rr) + "\n"
+        return ret
+
+    def encode(self):
+        ret = ""
+        ret += struct.pack(">6H", self.qid, self.encodeflags(), len(self.qlist), len(self.anlist), len(self.aulist), len(self.adlist))
+        offset = len(ret)
+        names = []
+        for rr in self.qlist:
+            rre, names = rr.encode(names, offset)
+            offset += len(rre)
+            ret += rre
+        for rr in self.anlist:
+            rre, names = rr.encode(names, offset)
+            offset += len(rre)
+            ret += rre
+        for rr in self.aulist:
+            rre, names = rr.encode(names, offset)
+            offset += len(rre)
+            ret += rre
+        for rr in self.adlist:
+            rre, names = rr.encode(names, offset)
+            offset += len(rre)
+            ret += rre
+        return ret
+
+def decodepacket(string):
+    offset = struct.calcsize(">6H")
+    qid, flags, qno, anno, auno, adno = struct.unpack(">6H", string[0:offset])
+    ret = packet(qid, flags)
+    try:
+        for i in range(qno):
+            crr, offset = rec.rrhead.decode(string, offset)
+            ret.addq(crr)
+        for i in range(anno):
+            crr, offset = rec.rr.decode(string, offset)
+            ret.addan(crr)
+        for i in range(auno):
+            crr, offset = rec.rr.decode(string, offset)
+            ret.addau(crr)
+        for i in range(adno):
+            crr, offset = rec.rr.decode(string, offset)
+            ret.addad(crr)
+    except rec.malformedrr, inst:
+        raise malformedpacket(str(inst), qid)
+    return ret
+
+def responsefor(pkt, rescode = 0):
+    resp = packet(pkt.qid, ["resp"])
+    resp.opcode = pkt.opcode
+    resp.rescode = rescode
+    resp.tsigctx = pkt.tsigctx
+    resp.qlist = pkt.qlist + []  # Make a copy
+    return resp
+
+def decodename(packet, offset):
+    parts = []
+    while True:
+        clen = ord(packet[offset])
+        offset += 1
+        if clen & 0xc0:
+            my = dn.domainname(parts, False)
+            cont, = struct.unpack(">H", chr(clen & 0x3f) + packet[offset])
+            res, discard = decodename(packet, cont)
+            return my + res, offset + 1
+        elif clen == 0:
+            return dn.domainname(parts, True), offset
+        else:
+            parts.append(packet[offset:offset + clen])
+            offset += clen
+
+def encodename(dn, names, offset):
+    ret = ""
+    for i in range(len(dn)):
+        for name, off in names:
+            if name == dn[i:]:
+                ret += chr(0xc0 + (off >> 8))
+                ret += chr(off & 0xff)
+                offset += 2
+                return ret, names
+        if offset < 16384:
+            names += [(dn[i:], offset)]
+        ret += chr(len(dn.parts[i]))
+        ret += dn.parts[i]
+        offset += 1 + len(dn.parts[i])
+    ret += chr(0)
+    offset += 1
+    return ret, names
+
+# Opcode constants
+QUERY = 0
+IQUERY = 1
+STATUS = 2
+UPDATE = 5
+
+# Response code constants
+#  RFC 1035:
+FORMERR = 1
+SERVFAIL = 2
+NXDOMAIN = 3
+NOTIMP = 4
+REFUSED = 5
+#  RFC 2136:
+YXDOMAIN = 6
+YXRRSET = 7
+NXRRSET = 8
+NOTAUTH = 9
+NOTZONE = 10
+#  RFC 2845:
+BADSIG = 16
+BADKEY = 17
+BADTIME = 18
+
+# Special RR types
+QTANY = 255
+QTMAILA = 254
+QTMAILB = 253
+QTAXFR = 252
diff --git a/ldd/rec.py b/ldd/rec.py
new file mode 100644 (file)
index 0000000..4145343
--- /dev/null
@@ -0,0 +1,342 @@
+#    ldd - DNS implementation in Python
+#    Copyright (C) 2006 Fredrik Tolf <fredrik@dolda2000.com>
+#
+#    This program is free software; you can redistribute it and/or modify
+#    it under the terms of the GNU General Public License as published by
+#    the Free Software Foundation; either version 2 of the License, or
+#    (at your option) any later version.
+#
+#    This program is distributed in the hope that it will be useful,
+#    but WITHOUT ANY WARRANTY; without even the implied warranty of
+#    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+#    GNU General Public License for more details.
+#
+#    You should have received a copy of the GNU General Public License
+#    along with this program; if not, write to the Free Software
+#    Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA  02110-1301  USA
+
+import socket
+import struct
+
+import proto
+import dn
+
+rtypes = []
+
+def addrtype(id, name, syntax):
+    rtypes.append((id, name, syntax))
+
+def rtypebyid(id):
+    for rtype in rtypes:
+        if rtype[0] == id:
+            return rtype
+    return None
+
+def rtypebyname(name):
+    for rtype in rtypes:
+        if rtype[1] == name.upper():
+            return rtype[0]
+    return None
+
+class error(Exception):
+    def __init__(self, text):
+        self.text = text
+
+    def __str__(self):
+        return self.text
+
+class malformedrr(Exception):
+    def __init__(self, text):
+        self.text = text
+
+    def __str__(self):
+        return self.text
+
+class rrhead:
+    def __init__(self, name = None, rtype = None, rclass = None):
+        if rclass is None: rclass = CLASSIN
+        if type(name) == str:
+            self.name = dn.fromstring(name)
+        else:
+            self.name = name
+        if type(rtype) == str:
+            self.rtype = rtypebyname(rtype)
+            if self.rtype is None:
+                raise error("no such rtype " + rtype)
+        else:
+            self.rtype = rtype
+        self.rclass = rclass
+
+    def encode(self, names, offset):
+        ret, names = proto.encodename(self.name, names, offset)
+        ret += struct.pack(">HH", self.rtype, self.rclass)
+        return ret, names
+
+    def __eq__(self, other):
+        return self.name == other.name and self.rtype == other.rtype
+
+    def __str__(self):
+        rtype = rtypebyid(self.rtype)
+        if rtype is None:
+            return "%02x RRhead %s" % (self.rtype, self.name)
+        else:
+            return "%s RRhead %s" % (rtype[1], self.name)
+
+    def istype(self, rtype):
+        if type(rtype) == str:
+            rtype = rtypebyname(rtype)
+        return self.rtype == rtype
+        
+    def decode(self, packet, offset):
+        name, offset = proto.decodename(packet, offset)
+        rtype, rclass = struct.unpack(">HH", packet[offset:offset + struct.calcsize(">HH")])
+        offset += struct.calcsize(">HH")
+        ret = rrhead(name, rtype, rclass)
+        return ret, offset
+    decode = classmethod(decode)
+
+class rrdata:
+    def __init__(self, rtype, *args):
+        if type(rtype) == tuple and type(args[0]) == dict:
+            self.rtype = rtype
+            self.rdata = args[0]
+            return
+        
+        if type(rtype) == str:
+            self.rtype = rtypebyname(rtype)
+            if self.rtype is None:
+                raise error("no such rtype " + rtype)
+        else:
+            self.rtype = rtype
+        rtid = self.rtype
+        self.rtype = rtypebyid(rtid)
+        if self.rtype is None:
+            raise error("no such rtype " + rtid)
+        self.rdata = {}
+        for i, e in enumerate(self.rtype[2]):
+            d = self.convdata(e[0], args[i])
+            self.rdata[e[1]] = d
+
+    def __eq__(self, other):
+        return(self.rdata == other.rdata)
+
+    def __str__(self):
+        ret = "{"
+        first = True
+        for e in self.rtype[2]:
+            if not first:
+                ret += ", "
+            first = False
+            ret += e[1] + ": "
+            d = self.rdata[e[1]]
+            if e[0] == "4":
+                ret += socket.inet_ntop(socket.AF_INET, d)
+            elif e[0] == "6":
+                ret += socket.inet_ntop(socket.AF_INET6, d)
+            elif e[0] == "s":
+                ret += '"' + d + '"'
+            else:
+                ret += str(d)
+        ret += "}"
+        return ret
+
+    def istype(self, rtype):
+        if type(rtype) == str:
+            rtype = rtypebyname(rtype)
+        return self.rtype[0] == rtype
+        
+    def convdata(self, dtype, data):
+        if dtype == "4":
+            if type(data) != str:
+                raise error("IPv4 address must be a string")
+            if len(data) == 4:
+                d = data
+            else:
+                d = socket.inet_pton(socket.AF_INET, data)
+        if dtype == "6":
+            if type(data) != str:
+                raise error("IPv6 address must be a string")
+            if len(data) == 16 and data.find(":") == -1:
+                d = data
+            else:
+                d = socket.inet_pton(socket.AF_INET6, data)
+        if dtype == "d":
+            if type(data) == str:
+                d = dn.fromstring(data)
+            elif isinstance(data, dn.domainname):
+                d = data
+            else:
+                raise error("Domain name must be either proper or string")
+        if dtype == "s":
+            d = str(data)
+        if dtype == "i":
+            d = int(data)
+        return d
+    
+    def __iter__(self):
+        return iter(self.rdata)
+    
+    def __getitem__(self, i):
+        return self.rdata[i]
+
+    def __setitem__(self, i, v):
+        for e in self.rtype[2]:
+            if e[1] == i:
+                break
+        else:
+            raise error("No such data for " + self.rtype[1] + " record: " + str(i))
+        self.rdata[i] = self.convdata(e[0], v)
+
+    def encode(self, names, offset):
+        ret = ""
+        for e in self.rtype[2]:
+            d = self.rdata[e[1]]
+            if e[2] == "strc":
+                ret += d
+                offset += len(d)
+            if e[2] == "cmdn":
+                buf, names = proto.encodename(d, names, offset)
+                ret += buf
+                offset += len(buf)
+            if e[2] == "lstr":
+                ret += chr(len(d)) + d
+                offset += 1 + len(d)
+            if e[2] == "llstr":
+                ret += struct.pack(">H", len(d)) + d
+                offset += struct.calcsize(">H") + len(d)
+            if e[2] == "short":
+                ret += struct.pack(">H", d)
+                offset += struct.calcsize(">H")
+            if e[2] == "long":
+                ret += struct.pack(">L", d)
+                offset += struct.calcsize(">L")
+            if e[2] == "int6":
+                ret += struct.pack(">Q", d)[-6:]
+                offset += 6
+        return ret, names
+
+    def decode(self, rtid, packet, offset, dlen):
+        rtype = rtypebyid(rtid)
+        origoff = offset
+        rdata = {}
+        if rtype is None:
+            rtype = (rtid, "Unknown", [("s", "unknown", "strc", dlen)])
+        for e in rtype[2]:
+            if e[2] == "strc":
+                d = packet[offset:offset + e[3]]
+                offset += e[3]
+            if e[2] == "cmdn":
+                d, offset = proto.decodename(packet, offset)
+            if e[2] == "lstr":
+                dl = ord(packet[offset])
+                offset += 1
+                d = packet[offset:offset + dl]
+                offset += dl
+            if e[2] == "llstr":
+                (dl,) = struct.unpack(">H", packet[offset:offset + struct.calcsize(">H")])
+                offset += struct.calcsize(">H")
+                d = packet[offset:offset + dl]
+                offset += dl
+            if e[2] == "short":
+                (d,) = struct.unpack(">H", packet[offset:offset + struct.calcsize(">H")])
+                offset += struct.calcsize(">H")
+            if e[2] == "long":
+                (d,) = struct.unpack(">L", packet[offset:offset + struct.calcsize(">L")])
+                offset += struct.calcsize(">L")
+            if e[2] == "int6":
+                (d,) = struct.unpack(">Q", ("\0" * (struct.calcsize(">Q") - 6)) + packet[offset:offset + 6])
+                offset += 6
+            rdata[e[1]] = d
+        if origoff + dlen != offset:
+            raise malformedrr(rtype[1] + " RR data length mismatch")
+        return rrdata(rtype, rdata)
+    decode = classmethod(decode)
+
+class rr:
+    def __init__(self, head, ttl, data):
+        if type(head) == tuple:
+            self.head = rrhead(*head)
+        else:
+            self.head = head
+        self.ttl = ttl
+        self.data = data
+        self.flags = set()
+
+    def setflags(self, flags):
+        self.flags |= set(flags)
+
+    def clrflags(self, flags):
+        self.flags -= set(flags)
+    
+    def encode(self, names, offset):
+        ret, names = self.head.encode(names, offset)
+        if self.data is None:
+            data = ""
+        else:
+            data, names = self.data.encode(names, offset + len(ret) + struct.calcsize(">LH"))
+        ret += struct.pack(">LH", self.ttl, len(data))
+        ret += data
+        return ret, names
+
+    def __eq__(self, other):
+        return self.head == other.head and self.ttl == other.ttl and self.data == other.data
+
+    def __str__(self):
+        rtype = rtypebyid(self.head.rtype)
+        if rtype is None:
+            ret = "%02x" % self.head.rtype
+        else:
+            ret = rtype[1]
+        ret += " RR %s, TTL=%i: %s" % (self.head.name, self.ttl, self.data)
+        if len(self.flags) > 0:
+            ret += " (Flags:"
+            for f in self.flags:
+                ret += " " + f
+            ret += ")"
+        return ret
+    
+    def decode(self, packet, offset):
+        head, offset = rrhead.decode(packet, offset)
+        ttl, dlen = struct.unpack(">LH", packet[offset:offset + struct.calcsize(">LH")])
+        offset += struct.calcsize(">LH")
+        if dlen == 0:
+            data = None
+        else:
+            data = rrdata.decode(head.rtype, packet, offset, dlen)
+            offset += dlen
+        return rr(head, ttl, data), offset
+    decode = classmethod(decode)
+
+addrtype(0x01, "A", [("4", "address", "strc", 4)])
+addrtype(0x02, "NS", [("d", "nsname", "cmdn")])
+addrtype(0x05, "CNAME", [("d", "priname", "cmdn")])
+addrtype(0x06, "SOA", [("d", "priserv", "cmdn"),
+                       ("d", "mailbox", "cmdn"),
+                       ("i", "serial", "long"),
+                       ("i", "refresh", "long"),
+                       ("i", "retry", "long"),
+                       ("i", "expire", "long"),
+                       ("i", "minttl", "long")])
+addrtype(0x0c, "PTR", [("d", "target", "cmdn")])
+addrtype(0x0f, "MX", [("i", "prio", "short"),
+                      ("d", "target", "cmdn")])
+addrtype(0x10, "TXT", [("s", "rrtext", "lstr")])
+addrtype(0x1c, "AAAA", [("6", "address", "strc", 16)])
+addrtype(0x21, "SRV", [("i", "prio", "short"),
+                       ("i", "weight", "short"),
+                       ("i", "port", "short"),
+                       ("d", "target", "cmdn")])
+addrtype(0xfa, "TSIG", [("d", "algo", "cmdn"),
+                        ("i", "stime", "int6"),
+                        ("i", "fudge", "short"),
+                        ("s", "mac", "llstr"),
+                        ("i", "orgid", "short"),
+                        ("i", "err", "short"),
+                        ("s", "other", "llstr")])
+
+CLASSIN = 1
+CLASSCS = 2
+CLASSCH = 3
+CLASSHS = 4
+CLASSNONE = 254
+CLASSANY = 255
diff --git a/ldd/rescache.py b/ldd/rescache.py
new file mode 100644 (file)
index 0000000..36e5731
--- /dev/null
@@ -0,0 +1,144 @@
+#    ldd - DNS implementation in Python
+#    Copyright (C) 2006 Fredrik Tolf <fredrik@dolda2000.com>
+#
+#    This program is free software; you can redistribute it and/or modify
+#    it under the terms of the GNU General Public License as published by
+#    the Free Software Foundation; either version 2 of the License, or
+#    (at your option) any later version.
+#
+#    This program is distributed in the hope that it will be useful,
+#    but WITHOUT ANY WARRANTY; without even the implied warranty of
+#    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+#    GNU General Public License for more details.
+#
+#    You should have received a copy of the GNU General Public License
+#    along with this program; if not, write to the Free Software
+#    Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA  02110-1301  USA
+
+import threading
+import time
+
+import resolver
+import proto
+import rec
+
+class nxdmark:
+    def __init__(self, expire, auth):
+        self.expire = expire
+        self.auth = auth
+
+class cacheresolver(resolver.resolver):
+    def __init__(self, resolver):
+        self.resolver = resolver
+        self.cache = dict()
+        self.cachelock = threading.Lock()
+
+    def getcached(self, name, rtype = proto.QTANY):
+        self.cachelock.acquire()
+        try:
+            if name not in self.cache:
+                return []
+            now = int(time.time())
+            if isinstance(self.cache[name], nxdmark):
+                if self.cache[name].expire < now:
+                    self.cache[name] = []
+                    return []
+                return self.cache[name]
+            ret = []
+            if rtype == proto.QTANY:
+                cond = lambda rt: True
+            elif type(rtype) == int:
+                cond = lambda rt: rtype == rt
+            elif type(rtype) == str:
+                rtid = rec.rtypebyname(rtype)
+                cond = lambda rt: rtid == rt
+            else:
+                rtset = set([((type(rtid) == str) and rec.rtypebyname(rtid)) or rtid for rtid in rtype])
+                cond = lambda rt: rt in rtset
+            for exp, trd, data, auth in self.cache[name]:
+                if exp > now and cond(trd):
+                    ret += [(rec.rr((name, trd), exp - now, data), auth)]
+            return ret
+        finally:
+            self.cachelock.release()
+
+    def dolookup(self, name, rtype):
+        try:
+            res = self.resolver.squery(name, rtype)
+        except resolver.servfail, resolver.unreachable:
+            return None
+        if res is None:
+            return None
+        if res.rescode == proto.NXDOMAIN:
+            ttl = 300
+            for rr in res.aulist:
+                if rr.head.istype("SOA"):
+                    ttl = rr.data["minttl"]
+            nc = nxdmark(int(time.time()) + ttl, res.aulist)
+            self.cachelock.acquire()
+            try:
+                self.cache[name] = nc
+            finally:
+                self.cachelock.release()
+            return nc
+        now = int(time.time())
+        self.cachelock.acquire()
+        try:
+            alltypes = set([rr.head.rtype for rr in res.allrrs()])
+            for name in set([rr.head.name for rr in res.allrrs()]):
+                if name in self.cache:
+                    self.cache[name] = [cl for cl in self.cache[name] if cl[1] not in alltypes]
+            for rr in res.allrrs():
+                if rr.head.name not in self.cache:
+                    self.cache[rr.head.name] = []
+                self.cache[rr.head.name] += [(now + rr.ttl, rr.head.rtype, rr.data, [rr for rr in res.aulist if rr.head.istype("NS")])]
+            return res
+        finally:
+            self.cachelock.release()
+
+    def addcached(self, packet, cis):
+        for item, auth in cis:
+            packet.addan(item)
+            for ns in auth:
+                packet.addau(ns)
+                nsal = self.getcached(ns.data["nsname"], ["A", "AAAA"])
+                if type(nsal) == list:
+                    for item, auth in nsal:
+                        packet.addad(item)
+
+    def resolve(self, packet):
+        res = proto.responsefor(packet)
+        for q in packet.qlist:
+            name = q.name
+            rtype = q.rtype
+            while True:
+                cis = self.getcached(name, rtype)
+                if isinstance(cis, nxdmark):
+                    if len(packet.qlist) == 1:
+                        res.rescode = proto.NXDOMAIN
+                        res.aulist = cis.auth
+                        return res
+                    continue
+                if len(cis) == 0:
+                    cics = self.getcached(name, "CNAME")
+                    if isinstance(cics, nxdmark):
+                        break
+                    if len(cics) > 0:
+                        self.addcached(res, cics)
+                        name = cics[0][0].data["priname"]
+                        continue
+                break
+            if len(cis) == 0:
+                tres = self.dolookup(name, rtype)
+                if isinstance(tres, nxdmark) and len(packet.qlist) == 1:
+                    res.rescode = proto.NXDOMAIN
+                    res.aulist = tres.auth
+                    return res
+                if tres is None and len(packet.qlist) == 1:
+                    res.rescode = proto.SERVFAIL
+                    return res
+                if tres is not None and tres.rescode == 0:
+                    res.merge(tres)
+            else:
+                self.addcached(res, cis)
+        return res
diff --git a/ldd/resolver.py b/ldd/resolver.py
new file mode 100644 (file)
index 0000000..9189b34
--- /dev/null
@@ -0,0 +1,301 @@
+#    ldd - DNS implementation in Python
+#    Copyright (C) 2006 Fredrik Tolf <fredrik@dolda2000.com>
+#
+#    This program is free software; you can redistribute it and/or modify
+#    it under the terms of the GNU General Public License as published by
+#    the Free Software Foundation; either version 2 of the License, or
+#    (at your option) any later version.
+#
+#    This program is distributed in the hope that it will be useful,
+#    but WITHOUT ANY WARRANTY; without even the implied warranty of
+#    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+#    GNU General Public License for more details.
+#
+#    You should have received a copy of the GNU General Public License
+#    along with this program; if not, write to the Free Software
+#    Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA  02110-1301  USA
+
+import socket
+import select
+import time
+import random
+
+import proto
+import rec
+import dn
+
+class error(Exception):
+    def __init__(self, text):
+        self.text = text
+
+    def __str__(self):
+        return self.text
+
+class servfail(error):
+    def __init__(self):
+        error.__init__(self, "SERVFAIL")
+
+class unreachable(error):
+    def __init__(self, server):
+        error.__init__(self, "could not reach server: " + str(server))
+
+def resolvecnames(pkt, res = None):
+    if res is None: res = default
+    for q in pkt.qlist:
+        cnrr = pkt.getanswer(q.name, rec.rtypebyname("CNAME"))
+        if cnrr is not None:
+            if pkt.getanswer(cnrr.data["priname"], q.rtype) is None:
+                try:
+                    resp = res.squery(cnrr.data["priname"], q.rtype)
+                except error:
+                    continue
+                if resp is None:
+                    continue
+                anrr = resp.getanswer(cnrr.data["priname"], q.rtype)
+                if anrr is None:
+                    continue
+                pkt.addan(anrr)
+
+def resolveadditional(pkt, rr, res = None):
+    if res is None: res = default
+    for name in rr.data:
+        if isinstance(rr.data[name], dn.domainname):
+            for rtype in ["A", "AAAA"]:
+                if pkt.getanswer(rr.data[name], rtype) is not None:
+                    continue
+                try:
+                    resp = res.squery(rr.data[name], rtype)
+                except error:
+                    continue
+                if resp is None:
+                    continue
+                anrr = resp.getanswer(rr.data[name], rtype)
+                if anrr is None:
+                    continue
+                pkt.addad(anrr)
+
+def extractaddrinfo(packet, name):
+    ret = []
+    for rr in packet.anlist + packet.adlist:
+        if rr.head.name == name:
+            if rr.head.istype("A"):
+                ret += [(socket.AF_INET, socket.inet_ntop(socket.AF_INET, rr.data["address"]))]
+            elif rr.head.istype("AAAA"):
+                ret += [(socket.AF_INET6, socket.inet_ntop(socket.AF_INET6, rr.data["address"]))]
+    return ret
+
+def resolve(packet, nameserver, recurse, retries = 3, timeout = 2000, hops = 0, cnameres = None, verbose = False, visited = None):
+    if cnameres is None: cnameres = default
+    if visited is None: visited = set()
+    visited |= set([nameserver])
+    sk = socket.socket(nameserver[0], socket.SOCK_DGRAM)
+    sk.bind(("", 0))
+    for i in range(retries):
+        sk.sendto(packet.encode(), nameserver[1:])
+        p = select.poll()
+        p.register(sk.fileno(), select.POLLIN)
+        fds = p.poll(timeout)
+        if (sk.fileno(), select.POLLIN) in fds:
+            break
+    else:
+        raise unreachable(nameserver)
+    ret = sk.recv(65536)
+    sk.close()
+    try:
+        resp = proto.decodepacket(ret)
+    except proto.malformedpacket, inst:
+        raise error(str(inst))
+    if resp.qid != packet.qid:
+        raise error("got response with wrong qid(?!)")
+    if "resp" not in resp.flags:
+        raise error("got query in response")
+    if resp.rescode != 0:
+        if resp.rescode == proto.SERVFAIL:
+            raise servfail()
+        if resp.rescode == proto.NXDOMAIN:
+            return resp
+        raise error("non-successful response (" + str(resp.rescode) + ")")
+    if recurse:
+        resolvecnames(resp, cnameres)
+    if not recurse or resp.hasanswers():
+        return resp
+    if not resp.hasanswers() and "auth" in resp.flags:
+        return resp
+    if hops > 30:
+        raise error("too many levels deep")
+    for rr in resp.aulist:
+        if verbose:
+            print (hops * " ") + "Checking " + str(rr)
+        if rr.head.istype("NS"):
+            if verbose:
+                print (hops * " ") + "Will try " + str(rr)
+            ai = extractaddrinfo(resp, rr.data["nsname"])
+            if len(ai) == 0:
+                if verbose:
+                    print (hops * " ") + "Resolving nameservers for " + str(rr.data["nsname"])
+                resolveadditional(resp, rr)
+                ai = extractaddrinfo(resp, rr.data["nsname"])
+            for ns in ai:
+                ns += (53,)
+                if ns in visited:
+                    if verbose:
+                        print (hops * " ") + "Will not try " + str(ns) + " again"
+                    continue
+                if verbose:
+                    print (hops * " ") + "Trying " + str(ns)
+                try:
+                    resp2 = resolve(packet, ns, recurse, retries, timeout, hops + 1, verbose = verbose, visited = visited)
+                except unreachable:
+                    if verbose:
+                        print (hops * " ") + "Could not reach " + str(ns)
+                    continue
+                if verbose:
+                    if resp2 is None:
+                        print (hops * " ") + "Got None"
+                    else:
+                        if "auth" in resp2.flags:
+                            austr = "Auth"
+                        else:
+                            austr = "Nonauth"
+                        print (hops * " ") + "Got " + str(resp2.hasanswers()) + " (" + austr + ")"
+                if resp2 is not None and resp2.hasanswers():
+                    return resp2
+                if resp2 is not None and not resp2.hasanswers() and "auth" in resp2.flags:
+                    return resp2
+    return None
+
+class resolver:
+    def __init__(self, nameserver, recurse, nsrecurse = True, retries = 3, timeout = 2000, verbose = False):
+        self.nameserver = nameserver
+        self.recurse = recurse
+        self.nsrecurse = nsrecurse
+        self.retries = retries
+        self.timeout = timeout
+        self.verbose = verbose
+
+    def resolve(self, packet):
+        return resolve(packet, self.nameserver, self.recurse, self.retries, self.timeout, verbose = self.verbose)
+
+    def squery(self, name, rtype):
+        packet = proto.packet()
+        try:
+            if self.nsrecurse: packet.setflags(["recurse"])
+        except AttributeError: pass
+        packet.addq(rec.rrhead(name, rtype))
+        return self.resolve(packet)
+
+class multiresolver(resolver):
+    def __init__(self, resolvers):
+        self.rl = [{"res": res, "qs": []} for res in resolvers]
+        self.lastclean = int(time.time())
+
+    def clean(self):
+        now = int(time.time())
+        if now - self.lastclean < 60:
+            return
+        self.lastclean = now
+        for r in self.rl:
+            nl = []
+            for q in r["qs"]:
+                if now - q["time"] < 1800:
+                    nl += [q]
+            r["qs"] = nl
+        
+    def resolve(self, packet):
+        self.clean()
+        l = []
+        ts = 0
+        for r in self.rl:
+            if len(r["qs"]) < 1:
+                score = 1.0
+            else:
+                score = float(sum([q["s"] for q in r["qs"]])) / len(r["qs"])
+            l += [(score, r)]
+            ts += score
+        c = random.random() * ts
+        for score, r in l:
+            c -= score
+            if c <= 0:
+                break
+        else:
+            assert(False)
+        try:
+            res = r["res"].resolve(packet)
+        except error:
+            r["qs"] = r["qs"][:10] + [{"time": int(time.time()), "s": 0}]
+            raise
+        r["qs"] = r["qs"][:10] + [{"time": int(time.time()), "s": 1}]
+        return res
+
+class sysresolver(resolver):
+    def __init__(self, conffile = "/etc/resolv.conf"):
+        nslist = []
+        prelist = []
+        a = open(conffile, "r")
+        for line in (l.strip() for l in a):
+            p = line.find(" ")
+            if p >= 0:
+                c = line[:p]
+                line = line[p + 1:]
+                if c == "nameserver":
+                    try:
+                        socket.inet_pton(socket.AF_INET, line)
+                    except socket.error: pass
+                    else:
+                        nslist += [(socket.AF_INET, line, 53)]
+                    try:
+                        socket.inet_pton(socket.AF_INET6, line)
+                    except socket.error: pass
+                    else:
+                        nslist += [(socket.AF_INET6, line, 53)]
+                if c == "domain" or c == "search":     # How do these differ?
+                    prelist += line.split()
+        a.close()
+        rl = []
+        for ns in nslist:
+            rl += [resolver(ns, False, True)]
+        self.resolver = multiresolver(rl)
+        self.prelist = []
+        for prefix in prelist:
+            pp = dn.fromstring(prefix)
+            pp.rooted = True
+            self.prelist += [pp]
+
+    def resolve(self, packet):
+        res = self.resolver.resolve(packet)
+        return res
+
+    def squery(self, name, rtype):
+        if type(name) == str:
+            name = dn.fromstring(name)
+        if not name.rooted:
+            namelist = [name + prefix for prefix in self.prelist] + [name + dn.fromstring(".")]
+        else:
+            namelist = [name]
+        for name in namelist:
+            packet = proto.packet()
+            packet.setflags(["recurse"])
+            packet.addq(rec.rrhead(name, rtype))
+            res = self.resolve(packet)
+            if res.rescode == 0:
+                break
+        return res
+
+sysres = sysresolver()
+rootresolvers = {"a": resolver((socket.AF_INET, "198.41.0.4", 53), True, False),
+                 "b": resolver((socket.AF_INET, "192.228.79.201", 53), True, False),
+                 "c": resolver((socket.AF_INET, "192.33.4.12", 53), True, False),
+                 "d": resolver((socket.AF_INET, "128.8.10.90", 53), True, False),
+                 "e": resolver((socket.AF_INET, "192.203.230.10", 53), True, False),
+                 "f": resolver((socket.AF_INET, "192.5.5.241", 53), True, False),
+                 "g": resolver((socket.AF_INET, "192.112.36.4", 53), True, False),
+                 "h": resolver((socket.AF_INET, "128.63.2.53", 53), True, False),
+                 "i": resolver((socket.AF_INET, "192.36.148.17", 53), True, False),
+                 "j": resolver((socket.AF_INET, "192.58.128.30", 53), True, False),
+                 "k": resolver((socket.AF_INET, "193.0.14.129", 53), True, False),
+                 "l": resolver((socket.AF_INET, "198.32.64.12", 53), True, False),
+                 "m": resolver((socket.AF_INET, "202.12.27.33", 53), True, False)
+                 }
+rootres = multiresolver(rootresolvers.values())
+
+default = sysres
diff --git a/ldd/server.py b/ldd/server.py
new file mode 100644 (file)
index 0000000..86fd885
--- /dev/null
@@ -0,0 +1,317 @@
+#    ldd - DNS implementation in Python
+#    Copyright (C) 2006 Fredrik Tolf <fredrik@dolda2000.com>
+#
+#    This program is free software; you can redistribute it and/or modify
+#    it under the terms of the GNU General Public License as published by
+#    the Free Software Foundation; either version 2 of the License, or
+#    (at your option) any later version.
+#
+#    This program is distributed in the hope that it will be useful,
+#    but WITHOUT ANY WARRANTY; without even the implied warranty of
+#    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+#    GNU General Public License for more details.
+#
+#    You should have received a copy of the GNU General Public License
+#    along with this program; if not, write to the Free Software
+#    Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA  02110-1301  USA
+
+import socket
+import threading
+import select
+import errno
+import logging
+import time
+
+import proto
+import rec
+import dn
+import resolver
+
+logger = logging.getLogger("ldd.server")
+
+class dnsserver:
+    class socklistener(threading.Thread):
+        def __init__(self, server):
+            threading.Thread.__init__(self)
+            self.server = server
+            self.alive = True
+
+        class sender:
+            def __init__(self, addr, sk):
+                self.addr = addr
+                self.sk = sk
+
+            def send(self, pkt):
+                logger.debug("sending response to %04x", pkt.qid)
+                self.sk.sendto(pkt.encode(), self.addr)
+
+        def run(self):
+            while self.alive:
+                p = select.poll()
+                for af, sk in self.server.sockets:
+                    p.register(sk.fileno(), select.POLLIN)
+                try:
+                    fds = p.poll(1000)
+                except select.error, e:
+                    if e[0] == errno.EINTR:
+                        continue
+                    raise
+                for fd, event in fds:
+                    if event & select.POLLIN == 0:
+                        continue
+                    for af, sk in self.server.sockets:
+                        if sk.fileno() == fd:
+                            break
+                    else:
+                        continue
+                    req, addr = sk.recvfrom(65536)
+                    try:
+                        pkt = proto.decodepacket(req)
+                    except proto.malformedpacket, inst:
+                        resp = proto.packet(inst.qid, ["resp"])
+                        resp.rescode = proto.FORMERR
+                        sk.sendto(resp.encode(), addr)
+                    else:
+                        logger.debug("got request (%04x) from %s", pkt.qid, addr[0])
+                        pkt.addr = (af,) + addr
+                        self.server.queuereq(pkt, dnsserver.socklistener.sender(addr, sk))
+
+        def stop(self):
+            self.alive = False
+
+    class dispatcher(threading.Thread):
+        def __init__(self, server):
+            threading.Thread.__init__(self)
+            self.server = server
+            self.alive = True
+
+        def run(self):
+            while self.alive:
+                req = self.server.dequeuereq()
+                if req is not None:
+                    pkt, sender = req
+                    resp = self.server.handle(pkt)
+                    if resp is None:
+                        resp = proto.responsefor(pkt, proto.SERVFAIL)
+                    sender.send(resp)
+
+    class queuemonitor(threading.Thread):
+        def __init__(self, server):
+            threading.Thread.__init__(self)
+            self.server = server
+
+        def run(self):
+            while(self.server.running):
+                self.server.queuelock.acquire()
+                if len(self.server.queue) > 0:
+                    peeked = self.server.queue[0]
+                else:
+                    peeked = None
+                self.server.queuelock.release()
+                if peeked is not None:
+                    if time.time() - peeked[0] > 1:
+                        newdsp = dnsserver.dispatcher(self.server)
+                        self.server.dispatchers += [newdsp]
+                        newdsp.start()
+                        logger.debug("starting new dispatcher, there are now %i", len(self.server.dispatchers))
+                time.sleep(1)
+
+    def __init__(self):
+        self.sockets = []
+        self.queue = []
+        self.zones = []
+        self.listener = None
+        self.dispatchers = []
+        self.running = False
+        self.queuelock = threading.Condition()
+        self.knownkeys = []
+
+    def handle(self, pkt):
+        resp = None
+
+        if len(self.knownkeys) > 0:
+            import dnssec
+            dnssec.tsigverify(pkt, self.knownkeys)
+        
+        for query in pkt.qlist:
+            match = None
+            for zone in self.zones:
+                if query.name in zone.origin:
+                    if match is None:
+                        match = zone
+                    elif len(zone.origin) > len(match.origin):
+                        match = zone
+            if match is None:
+                return None
+            else:
+                curresp = match.handle(query, pkt)
+                if resp is None:
+                    resp = curresp
+                else:
+                    resp.merge(curresp)
+        
+        if resp is not None and resp.tsigctx is not None and not resp.signed:
+            resp.tsigctx.signpkt(resp)
+
+        return resp
+
+    def addsock(self, af, socket):
+        self.sockets += [(af, socket)]
+
+    def addzone(self, zone):
+        self.zones += [zone]
+
+    def queuereq(self, req, sender):
+        self.queuelock.acquire()
+        self.queue += [(time.time(), req, sender)]
+        logger.debug("queue length+: %i", len(self.queue))
+        self.queuelock.notify()
+        self.queuelock.release()
+
+    def dequeuereq(self):
+        self.queuelock.acquire()
+        if len(self.queue) == 0:
+            self.queuelock.wait()
+        if len(self.queue) > 0:
+            ret = self.queue[0]
+            self.queue = self.queue[1:]
+        else:
+            ret = None
+        logger.debug("queue length-: %i", len(self.queue))
+        self.queuelock.release()
+        if ret is None:
+            return None
+        else:
+            return ret[1:]
+
+    def start(self):
+        if self.running:
+            raise Exception("already running")
+        lst = dnsserver.socklistener(self)
+        self.listener = lst
+        lst.start()
+        for i in xrange(10):
+            newdsp = dnsserver.dispatcher(self)
+            self.dispatchers += [newdsp]
+            newdsp.start()
+        self.running = True
+        self.monitor = dnsserver.queuemonitor(self)
+        self.monitor.start()
+
+    def stop(self):
+        self.listener.stop()
+        self.listener.join()
+        self.listener = None
+        for dsp in self.dispatchers:
+            dsp.alive = False
+        self.queuelock.acquire()
+        self.queuelock.notifyAll()
+        self.queuelock.release()
+        for dsp in self.dispatchers + []:
+            dsp.join()
+            self.dispatchers.remove(dsp)
+        self.running = False
+        self.monitor = None
+
+    def resolver(self, addr = None):
+        class myres(resolver.resolver):
+            def __init__(self, server, addr):
+                self.server = server
+                self.addr = addr
+            def resolve(self, packet):
+                if self.addr is not None:
+                    packet.addr = self.addr
+                packet.setflags(["internal"])
+                return self.server.handle(packet)
+        return myres(self, addr)
+
+class zone:
+    def __init__(self, origin, handler):
+        if type(origin) == str:
+            self.origin = dn.fromstring(origin)
+            self.origin.rooted = True
+        else:
+            self.origin = origin
+        self.handler = handler
+
+    def handle(self, query, pkt):
+        resp = self.handler.handle(query, pkt, self.origin)
+        return resp
+
+class authzone(zone):
+    def __init__(self, aurecres, *args):
+        self.aurecres = aurecres
+        zone.__init__(self, *args)
+
+    def handle(self, query, pkt):
+        resp = zone.handle(self, query, pkt)
+        if not "internal" in pkt.flags:
+            if resp is None:
+                resp = proto.responsefor(pkt)
+                soa = zone.handle(self, rec.rrhead(self.origin, "SOA"), pkt)
+                resp.aulist += soa.anlist
+                resp.rescode = proto.NXDOMAIN
+            else:
+                resolver.resolvecnames(resp, self.aurecres)
+                nsrecs = zone.handle(self, rec.rrhead(self.origin, "NS"), pkt)
+                if nsrecs is not None:
+                    resp.aulist += nsrecs.anlist
+                    for rr in nsrecs.anlist:
+                        resolver.resolveadditional(resp, rr, self.aurecres)
+        else:
+            if resp is None:
+                return None
+        resp.setflags(["auth"])
+        return resp
+
+class handler:
+    def handle(self, query, pkt, origin):
+        return None
+
+class forwarder(handler):
+    def __init__(self, nameserver, timeout = 2000, retries = 3):
+        self.nameserver = nameserver
+        self.timeout = timeout
+        self.retries = retries
+    
+    def handle(self, query, pkt, origin):
+        sk = socket.socket(self.nameserver[0], socket.SOCK_DGRAM)
+        sk.bind(("", 0))
+        p = select.poll()
+        p.register(sk.fileno(), select.POLLIN)
+        for i in range(self.retries):
+            sk.sendto(pkt.encode(), self.nameserver[1:])
+            fds = p.poll(self.timeout)
+            if (sk.fileno(), select.POLLIN) in fds:
+                break
+        else:
+            return None
+        resp = sk.recv(65536)
+        resp = proto.decodepacket(resp)
+        return resp
+    
+class recurser(handler):
+    def __init__(self, resolver):
+        self.resolver = resolver
+
+    def handle(self, query, pkt, origin):
+        try:
+            resp = self.resolver.resolve(pkt)
+        except resolver.error:
+            return None
+        return resp
+
+class chain(handler):
+    def __init__(self, chain):
+        self.chain = chain
+
+    def add(self, handler):
+        self.chain += [handler]
+    
+    def handle(self, *args):
+        for h in self.chain:
+            resp = h.handle(*args)
+            if resp is not None:
+                return resp
+        return None
+
diff --git a/lddd b/lddd
new file mode 100755 (executable)
index 0000000..aa1411a
--- /dev/null
+++ b/lddd
@@ -0,0 +1,78 @@
+#!/usr/bin/python
+#    ldd - DNS implementation in Python
+#    Copyright (C) 2006 Fredrik Tolf <fredrik@dolda2000.com>
+#
+#    This program is free software; you can redistribute it and/or modify
+#    it under the terms of the GNU General Public License as published by
+#    the Free Software Foundation; either version 2 of the License, or
+#    (at your option) any later version.
+#
+#    This program is distributed in the hope that it will be useful,
+#    but WITHOUT ANY WARRANTY; without even the implied warranty of
+#    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+#    GNU General Public License for more details.
+#
+#    You should have received a copy of the GNU General Public License
+#    along with this program; if not, write to the Free Software
+#    Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA  02110-1301  USA
+
+
+import os
+import sys
+import getopt
+import socket
+import signal
+import imp
+import logging
+
+from ldd import server
+
+cfname = "/etc/lddd/conf"
+port = 53
+daemonize = True
+opts, args = getopt.getopt(sys.argv[1:], "ndc:p:")
+for o, a in opts:
+    if o == "-d":
+        logging.basicConfig(level = logging.DEBUG)
+        daemonize = False
+    if o == "-c":
+        cfname = a
+    if o == "-n":
+        daemonize = False
+    if o == "-p":
+        port = int(a)
+
+logger = logging.getLogger("ldd.daemon")
+
+def diehandler(signum, frame):
+    global alive
+    alive = False
+
+for sig in [getattr(signal, "SIG" + s) for s in ["INT", "TERM"]]:
+    signal.signal(sig, diehandler)
+
+srv = server.dnsserver()
+
+cf = open(cfname, "r")
+cmod = imp.load_module("servconf", cf, cfname, ("", "r", imp.PY_SOURCE))
+cf.close()
+
+cmod.setup(srv)
+if(len(srv.sockets) < 1):
+    sk = socket.socket(socket.AF_INET6, socket.SOCK_DGRAM)
+    sk.bind(("", port))
+    srv.addsock(socket.AF_INET6, sk)
+logger.info("config OK, starting server")
+
+alive = True
+srv.start()
+
+if daemonize:
+    if(os.fork() != 0):
+        sys.exit(0)
+    os.chdir("/")
+
+while alive:
+    signal.pause()
+logger.info("terminating")
+srv.stop()
diff --git a/resolve b/resolve
new file mode 100755 (executable)
index 0000000..2321b25
--- /dev/null
+++ b/resolve
@@ -0,0 +1,63 @@
+#!/usr/bin/python
+#    ldd - DNS implementation in Python
+#    Copyright (C) 2006 Fredrik Tolf <fredrik@dolda2000.com>
+#
+#    This program is free software; you can redistribute it and/or modify
+#    it under the terms of the GNU General Public License as published by
+#    the Free Software Foundation; either version 2 of the License, or
+#    (at your option) any later version.
+#
+#    This program is distributed in the hope that it will be useful,
+#    but WITHOUT ANY WARRANTY; without even the implied warranty of
+#    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+#    GNU General Public License for more details.
+#
+#    You should have received a copy of the GNU General Public License
+#    along with this program; if not, write to the Free Software
+#    Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA  02110-1301  USA
+
+
+import socket
+import sys
+import getopt
+import time
+
+from ldd import resolver
+
+nameserver = (socket.AF_INET, "198.41.0.4", 53)
+rtype = "a"
+lrec = True
+srec = False
+verbose = False
+opts, args = getopt.getopt(sys.argv[1:], "vrRs:t:p:")
+
+for o, a in opts:
+    if o == "-s":
+        nameserver = (socket.AF_INET, a, 53)
+    if o == "-t":
+        if a == "any":
+            rtype = 255
+        else:
+            rtype = a
+    if o == "-p":
+        nameserver = nameserver[0:2] + (int(a),)
+    if o == "-r":
+        lrec = False
+    if o == "-R":
+        srec = True
+    if o == "-v":
+        verbose = True
+
+if len(args) < 1:
+    print "No target given"
+    sys.exit(1)
+
+res = resolver.resolver(nameserver, lrec, srec, verbose = verbose)
+try:
+    rsp = res.squery(args[0], rtype)
+except resolver.error, inst:
+    print "error: " + str(inst)
+except KeyboardInterrupt:
+    sys.exit(1)
+else:
+    print str(rsp)
diff --git a/setup.py b/setup.py
new file mode 100644 (file)
index 0000000..aa602b3
--- /dev/null
+++ b/setup.py
@@ -0,0 +1,27 @@
+#    ldd - DNS implementation in Python
+#    Copyright (C) 2006 Fredrik Tolf <fredrik@dolda2000.com>
+#
+#    This program is free software; you can redistribute it and/or modify
+#    it under the terms of the GNU General Public License as published by
+#    the Free Software Foundation; either version 2 of the License, or
+#    (at your option) any later version.
+#
+#    This program is distributed in the hope that it will be useful,
+#    but WITHOUT ANY WARRANTY; without even the implied warranty of
+#    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+#    GNU General Public License for more details.
+#
+#    You should have received a copy of the GNU General Public License
+#    along with this program; if not, write to the Free Software
+#    Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA  02110-1301  USA
+
+from distutils.core import setup
+
+setup(name="ldd",
+      version="0.1",
+      description="DNS implementation in Python",
+      author="Fredrik Tolf",
+      author_email="fredrik@dolda2000.com",
+      url="http://www.dolda2000.com/~fredrik/ldd/",
+      packages=["ldd"],
+      )