Initial import
[ldd.git] / ldd / resolver.py
CommitLineData
769e7ed9 1# ldd - DNS implementation in Python
2# Copyright (C) 2006 Fredrik Tolf <fredrik@dolda2000.com>
3#
4# This program is free software; you can redistribute it and/or modify
5# it under the terms of the GNU General Public License as published by
6# the Free Software Foundation; either version 2 of the License, or
7# (at your option) any later version.
8#
9# This program is distributed in the hope that it will be useful,
10# but WITHOUT ANY WARRANTY; without even the implied warranty of
11# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12# GNU General Public License for more details.
13#
14# You should have received a copy of the GNU General Public License
15# along with this program; if not, write to the Free Software
16# Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
17
18import socket
19import select
20import time
21import random
22
23import proto
24import rec
25import dn
26
27class error(Exception):
28 def __init__(self, text):
29 self.text = text
30
31 def __str__(self):
32 return self.text
33
34class servfail(error):
35 def __init__(self):
36 error.__init__(self, "SERVFAIL")
37
38class unreachable(error):
39 def __init__(self, server):
40 error.__init__(self, "could not reach server: " + str(server))
41
42def resolvecnames(pkt, res = None):
43 if res is None: res = default
44 for q in pkt.qlist:
45 cnrr = pkt.getanswer(q.name, rec.rtypebyname("CNAME"))
46 if cnrr is not None:
47 if pkt.getanswer(cnrr.data["priname"], q.rtype) is None:
48 try:
49 resp = res.squery(cnrr.data["priname"], q.rtype)
50 except error:
51 continue
52 if resp is None:
53 continue
54 anrr = resp.getanswer(cnrr.data["priname"], q.rtype)
55 if anrr is None:
56 continue
57 pkt.addan(anrr)
58
59def resolveadditional(pkt, rr, res = None):
60 if res is None: res = default
61 for name in rr.data:
62 if isinstance(rr.data[name], dn.domainname):
63 for rtype in ["A", "AAAA"]:
64 if pkt.getanswer(rr.data[name], rtype) is not None:
65 continue
66 try:
67 resp = res.squery(rr.data[name], rtype)
68 except error:
69 continue
70 if resp is None:
71 continue
72 anrr = resp.getanswer(rr.data[name], rtype)
73 if anrr is None:
74 continue
75 pkt.addad(anrr)
76
77def extractaddrinfo(packet, name):
78 ret = []
79 for rr in packet.anlist + packet.adlist:
80 if rr.head.name == name:
81 if rr.head.istype("A"):
82 ret += [(socket.AF_INET, socket.inet_ntop(socket.AF_INET, rr.data["address"]))]
83 elif rr.head.istype("AAAA"):
84 ret += [(socket.AF_INET6, socket.inet_ntop(socket.AF_INET6, rr.data["address"]))]
85 return ret
86
87def resolve(packet, nameserver, recurse, retries = 3, timeout = 2000, hops = 0, cnameres = None, verbose = False, visited = None):
88 if cnameres is None: cnameres = default
89 if visited is None: visited = set()
90 visited |= set([nameserver])
91 sk = socket.socket(nameserver[0], socket.SOCK_DGRAM)
92 sk.bind(("", 0))
93 for i in range(retries):
94 sk.sendto(packet.encode(), nameserver[1:])
95 p = select.poll()
96 p.register(sk.fileno(), select.POLLIN)
97 fds = p.poll(timeout)
98 if (sk.fileno(), select.POLLIN) in fds:
99 break
100 else:
101 raise unreachable(nameserver)
102 ret = sk.recv(65536)
103 sk.close()
104 try:
105 resp = proto.decodepacket(ret)
106 except proto.malformedpacket, inst:
107 raise error(str(inst))
108 if resp.qid != packet.qid:
109 raise error("got response with wrong qid(?!)")
110 if "resp" not in resp.flags:
111 raise error("got query in response")
112 if resp.rescode != 0:
113 if resp.rescode == proto.SERVFAIL:
114 raise servfail()
115 if resp.rescode == proto.NXDOMAIN:
116 return resp
117 raise error("non-successful response (" + str(resp.rescode) + ")")
118 if recurse:
119 resolvecnames(resp, cnameres)
120 if not recurse or resp.hasanswers():
121 return resp
122 if not resp.hasanswers() and "auth" in resp.flags:
123 return resp
124 if hops > 30:
125 raise error("too many levels deep")
126 for rr in resp.aulist:
127 if verbose:
128 print (hops * " ") + "Checking " + str(rr)
129 if rr.head.istype("NS"):
130 if verbose:
131 print (hops * " ") + "Will try " + str(rr)
132 ai = extractaddrinfo(resp, rr.data["nsname"])
133 if len(ai) == 0:
134 if verbose:
135 print (hops * " ") + "Resolving nameservers for " + str(rr.data["nsname"])
136 resolveadditional(resp, rr)
137 ai = extractaddrinfo(resp, rr.data["nsname"])
138 for ns in ai:
139 ns += (53,)
140 if ns in visited:
141 if verbose:
142 print (hops * " ") + "Will not try " + str(ns) + " again"
143 continue
144 if verbose:
145 print (hops * " ") + "Trying " + str(ns)
146 try:
147 resp2 = resolve(packet, ns, recurse, retries, timeout, hops + 1, verbose = verbose, visited = visited)
148 except unreachable:
149 if verbose:
150 print (hops * " ") + "Could not reach " + str(ns)
151 continue
152 if verbose:
153 if resp2 is None:
154 print (hops * " ") + "Got None"
155 else:
156 if "auth" in resp2.flags:
157 austr = "Auth"
158 else:
159 austr = "Nonauth"
160 print (hops * " ") + "Got " + str(resp2.hasanswers()) + " (" + austr + ")"
161 if resp2 is not None and resp2.hasanswers():
162 return resp2
163 if resp2 is not None and not resp2.hasanswers() and "auth" in resp2.flags:
164 return resp2
165 return None
166
167class resolver:
168 def __init__(self, nameserver, recurse, nsrecurse = True, retries = 3, timeout = 2000, verbose = False):
169 self.nameserver = nameserver
170 self.recurse = recurse
171 self.nsrecurse = nsrecurse
172 self.retries = retries
173 self.timeout = timeout
174 self.verbose = verbose
175
176 def resolve(self, packet):
177 return resolve(packet, self.nameserver, self.recurse, self.retries, self.timeout, verbose = self.verbose)
178
179 def squery(self, name, rtype):
180 packet = proto.packet()
181 try:
182 if self.nsrecurse: packet.setflags(["recurse"])
183 except AttributeError: pass
184 packet.addq(rec.rrhead(name, rtype))
185 return self.resolve(packet)
186
187class multiresolver(resolver):
188 def __init__(self, resolvers):
189 self.rl = [{"res": res, "qs": []} for res in resolvers]
190 self.lastclean = int(time.time())
191
192 def clean(self):
193 now = int(time.time())
194 if now - self.lastclean < 60:
195 return
196 self.lastclean = now
197 for r in self.rl:
198 nl = []
199 for q in r["qs"]:
200 if now - q["time"] < 1800:
201 nl += [q]
202 r["qs"] = nl
203
204 def resolve(self, packet):
205 self.clean()
206 l = []
207 ts = 0
208 for r in self.rl:
209 if len(r["qs"]) < 1:
210 score = 1.0
211 else:
212 score = float(sum([q["s"] for q in r["qs"]])) / len(r["qs"])
213 l += [(score, r)]
214 ts += score
215 c = random.random() * ts
216 for score, r in l:
217 c -= score
218 if c <= 0:
219 break
220 else:
221 assert(False)
222 try:
223 res = r["res"].resolve(packet)
224 except error:
225 r["qs"] = r["qs"][:10] + [{"time": int(time.time()), "s": 0}]
226 raise
227 r["qs"] = r["qs"][:10] + [{"time": int(time.time()), "s": 1}]
228 return res
229
230class sysresolver(resolver):
231 def __init__(self, conffile = "/etc/resolv.conf"):
232 nslist = []
233 prelist = []
234 a = open(conffile, "r")
235 for line in (l.strip() for l in a):
236 p = line.find(" ")
237 if p >= 0:
238 c = line[:p]
239 line = line[p + 1:]
240 if c == "nameserver":
241 try:
242 socket.inet_pton(socket.AF_INET, line)
243 except socket.error: pass
244 else:
245 nslist += [(socket.AF_INET, line, 53)]
246 try:
247 socket.inet_pton(socket.AF_INET6, line)
248 except socket.error: pass
249 else:
250 nslist += [(socket.AF_INET6, line, 53)]
251 if c == "domain" or c == "search": # How do these differ?
252 prelist += line.split()
253 a.close()
254 rl = []
255 for ns in nslist:
256 rl += [resolver(ns, False, True)]
257 self.resolver = multiresolver(rl)
258 self.prelist = []
259 for prefix in prelist:
260 pp = dn.fromstring(prefix)
261 pp.rooted = True
262 self.prelist += [pp]
263
264 def resolve(self, packet):
265 res = self.resolver.resolve(packet)
266 return res
267
268 def squery(self, name, rtype):
269 if type(name) == str:
270 name = dn.fromstring(name)
271 if not name.rooted:
272 namelist = [name + prefix for prefix in self.prelist] + [name + dn.fromstring(".")]
273 else:
274 namelist = [name]
275 for name in namelist:
276 packet = proto.packet()
277 packet.setflags(["recurse"])
278 packet.addq(rec.rrhead(name, rtype))
279 res = self.resolve(packet)
280 if res.rescode == 0:
281 break
282 return res
283
284sysres = sysresolver()
285rootresolvers = {"a": resolver((socket.AF_INET, "198.41.0.4", 53), True, False),
286 "b": resolver((socket.AF_INET, "192.228.79.201", 53), True, False),
287 "c": resolver((socket.AF_INET, "192.33.4.12", 53), True, False),
288 "d": resolver((socket.AF_INET, "128.8.10.90", 53), True, False),
289 "e": resolver((socket.AF_INET, "192.203.230.10", 53), True, False),
290 "f": resolver((socket.AF_INET, "192.5.5.241", 53), True, False),
291 "g": resolver((socket.AF_INET, "192.112.36.4", 53), True, False),
292 "h": resolver((socket.AF_INET, "128.63.2.53", 53), True, False),
293 "i": resolver((socket.AF_INET, "192.36.148.17", 53), True, False),
294 "j": resolver((socket.AF_INET, "192.58.128.30", 53), True, False),
295 "k": resolver((socket.AF_INET, "193.0.14.129", 53), True, False),
296 "l": resolver((socket.AF_INET, "198.32.64.12", 53), True, False),
297 "m": resolver((socket.AF_INET, "202.12.27.33", 53), True, False)
298 }
299rootres = multiresolver(rootresolvers.values())
300
301default = sysres