Initial import
[ldd.git] / ldd / resolver.py
1 #    ldd - DNS implementation in Python
2 #    Copyright (C) 2006 Fredrik Tolf <fredrik@dolda2000.com>
3 #
4 #    This program is free software; you can redistribute it and/or modify
5 #    it under the terms of the GNU General Public License as published by
6 #    the Free Software Foundation; either version 2 of the License, or
7 #    (at your option) any later version.
8 #
9 #    This program is distributed in the hope that it will be useful,
10 #    but WITHOUT ANY WARRANTY; without even the implied warranty of
11 #    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12 #    GNU General Public License for more details.
13 #
14 #    You should have received a copy of the GNU General Public License
15 #    along with this program; if not, write to the Free Software
16 #    Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA  02110-1301  USA
17
18 import socket
19 import select
20 import time
21 import random
22
23 import proto
24 import rec
25 import dn
26
27 class error(Exception):
28     def __init__(self, text):
29         self.text = text
30
31     def __str__(self):
32         return self.text
33
34 class servfail(error):
35     def __init__(self):
36         error.__init__(self, "SERVFAIL")
37
38 class unreachable(error):
39     def __init__(self, server):
40         error.__init__(self, "could not reach server: " + str(server))
41
42 def resolvecnames(pkt, res = None):
43     if res is None: res = default
44     for q in pkt.qlist:
45         cnrr = pkt.getanswer(q.name, rec.rtypebyname("CNAME"))
46         if cnrr is not None:
47             if pkt.getanswer(cnrr.data["priname"], q.rtype) is None:
48                 try:
49                     resp = res.squery(cnrr.data["priname"], q.rtype)
50                 except error:
51                     continue
52                 if resp is None:
53                     continue
54                 anrr = resp.getanswer(cnrr.data["priname"], q.rtype)
55                 if anrr is None:
56                     continue
57                 pkt.addan(anrr)
58
59 def resolveadditional(pkt, rr, res = None):
60     if res is None: res = default
61     for name in rr.data:
62         if isinstance(rr.data[name], dn.domainname):
63             for rtype in ["A", "AAAA"]:
64                 if pkt.getanswer(rr.data[name], rtype) is not None:
65                     continue
66                 try:
67                     resp = res.squery(rr.data[name], rtype)
68                 except error:
69                     continue
70                 if resp is None:
71                     continue
72                 anrr = resp.getanswer(rr.data[name], rtype)
73                 if anrr is None:
74                     continue
75                 pkt.addad(anrr)
76
77 def extractaddrinfo(packet, name):
78     ret = []
79     for rr in packet.anlist + packet.adlist:
80         if rr.head.name == name:
81             if rr.head.istype("A"):
82                 ret += [(socket.AF_INET, socket.inet_ntop(socket.AF_INET, rr.data["address"]))]
83             elif rr.head.istype("AAAA"):
84                 ret += [(socket.AF_INET6, socket.inet_ntop(socket.AF_INET6, rr.data["address"]))]
85     return ret
86
87 def resolve(packet, nameserver, recurse, retries = 3, timeout = 2000, hops = 0, cnameres = None, verbose = False, visited = None):
88     if cnameres is None: cnameres = default
89     if visited is None: visited = set()
90     visited |= set([nameserver])
91     sk = socket.socket(nameserver[0], socket.SOCK_DGRAM)
92     sk.bind(("", 0))
93     for i in range(retries):
94         sk.sendto(packet.encode(), nameserver[1:])
95         p = select.poll()
96         p.register(sk.fileno(), select.POLLIN)
97         fds = p.poll(timeout)
98         if (sk.fileno(), select.POLLIN) in fds:
99             break
100     else:
101         raise unreachable(nameserver)
102     ret = sk.recv(65536)
103     sk.close()
104     try:
105         resp = proto.decodepacket(ret)
106     except proto.malformedpacket, inst:
107         raise error(str(inst))
108     if resp.qid != packet.qid:
109         raise error("got response with wrong qid(?!)")
110     if "resp" not in resp.flags:
111         raise error("got query in response")
112     if resp.rescode != 0:
113         if resp.rescode == proto.SERVFAIL:
114             raise servfail()
115         if resp.rescode == proto.NXDOMAIN:
116             return resp
117         raise error("non-successful response (" + str(resp.rescode) + ")")
118     if recurse:
119         resolvecnames(resp, cnameres)
120     if not recurse or resp.hasanswers():
121         return resp
122     if not resp.hasanswers() and "auth" in resp.flags:
123         return resp
124     if hops > 30:
125         raise error("too many levels deep")
126     for rr in resp.aulist:
127         if verbose:
128             print (hops * " ") + "Checking " + str(rr)
129         if rr.head.istype("NS"):
130             if verbose:
131                 print (hops * " ") + "Will try " + str(rr)
132             ai = extractaddrinfo(resp, rr.data["nsname"])
133             if len(ai) == 0:
134                 if verbose:
135                     print (hops * " ") + "Resolving nameservers for " + str(rr.data["nsname"])
136                 resolveadditional(resp, rr)
137                 ai = extractaddrinfo(resp, rr.data["nsname"])
138             for ns in ai:
139                 ns += (53,)
140                 if ns in visited:
141                     if verbose:
142                         print (hops * " ") + "Will not try " + str(ns) + " again"
143                     continue
144                 if verbose:
145                     print (hops * " ") + "Trying " + str(ns)
146                 try:
147                     resp2 = resolve(packet, ns, recurse, retries, timeout, hops + 1, verbose = verbose, visited = visited)
148                 except unreachable:
149                     if verbose:
150                         print (hops * " ") + "Could not reach " + str(ns)
151                     continue
152                 if verbose:
153                     if resp2 is None:
154                         print (hops * " ") + "Got None"
155                     else:
156                         if "auth" in resp2.flags:
157                             austr = "Auth"
158                         else:
159                             austr = "Nonauth"
160                         print (hops * " ") + "Got " + str(resp2.hasanswers()) + " (" + austr + ")"
161                 if resp2 is not None and resp2.hasanswers():
162                     return resp2
163                 if resp2 is not None and not resp2.hasanswers() and "auth" in resp2.flags:
164                     return resp2
165     return None
166
167 class resolver:
168     def __init__(self, nameserver, recurse, nsrecurse = True, retries = 3, timeout = 2000, verbose = False):
169         self.nameserver = nameserver
170         self.recurse = recurse
171         self.nsrecurse = nsrecurse
172         self.retries = retries
173         self.timeout = timeout
174         self.verbose = verbose
175
176     def resolve(self, packet):
177         return resolve(packet, self.nameserver, self.recurse, self.retries, self.timeout, verbose = self.verbose)
178
179     def squery(self, name, rtype):
180         packet = proto.packet()
181         try:
182             if self.nsrecurse: packet.setflags(["recurse"])
183         except AttributeError: pass
184         packet.addq(rec.rrhead(name, rtype))
185         return self.resolve(packet)
186
187 class multiresolver(resolver):
188     def __init__(self, resolvers):
189         self.rl = [{"res": res, "qs": []} for res in resolvers]
190         self.lastclean = int(time.time())
191
192     def clean(self):
193         now = int(time.time())
194         if now - self.lastclean < 60:
195             return
196         self.lastclean = now
197         for r in self.rl:
198             nl = []
199             for q in r["qs"]:
200                 if now - q["time"] < 1800:
201                     nl += [q]
202             r["qs"] = nl
203         
204     def resolve(self, packet):
205         self.clean()
206         l = []
207         ts = 0
208         for r in self.rl:
209             if len(r["qs"]) < 1:
210                 score = 1.0
211             else:
212                 score = float(sum([q["s"] for q in r["qs"]])) / len(r["qs"])
213             l += [(score, r)]
214             ts += score
215         c = random.random() * ts
216         for score, r in l:
217             c -= score
218             if c <= 0:
219                 break
220         else:
221             assert(False)
222         try:
223             res = r["res"].resolve(packet)
224         except error:
225             r["qs"] = r["qs"][:10] + [{"time": int(time.time()), "s": 0}]
226             raise
227         r["qs"] = r["qs"][:10] + [{"time": int(time.time()), "s": 1}]
228         return res
229
230 class sysresolver(resolver):
231     def __init__(self, conffile = "/etc/resolv.conf"):
232         nslist = []
233         prelist = []
234         a = open(conffile, "r")
235         for line in (l.strip() for l in a):
236             p = line.find(" ")
237             if p >= 0:
238                 c = line[:p]
239                 line = line[p + 1:]
240                 if c == "nameserver":
241                     try:
242                         socket.inet_pton(socket.AF_INET, line)
243                     except socket.error: pass
244                     else:
245                         nslist += [(socket.AF_INET, line, 53)]
246                     try:
247                         socket.inet_pton(socket.AF_INET6, line)
248                     except socket.error: pass
249                     else:
250                         nslist += [(socket.AF_INET6, line, 53)]
251                 if c == "domain" or c == "search":     # How do these differ?
252                     prelist += line.split()
253         a.close()
254         rl = []
255         for ns in nslist:
256             rl += [resolver(ns, False, True)]
257         self.resolver = multiresolver(rl)
258         self.prelist = []
259         for prefix in prelist:
260             pp = dn.fromstring(prefix)
261             pp.rooted = True
262             self.prelist += [pp]
263
264     def resolve(self, packet):
265         res = self.resolver.resolve(packet)
266         return res
267
268     def squery(self, name, rtype):
269         if type(name) == str:
270             name = dn.fromstring(name)
271         if not name.rooted:
272             namelist = [name + prefix for prefix in self.prelist] + [name + dn.fromstring(".")]
273         else:
274             namelist = [name]
275         for name in namelist:
276             packet = proto.packet()
277             packet.setflags(["recurse"])
278             packet.addq(rec.rrhead(name, rtype))
279             res = self.resolve(packet)
280             if res.rescode == 0:
281                 break
282         return res
283
284 sysres = sysresolver()
285 rootresolvers = {"a": resolver((socket.AF_INET, "198.41.0.4", 53), True, False),
286                  "b": resolver((socket.AF_INET, "192.228.79.201", 53), True, False),
287                  "c": resolver((socket.AF_INET, "192.33.4.12", 53), True, False),
288                  "d": resolver((socket.AF_INET, "128.8.10.90", 53), True, False),
289                  "e": resolver((socket.AF_INET, "192.203.230.10", 53), True, False),
290                  "f": resolver((socket.AF_INET, "192.5.5.241", 53), True, False),
291                  "g": resolver((socket.AF_INET, "192.112.36.4", 53), True, False),
292                  "h": resolver((socket.AF_INET, "128.63.2.53", 53), True, False),
293                  "i": resolver((socket.AF_INET, "192.36.148.17", 53), True, False),
294                  "j": resolver((socket.AF_INET, "192.58.128.30", 53), True, False),
295                  "k": resolver((socket.AF_INET, "193.0.14.129", 53), True, False),
296                  "l": resolver((socket.AF_INET, "198.32.64.12", 53), True, False),
297                  "m": resolver((socket.AF_INET, "202.12.27.33", 53), True, False)
298                  }
299 rootres = multiresolver(rootresolvers.values())
300
301 default = sysres