Initial import
[ldd.git] / ldd / proto.py
CommitLineData
769e7ed9 1# ldd - DNS implementation in Python
2# Copyright (C) 2006 Fredrik Tolf <fredrik@dolda2000.com>
3#
4# This program is free software; you can redistribute it and/or modify
5# it under the terms of the GNU General Public License as published by
6# the Free Software Foundation; either version 2 of the License, or
7# (at your option) any later version.
8#
9# This program is distributed in the hope that it will be useful,
10# but WITHOUT ANY WARRANTY; without even the implied warranty of
11# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12# GNU General Public License for more details.
13#
14# You should have received a copy of the GNU General Public License
15# along with this program; if not, write to the Free Software
16# Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
17
18from random import randint
19import struct
20
21import dn
22import rec
23
24class malformedpacket(Exception):
25 def __init__(self, text, qid):
26 self.text = text
27 self.qid = qid
28
29 def __str__(self):
30 return self.text
31
32class packet:
33 "An abstract representation of a DNS query"
34
35 def __init__(self, qid = None, flags = 0, addr = None):
36 if qid is None: qid = randint(0, 65535)
37 self.qid = qid
38 self.qlist = []
39 self.anlist = []
40 self.aulist = []
41 self.adlist = []
42 self.opcode = 0
43 self.rescode = 0
44 self.addr = addr
45 self.signed = False
46 self.tsigctx = None
47 if type(flags) == int:
48 self.initflags(flags)
49 elif type(flags) == set:
50 self.flags = flags
51 else:
52 self.flags = set(flags)
53
54 def setflags(self, flags):
55 flags = set(flags)
56 self.flags |= flags
57
58 def clrflags(self, flags):
59 flags = set(flags)
60 self.flags -= flags
61
62 def initflags(self, flags):
63 nf = set()
64 if flags & 0x8000: nf.add("resp")
65 if flags & 0x0400: nf.add("auth")
66 if flags & 0x0200: nf.add("trunc")
67 if flags & 0x0100: nf.add("recurse")
68 if flags & 0x0080: nf.add("recursed")
69 if flags & 0x0020: nf.add("isauthen")
70 if flags & 0x0010: nf.add("authok")
71 self.opcode = (flags & 0x7800) >> 11
72 self.rescode = flags & 0x000f
73 self.flags = nf
74
75 def encodeflags(self):
76 ret = 0
77 if "resp" in self.flags: ret |= 0x8000
78 if "auth" in self.flags: ret |= 0x0400
79 if "trunc" in self.flags: ret |= 0x0200
80 if "recurse" in self.flags: ret |= 0x0100
81 if "recursed" in self.flags: ret |= 0x0080
82 if "authok" in self.flags: ret |= 0x0010
83 ret |= self.opcode << 11
84 ret |= self.rescode
85 return ret
86
87 def addq(self, rr):
88 self.qlist.append(rr)
89
90 def addan(self, rr):
91 for rr2 in self.anlist:
92 if rr2.head == rr.head and rr2.data == rr.data:
93 break
94 else:
95 self.anlist.append(rr)
96
97 def addau(self, rr):
98 for rr2 in self.aulist:
99 if rr2.head == rr.head and rr2.data == rr.data:
100 break
101 else:
102 self.aulist.append(rr)
103
104 def addad(self, rr):
105 for rr2 in self.adlist:
106 if rr2.head == rr.head and rr2.data == rr.data:
107 break
108 else:
109 self.adlist.append(rr)
110
111 def allrrs(self):
112 return self.anlist + self.aulist + self.adlist
113
114 def merge(self, other):
115 for lst in ["anlist", "aulist", "adlist"]:
116 for rr in getattr(other, lst):
117 for rr2 in getattr(self, lst):
118 if rr2.head == rr.head and rr2.data == rr.data:
119 break
120 else:
121 getattr(self, lst).append(rr)
122
123 def getanswer(self, name, rtype):
124 for rr in self.anlist + self.aulist + self.adlist:
125 if rr.head.istype(rtype) and rr.head.name == name:
126 return rr
127 return None
128
129 def hasanswers(self):
130 for q in self.qlist:
131 for rr in self.anlist + self.aulist + self.adlist:
132 if rr.head.rtype == q.rtype and rr.head.name == q.name:
133 break
134 if rr.head.istype("CNAME") and rr.head.name == q.name and self.getanswer(rr.data["priname"], q.rtype) is not None:
135 break
136 else:
137 break
138 else:
139 return True
140 return False
141
142 def __str__(self):
143 ret = ""
144 ret += "ID: " + str(self.qid) + "\n"
145 ret += "Flags: " + str(self.flags) + "\n"
146 ret += "Opcode: " + str(self.opcode) + "\n"
147 ret += "Resp. code: " + str(self.rescode) + "\n"
148 ret += "Queries:\n"
149 for rr in self.qlist:
150 ret += "\t" + str(rr) + "\n"
151 ret += "Answers:\n"
152 for rr in self.anlist:
153 ret += "\t" + str(rr) + "\n"
154 ret += "Auth RRs:\n"
155 for rr in self.aulist:
156 ret += "\t" + str(rr) + "\n"
157 ret += "Additional RRs:\n"
158 for rr in self.adlist:
159 ret += "\t" + str(rr) + "\n"
160 return ret
161
162 def encode(self):
163 ret = ""
164 ret += struct.pack(">6H", self.qid, self.encodeflags(), len(self.qlist), len(self.anlist), len(self.aulist), len(self.adlist))
165 offset = len(ret)
166 names = []
167 for rr in self.qlist:
168 rre, names = rr.encode(names, offset)
169 offset += len(rre)
170 ret += rre
171 for rr in self.anlist:
172 rre, names = rr.encode(names, offset)
173 offset += len(rre)
174 ret += rre
175 for rr in self.aulist:
176 rre, names = rr.encode(names, offset)
177 offset += len(rre)
178 ret += rre
179 for rr in self.adlist:
180 rre, names = rr.encode(names, offset)
181 offset += len(rre)
182 ret += rre
183 return ret
184
185def decodepacket(string):
186 offset = struct.calcsize(">6H")
187 qid, flags, qno, anno, auno, adno = struct.unpack(">6H", string[0:offset])
188 ret = packet(qid, flags)
189 try:
190 for i in range(qno):
191 crr, offset = rec.rrhead.decode(string, offset)
192 ret.addq(crr)
193 for i in range(anno):
194 crr, offset = rec.rr.decode(string, offset)
195 ret.addan(crr)
196 for i in range(auno):
197 crr, offset = rec.rr.decode(string, offset)
198 ret.addau(crr)
199 for i in range(adno):
200 crr, offset = rec.rr.decode(string, offset)
201 ret.addad(crr)
202 except rec.malformedrr, inst:
203 raise malformedpacket(str(inst), qid)
204 return ret
205
206def responsefor(pkt, rescode = 0):
207 resp = packet(pkt.qid, ["resp"])
208 resp.opcode = pkt.opcode
209 resp.rescode = rescode
210 resp.tsigctx = pkt.tsigctx
211 resp.qlist = pkt.qlist + [] # Make a copy
212 return resp
213
214def decodename(packet, offset):
215 parts = []
216 while True:
217 clen = ord(packet[offset])
218 offset += 1
219 if clen & 0xc0:
220 my = dn.domainname(parts, False)
221 cont, = struct.unpack(">H", chr(clen & 0x3f) + packet[offset])
222 res, discard = decodename(packet, cont)
223 return my + res, offset + 1
224 elif clen == 0:
225 return dn.domainname(parts, True), offset
226 else:
227 parts.append(packet[offset:offset + clen])
228 offset += clen
229
230def encodename(dn, names, offset):
231 ret = ""
232 for i in range(len(dn)):
233 for name, off in names:
234 if name == dn[i:]:
235 ret += chr(0xc0 + (off >> 8))
236 ret += chr(off & 0xff)
237 offset += 2
238 return ret, names
239 if offset < 16384:
240 names += [(dn[i:], offset)]
241 ret += chr(len(dn.parts[i]))
242 ret += dn.parts[i]
243 offset += 1 + len(dn.parts[i])
244 ret += chr(0)
245 offset += 1
246 return ret, names
247
248# Opcode constants
249QUERY = 0
250IQUERY = 1
251STATUS = 2
252UPDATE = 5
253
254# Response code constants
255# RFC 1035:
256FORMERR = 1
257SERVFAIL = 2
258NXDOMAIN = 3
259NOTIMP = 4
260REFUSED = 5
261# RFC 2136:
262YXDOMAIN = 6
263YXRRSET = 7
264NXRRSET = 8
265NOTAUTH = 9
266NOTZONE = 10
267# RFC 2845:
268BADSIG = 16
269BADKEY = 17
270BADTIME = 18
271
272# Special RR types
273QTANY = 255
274QTMAILA = 254
275QTMAILB = 253
276QTAXFR = 252