Initial import
[ldd.git] / ldd / rescache.py
CommitLineData
769e7ed9 1# ldd - DNS implementation in Python
2# Copyright (C) 2006 Fredrik Tolf <fredrik@dolda2000.com>
3#
4# This program is free software; you can redistribute it and/or modify
5# it under the terms of the GNU General Public License as published by
6# the Free Software Foundation; either version 2 of the License, or
7# (at your option) any later version.
8#
9# This program is distributed in the hope that it will be useful,
10# but WITHOUT ANY WARRANTY; without even the implied warranty of
11# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12# GNU General Public License for more details.
13#
14# You should have received a copy of the GNU General Public License
15# along with this program; if not, write to the Free Software
16# Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
17
18import threading
19import time
20
21import resolver
22import proto
23import rec
24
25class nxdmark:
26 def __init__(self, expire, auth):
27 self.expire = expire
28 self.auth = auth
29
30class cacheresolver(resolver.resolver):
31 def __init__(self, resolver):
32 self.resolver = resolver
33 self.cache = dict()
34 self.cachelock = threading.Lock()
35
36 def getcached(self, name, rtype = proto.QTANY):
37 self.cachelock.acquire()
38 try:
39 if name not in self.cache:
40 return []
41 now = int(time.time())
42 if isinstance(self.cache[name], nxdmark):
43 if self.cache[name].expire < now:
44 self.cache[name] = []
45 return []
46 return self.cache[name]
47 ret = []
48 if rtype == proto.QTANY:
49 cond = lambda rt: True
50 elif type(rtype) == int:
51 cond = lambda rt: rtype == rt
52 elif type(rtype) == str:
53 rtid = rec.rtypebyname(rtype)
54 cond = lambda rt: rtid == rt
55 else:
56 rtset = set([((type(rtid) == str) and rec.rtypebyname(rtid)) or rtid for rtid in rtype])
57 cond = lambda rt: rt in rtset
58 for exp, trd, data, auth in self.cache[name]:
59 if exp > now and cond(trd):
60 ret += [(rec.rr((name, trd), exp - now, data), auth)]
61 return ret
62 finally:
63 self.cachelock.release()
64
65 def dolookup(self, name, rtype):
66 try:
67 res = self.resolver.squery(name, rtype)
68 except resolver.servfail, resolver.unreachable:
69 return None
70 if res is None:
71 return None
72 if res.rescode == proto.NXDOMAIN:
73 ttl = 300
74 for rr in res.aulist:
75 if rr.head.istype("SOA"):
76 ttl = rr.data["minttl"]
77 nc = nxdmark(int(time.time()) + ttl, res.aulist)
78 self.cachelock.acquire()
79 try:
80 self.cache[name] = nc
81 finally:
82 self.cachelock.release()
83 return nc
84 now = int(time.time())
85 self.cachelock.acquire()
86 try:
87 alltypes = set([rr.head.rtype for rr in res.allrrs()])
88 for name in set([rr.head.name for rr in res.allrrs()]):
89 if name in self.cache:
90 self.cache[name] = [cl for cl in self.cache[name] if cl[1] not in alltypes]
91 for rr in res.allrrs():
92 if rr.head.name not in self.cache:
93 self.cache[rr.head.name] = []
94 self.cache[rr.head.name] += [(now + rr.ttl, rr.head.rtype, rr.data, [rr for rr in res.aulist if rr.head.istype("NS")])]
95 return res
96 finally:
97 self.cachelock.release()
98
99 def addcached(self, packet, cis):
100 for item, auth in cis:
101 packet.addan(item)
102 for ns in auth:
103 packet.addau(ns)
104 nsal = self.getcached(ns.data["nsname"], ["A", "AAAA"])
105 if type(nsal) == list:
106 for item, auth in nsal:
107 packet.addad(item)
108
109 def resolve(self, packet):
110 res = proto.responsefor(packet)
111 for q in packet.qlist:
112 name = q.name
113 rtype = q.rtype
114 while True:
115 cis = self.getcached(name, rtype)
116 if isinstance(cis, nxdmark):
117 if len(packet.qlist) == 1:
118 res.rescode = proto.NXDOMAIN
119 res.aulist = cis.auth
120 return res
121 continue
122 if len(cis) == 0:
123 cics = self.getcached(name, "CNAME")
124 if isinstance(cics, nxdmark):
125 break
126 if len(cics) > 0:
127 self.addcached(res, cics)
128 name = cics[0][0].data["priname"]
129 continue
130 break
131 if len(cis) == 0:
132 tres = self.dolookup(name, rtype)
133 if isinstance(tres, nxdmark) and len(packet.qlist) == 1:
134 res.rescode = proto.NXDOMAIN
135 res.aulist = tres.auth
136 return res
137 if tres is None and len(packet.qlist) == 1:
138 res.rescode = proto.SERVFAIL
139 return res
140 if tres is not None and tres.rescode == 0:
141 res.merge(tres)
142 else:
143 self.addcached(res, cis)
144 return res