Initial import
[ldd.git] / ldd / proto.py
1 #    ldd - DNS implementation in Python
2 #    Copyright (C) 2006 Fredrik Tolf <fredrik@dolda2000.com>
3 #
4 #    This program is free software; you can redistribute it and/or modify
5 #    it under the terms of the GNU General Public License as published by
6 #    the Free Software Foundation; either version 2 of the License, or
7 #    (at your option) any later version.
8 #
9 #    This program is distributed in the hope that it will be useful,
10 #    but WITHOUT ANY WARRANTY; without even the implied warranty of
11 #    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12 #    GNU General Public License for more details.
13 #
14 #    You should have received a copy of the GNU General Public License
15 #    along with this program; if not, write to the Free Software
16 #    Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA  02110-1301  USA
17
18 from random import randint
19 import struct
20
21 import dn
22 import rec
23
24 class malformedpacket(Exception):
25     def __init__(self, text, qid):
26         self.text = text
27         self.qid = qid
28
29     def __str__(self):
30         return self.text
31
32 class packet:
33     "An abstract representation of a DNS query"
34
35     def __init__(self, qid = None, flags = 0, addr = None):
36         if qid is None: qid = randint(0, 65535)
37         self.qid = qid
38         self.qlist = []
39         self.anlist = []
40         self.aulist = []
41         self.adlist = []
42         self.opcode = 0
43         self.rescode = 0
44         self.addr = addr
45         self.signed = False
46         self.tsigctx = None
47         if type(flags) == int:
48             self.initflags(flags)
49         elif type(flags) == set:
50             self.flags = flags
51         else:
52             self.flags = set(flags)
53         
54     def setflags(self, flags):
55         flags = set(flags)
56         self.flags |= flags
57
58     def clrflags(self, flags):
59         flags = set(flags)
60         self.flags -= flags
61
62     def initflags(self, flags):
63         nf = set()
64         if flags & 0x8000: nf.add("resp")
65         if flags & 0x0400: nf.add("auth")
66         if flags & 0x0200: nf.add("trunc")
67         if flags & 0x0100: nf.add("recurse")
68         if flags & 0x0080: nf.add("recursed")
69         if flags & 0x0020: nf.add("isauthen")
70         if flags & 0x0010: nf.add("authok")
71         self.opcode =   (flags & 0x7800) >> 11
72         self.rescode =   flags & 0x000f
73         self.flags = nf
74
75     def encodeflags(self):
76         ret = 0
77         if "resp"     in self.flags: ret |= 0x8000
78         if "auth"     in self.flags: ret |= 0x0400
79         if "trunc"    in self.flags: ret |= 0x0200
80         if "recurse"  in self.flags: ret |= 0x0100
81         if "recursed" in self.flags: ret |= 0x0080
82         if "authok"   in self.flags: ret |= 0x0010
83         ret |= self.opcode << 11
84         ret |= self.rescode
85         return ret
86
87     def addq(self, rr):
88         self.qlist.append(rr)
89
90     def addan(self, rr):
91         for rr2 in self.anlist:
92             if rr2.head  == rr.head and rr2.data == rr.data:
93                 break
94         else:
95             self.anlist.append(rr)
96
97     def addau(self, rr):
98         for rr2 in self.aulist:
99             if rr2.head  == rr.head and rr2.data == rr.data:
100                 break
101         else:
102             self.aulist.append(rr)
103
104     def addad(self, rr):
105         for rr2 in self.adlist:
106             if rr2.head  == rr.head and rr2.data == rr.data:
107                 break
108         else:
109             self.adlist.append(rr)
110
111     def allrrs(self):
112         return self.anlist + self.aulist + self.adlist
113
114     def merge(self, other):
115         for lst in ["anlist", "aulist", "adlist"]:
116             for rr in getattr(other, lst):
117                 for rr2 in getattr(self, lst):
118                     if rr2.head == rr.head and rr2.data == rr.data:
119                         break
120                 else:
121                     getattr(self, lst).append(rr)
122     
123     def getanswer(self, name, rtype):
124         for rr in self.anlist + self.aulist + self.adlist:
125             if rr.head.istype(rtype) and rr.head.name == name:
126                 return rr
127         return None
128
129     def hasanswers(self):
130         for q in self.qlist:
131             for rr in self.anlist + self.aulist + self.adlist:
132                 if rr.head.rtype == q.rtype and rr.head.name == q.name:
133                     break
134                 if rr.head.istype("CNAME") and rr.head.name == q.name and self.getanswer(rr.data["priname"], q.rtype) is not None:
135                     break
136             else:
137                 break
138         else:
139             return True
140         return False
141         
142     def __str__(self):
143         ret = ""
144         ret += "ID: " + str(self.qid) + "\n"
145         ret += "Flags: " + str(self.flags) + "\n"
146         ret += "Opcode: " + str(self.opcode) + "\n"
147         ret += "Resp. code: " + str(self.rescode) + "\n"
148         ret += "Queries:\n"
149         for rr in self.qlist:
150             ret += "\t" + str(rr) + "\n"
151         ret += "Answers:\n"
152         for rr in self.anlist:
153             ret += "\t" + str(rr) + "\n"
154         ret += "Auth RRs:\n"
155         for rr in self.aulist:
156             ret += "\t" + str(rr) + "\n"
157         ret += "Additional RRs:\n"
158         for rr in self.adlist:
159             ret += "\t" + str(rr) + "\n"
160         return ret
161
162     def encode(self):
163         ret = ""
164         ret += struct.pack(">6H", self.qid, self.encodeflags(), len(self.qlist), len(self.anlist), len(self.aulist), len(self.adlist))
165         offset = len(ret)
166         names = []
167         for rr in self.qlist:
168             rre, names = rr.encode(names, offset)
169             offset += len(rre)
170             ret += rre
171         for rr in self.anlist:
172             rre, names = rr.encode(names, offset)
173             offset += len(rre)
174             ret += rre
175         for rr in self.aulist:
176             rre, names = rr.encode(names, offset)
177             offset += len(rre)
178             ret += rre
179         for rr in self.adlist:
180             rre, names = rr.encode(names, offset)
181             offset += len(rre)
182             ret += rre
183         return ret
184
185 def decodepacket(string):
186     offset = struct.calcsize(">6H")
187     qid, flags, qno, anno, auno, adno = struct.unpack(">6H", string[0:offset])
188     ret = packet(qid, flags)
189     try:
190         for i in range(qno):
191             crr, offset = rec.rrhead.decode(string, offset)
192             ret.addq(crr)
193         for i in range(anno):
194             crr, offset = rec.rr.decode(string, offset)
195             ret.addan(crr)
196         for i in range(auno):
197             crr, offset = rec.rr.decode(string, offset)
198             ret.addau(crr)
199         for i in range(adno):
200             crr, offset = rec.rr.decode(string, offset)
201             ret.addad(crr)
202     except rec.malformedrr, inst:
203         raise malformedpacket(str(inst), qid)
204     return ret
205
206 def responsefor(pkt, rescode = 0):
207     resp = packet(pkt.qid, ["resp"])
208     resp.opcode = pkt.opcode
209     resp.rescode = rescode
210     resp.tsigctx = pkt.tsigctx
211     resp.qlist = pkt.qlist + []  # Make a copy
212     return resp
213
214 def decodename(packet, offset):
215     parts = []
216     while True:
217         clen = ord(packet[offset])
218         offset += 1
219         if clen & 0xc0:
220             my = dn.domainname(parts, False)
221             cont, = struct.unpack(">H", chr(clen & 0x3f) + packet[offset])
222             res, discard = decodename(packet, cont)
223             return my + res, offset + 1
224         elif clen == 0:
225             return dn.domainname(parts, True), offset
226         else:
227             parts.append(packet[offset:offset + clen])
228             offset += clen
229
230 def encodename(dn, names, offset):
231     ret = ""
232     for i in range(len(dn)):
233         for name, off in names:
234             if name == dn[i:]:
235                 ret += chr(0xc0 + (off >> 8))
236                 ret += chr(off & 0xff)
237                 offset += 2
238                 return ret, names
239         if offset < 16384:
240             names += [(dn[i:], offset)]
241         ret += chr(len(dn.parts[i]))
242         ret += dn.parts[i]
243         offset += 1 + len(dn.parts[i])
244     ret += chr(0)
245     offset += 1
246     return ret, names
247
248 # Opcode constants
249 QUERY = 0
250 IQUERY = 1
251 STATUS = 2
252 UPDATE = 5
253
254 # Response code constants
255 #  RFC 1035:
256 FORMERR = 1
257 SERVFAIL = 2
258 NXDOMAIN = 3
259 NOTIMP = 4
260 REFUSED = 5
261 #  RFC 2136:
262 YXDOMAIN = 6
263 YXRRSET = 7
264 NXRRSET = 8
265 NOTAUTH = 9
266 NOTZONE = 10
267 #  RFC 2845:
268 BADSIG = 16
269 BADKEY = 17
270 BADTIME = 18
271
272 # Special RR types
273 QTANY = 255
274 QTMAILA = 254
275 QTMAILB = 253
276 QTAXFR = 252