def testCache(self):
     cac = RecordCache(0)
     self.assertEqual(cac.lookup("abcdefghqqqq.com", Type.A, Class.IN), [])
     test = ResourceRecord("blabla.com", Type.A, Class.IN, 0,
                           ARecordData("111.111.111.111"))
     cac.add_record(test)
     self.assertEqual(cac.lookup("blabla.com", Type.A, Class.IN), test)
    def setUp(self):
        # put invalid record in cache file
        record_data = RecordData.create(Type.A, "192.168.123.456")
        self.rr = ResourceRecord("invalid.invalid", Type.A, Class.IN, 3, record_data)

        cache = RecordCache()
        cache.add_record(self.rr)
        cache.write_cache_file()
    def setUp(self):
        # put invalid record in cache file
        record_data = RecordData.create(Type.A, "192.168.123.456")
        self.rr = ResourceRecord("invalid.invalid", Type.A, Class.IN, 3,
                                 record_data)

        cache = RecordCache()
        cache.add_record(self.rr)
        cache.write_cache_file()
 def test_cache_lookup(self):
     """
     Add a record to the cache and look it up
     """
     rr = ResourceRecord("wiki.nl", Type.A, Class.IN, self.ttl, RecordData.create(Type.A, "192.168.123.456"))
     cache = RecordCache()
     cache.add_record(rr)
     lookup_vals = cache.lookup("wiki.nl", Type.A, Class.IN)
     self.assertEqual([rr], lookup_vals)
 def test_cache_lookup(self):
     """
     Add a record to the cache and look it up
     """
     rr = ResourceRecord("wiki.nl", Type.A, Class.IN, self.ttl,
                         RecordData.create(Type.A, "192.168.123.456"))
     cache = RecordCache()
     cache.add_record(rr)
     lookup_vals = cache.lookup("wiki.nl", Type.A, Class.IN)
     self.assertEqual([rr], lookup_vals)
    def test_TTL_expiration(self):
        """
        cache a record, wait till ttl expires, see if record is removed from cache
        """
        rr = ResourceRecord("wiki.nl", Type.A, Class.IN, self.ttl, RecordData.create(Type.A, "192.168.123.456"))
        cache = RecordCache()
        cache.add_record(rr)

        time.sleep(rr.ttl)

        lookup_vals = cache.lookup("wiki.nl", Type.A, Class.IN)
        self.assertFalse(lookup_vals)
 def test_expired_cache_entry(self):
     cache = RecordCache(0)
     record = ResourceRecord(
         name=Name("bonobo.putin"),
         type_=Type.A,
         class_=Class.IN,
         ttl=0,
         rdata=ARecordData("1.0.0.1"),
     )
     cache.add_record(record)
     self.assertEqual(cache.lookup(Name("bonobo.putin"), Type.A, Class.IN),
                      None)
 def test_invalid_domain_from_cache(self):
     cache = RecordCache(0)
     record = ResourceRecord(
         name=Name("bonobo.putin"),
         type_=Type.A,
         class_=Class.IN,
         ttl=60,
         rdata=ARecordData("1.0.0.1"),
     )
     cache.add_record(record)
     self.assertEqual(cache.lookup(Name("bonobo.putin"), Type.A, Class.IN),
                      record)
    def test_TTL_expiration(self):
        """
        cache a record, wait till ttl expires, see if record is removed from cache
        """
        rr = ResourceRecord("wiki.nl", Type.A, Class.IN, self.ttl,
                            RecordData.create(Type.A, "192.168.123.456"))
        cache = RecordCache()
        cache.add_record(rr)

        time.sleep(rr.ttl)

        lookup_vals = cache.lookup("wiki.nl", Type.A, Class.IN)
        self.assertFalse(lookup_vals)
 def test_ttl_overwrite(self):
     cache = RecordCache(60)
     record = ResourceRecord(
         name=Name("bonobo.putin"),
         type_=Type.A,
         class_=Class.IN,
         ttl=0,
         rdata=ARecordData("1.0.0.1"),
     )
     cache.add_record(record)
     cache_entry = cache.lookup(Name("bonobo.putin"), Type.A, Class.IN)
     self.assertEqual(cache_entry, record)
     self.assertEqual(cache_entry.ttl, 60)
示例#11
0
 def setUp(self):
     file = "testCache"
     with open(file, "w") as file_:
         file_.write("")
     RC = RecordCache(2, file)
     self.RC = RC
     self.r1 = ResourceRecord.from_dict({
         "type": "A",
         "name": "dnsIsAwesome.com",
         "class": "IN",
         "ttl": 2,
         "rdata": {
             "address": "192.123.12.23"
         }
     })
     RC.add_record(self.r1)
     RC.write_cache_file()
    def test_cache_disk_io(self):
        """
        Add a record to the cache, write to disk, read from disk, do a lookup
        """
        rr = ResourceRecord("wiki.nl", Type.A, Class.IN, self.ttl, RecordData.create(Type.A, "192.168.123.456"))
        cache = RecordCache()
        cache.write_cache_file() # overwrite the current cache file

        # add rr to cache and write to disk
        cache.add_record(rr)
        cache.write_cache_file()

        # read from disk again
        new_cache = RecordCache()
        new_cache.read_cache_file()
        lookup_vals = new_cache.lookup("wiki.nl", Type.A, Class.IN)
        self.assertEqual([rr], lookup_vals)
    def test_cache_disk_io(self):
        """
        Add a record to the cache, write to disk, read from disk, do a lookup
        """
        rr = ResourceRecord("wiki.nl", Type.A, Class.IN, self.ttl,
                            RecordData.create(Type.A, "192.168.123.456"))
        cache = RecordCache()
        cache.write_cache_file()  # overwrite the current cache file

        # add rr to cache and write to disk
        cache.add_record(rr)
        cache.write_cache_file()

        # read from disk again
        new_cache = RecordCache()
        new_cache.read_cache_file()
        lookup_vals = new_cache.lookup("wiki.nl", Type.A, Class.IN)
        self.assertEqual([rr], lookup_vals)
 def test_invalid_domain_from_cache(self):
     cache = RecordCache(0)
     resolver = Resolver(5, cache)
     cache.add_record(
         ResourceRecord(
             name=Name("bonobo.putin"),
             type_=Type.A,
             class_=Class.IN,
             ttl=60,
             rdata=ARecordData("1.0.0.1"),
         ))
     cache.add_record(
         ResourceRecord(
             name=Name("bonobo.putin"),
             type_=Type.CNAME,
             class_=Class.IN,
             ttl=60,
             rdata=CNAMERecordData(Name("putin.bonobo")),
         ))
     self.assertEqual(resolver.gethostbyname("bonobo.putin"),
                      ("bonobo.putin", ["putin.bonobo."], ["1.0.0.1"]))
 def test_expired_cache_entry(self):
     cache = RecordCache(0)
     resolver = Resolver(5, cache)
     cache.add_record(
         ResourceRecord(
             name=Name("hw.gumpe"),
             type_=Type.A,
             class_=Class.IN,
             ttl=0,
             rdata=ARecordData("1.0.0.2"),
         ))
     cache.add_record(
         ResourceRecord(
             name=Name("hw.gumpe"),
             type_=Type.CNAME,
             class_=Class.IN,
             ttl=0,
             rdata=CNAMERecordData(Name("gumpe.hw")),
         ))
     self.assertEqual(resolver.gethostbyname("hw.gumpe"),
                      ("hw.gumpe", [], []))
示例#16
0
class Resolver:
    """DNS resolver"""
    def __init__(self, timeout, caching, ttl, sock=None, cache=None):
        """Initialize the resolver

        Args:
            caching (bool): caching is enabled if True
            ttl (int): ttl of cache entries (if > 0)
        """
        self.timeout = timeout
        self.caching = caching
        self.rc = cache
        self.ttl = ttl
        if self.rc == None:
            self.rc = RecordCache(self.ttl)
        self.zone = Zone()
        self.sock = sock
        if self.sock is None:
            self.sock = SocketWrapper(53)
            self.sock.start()

    def _make_id(self):
        gm = time.gmtime()
        mss = str(time.time()).split(".")[1][0:3]
        gms = str(gm.tm_sec)
        id = int(gms + mss)
        return id

    def gethostbyname(self, hostname):
        """Translate a host name to IPv4 address.

        Currently this method contains an example. You will have to replace
        this example with the algorithm described in section 5.3.3 in RFC 1034.

        Args:
            hostname (str): the hostname to resolve

        Returns:
            (str, [str], [str]): (hostname, aliaslist, ipaddrlist)
        """
        alias_list = []
        a_list = []
        slist = []
        found = False

        acs = self.getRecordsFromCache(hostname, Type.A, Class.IN)
        if acs:
            a_list += acs
            return hostname, alias_list, a_list

        nscs = self.matchByLabel(hostname, Type.NS, Class.IN)
        for ns in nscs:
            glue = self.getRecordsFromCache(str(ns.rdata.nsdname))
            if glue:
                slist += glue
            else:
                slist += [ns]

        id = self._make_id()

        # Create and send query
        question = Question(Name(hostname), Type.A, Class.IN)
        header = Header(id, 0, 1, 0, 0, 0)
        header.qr = 0  # 0 for query
        header.opcode = 0  # standad query
        header.rd = 0  # not recursive
        query = Message(header, [question])

        self.zone.read_master_file('dns/root.zone')

        sbelt = []
        for root in list(self.zone.records.values()):
            sbelt += [r for r in root if r.type_ == Type.A]

        while not found:
            if slist:
                rr = slist.pop()
                if rr.type_ == Type.A:
                    addr = rr.rdata.address
                    self.sock.send((query, addr, 53))
                elif rr.type_ == Type.NS:
                    fqdn = str(rr.rdata.nsdname)
                    _, _, a_rrs = self.gethostbyname(fqdn)
                    slist += a_rrs
                    continue
                elif rr.type_ == Type.CNAME:
                    fqdn = str(rr.rdata.cname)
                    _, cname_rrs, a_rrs = self.gethostbyname(fqdn)
                    a_list += a_rrs
                    alias_list += cname_rrs
                    break

            elif sbelt:
                rr = sbelt.pop()
                addr = rr.rdata.address
                self.sock.send((query, addr, 53))
            else:
                break

            # Receive response
            data = None
            while not data:
                data = self.sock.msgThere(id)
            response, _ = data[0]
            #response = Message.from_bytes(data)

            for answer in response.answers:
                if answer.type_ == Type.A:
                    self.addRecordToCache(answer)
                    a_list.append(answer)
                    found = True
                if answer.type_ == Type.CNAME:
                    self.addRecordToCache(answer)
                    alias_list.append(answer)
                    slist += [answer]
                    continue

            nss = []
            for auth in response.authorities:
                if auth.type_ == Type.NS:
                    nss.append(auth)
                    self.addRecordToCache(auth)

            a_add = {}
            for add in response.additionals:
                if add.type_ == Type.A:
                    name = str(add.name)
                    a_add[name] = add
                    self.addRecordToCache(add)

            for ns in nss:
                name = str(ns.rdata.nsdname)
                if name in a_add:
                    slist += [a_add[name]]
                else:
                    slist += [ns]

        return hostname, alias_list, a_list

    def addRecordToCache(self, record):
        if self.caching:
            self.rc.add_record(record)

    def getRecordsFromCache(self, dname, t=Type.A, c=Class.IN):
        if self.caching:
            return self.rc.lookup(dname, t, c)
        else:
            return []

    def matchByLabel(self, dname, type_, class_):
        if self.caching:
            return self.rc.matchByLabel(dname, type_, class_)
        else:
            return []

    def shutdown(self):
        """
        to be used only when initialized for testing
        :return:
        """
        self.sock.shutdown()
示例#17
0
class Resolver:
    """DNS resolver"""

    def __init__(self, timeout, caching, ttl, rootip="198.41.0.4"):
        """Initialize the resolver

        Args:
            caching (bool): caching is enabled if True
            ttl (int): ttl of cache entries (if > 0)
        """
        self.timeout = timeout
        self.caching = caching
        self.ttl = ttl
        self.doLogging = False
        self.rootip = rootip
        self.rd = 0

        if self.caching:
            self.cache = RecordCache(ttl)
            self.cache.read_cache_file()

    def log(self, *args, end="\n"):
        if self.doLogging:
            print(*args, end=end)

    def logHeader(self, header):
        self.log("\tFLAGS", end="")
        self.log(" QR", header.qr, end=";")
        self.log(" OPCODE", header.opcode, end=";")
        self.log(" AA", header.aa, end=";")
        self.log(" TC", header.tc, end=";")
        self.log(" RD", header.rd, end=";")
        self.log(" RA", header.ra, end=";")
        self.log(" Z", header.z, end=";")
        self.log(" RCODE", header.rcode, end=";\n")

    def check_cache(self, hostname):
        iplist = self.cache.lookup(Name(hostname), Type.A, Class.IN)
        namelist = self.cache.lookup(Name(hostname), Type.CNAME, Class.IN)
        if not (iplist == [] and namelist == []):
            iplist = [x.rdata.address for x in iplist]
            namelist = [x.rdata.cname for x in namelist]
            return True, iplist, namelist

        hostname = hostname.split('.')
        for i in range(len(hostname)):
            test = '.'.join(hostname[i:])
            namelist = self.cache.lookup(Name(test), Type.NS, Class.IN)
            namelist = [x.rdata.nsdname for x in namelist]
            if not namelist == []:
                for x in namelist:
                    newips = self.cache.lookup(x, Type.A, Class.IN)
                    iplist.extend(newips)
                iplist = [x.rdata.address for x in iplist]
                return False, iplist, namelist
        return False, [], []

    def send_request(self, ip, name):

        #create socket and request
        sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
        sock.settimeout(self.timeout)

        question = Question(name, Type.A, Class.IN)
        header = Header(9001, 0, 1, 0, 0, 0)
        header.qr = 0
        header.opcode = 0
        header.rd = self.rd
        query = Message(header, [question])

        sock.sendto(query.to_bytes(), (ip, 53))

        # Receive response
        data = sock.recv(512)
        sock.close()
        response = Message.from_bytes(data)
        self.logHeader(response.header)
        if self.caching:
            for r in response.resources:
                if r.type_ == Type.A or r.type_ == Type.CNAME or r.type_ == Type.NS:
                    self.cache.add_record(r)

        return response.answers, response.authorities, response.additionals

    def resolve_request(self, ip, hostname):
        self.log("\nRESOLVING REQUEST", hostname, "at:", ip)
        answers, authorities, additionals = self.send_request(ip, hostname)

        namelist = []
        ipaddrlist = []
        if len(answers) != 0:
            self.log("\n\tGOT RESPONSE")
            for answer in answers:
                if answer.type_ == Type.A:
                    ipaddrlist.append(answer.rdata.address)
                    self.log("\t\t", answer.type_, answer.rdata.address)
                if answer.type_ == Type.CNAME:
                    namelist.append(hostname)
                    self.log("\t\t", answer.type_, answer.rdata.cname)
            return True, ipaddrlist, namelist
        else:

            if len(additionals) != 0:
                self.log("\n\tGOT ADDITIONALS")
                for answer in additionals:
                    if answer.type_ == Type.A:
                        ipaddrlist.append(answer.rdata.address)
                        self.log('\t\t', answer.type_, answer.rdata.address)

            if len(authorities) != 0:
                self.log("\n\tGOT AUTHORITIES")
                for answer in authorities:
                    if answer.type_ == Type.NS:
                        namelist.append(answer.rdata.nsdname)
                        self.log('\t\t', answer.type_, answer.rdata.nsdname)

            return False, ipaddrlist, namelist

    def gethostbyname(self, hostname):
        """Translate a host name to IPv4 address.

        Currently this method contains an example. You will have to replace
        this example with the algorithm described in section 5.3.3 in RFC 1034.

        Args:
            hostname (str): the hostname to resolve

        Returns:
            (str, [str], [str]): (hostname, aliaslist, ipaddrlist)
        """

        self.log("NEW QUERY:", hostname)
        serveriplist = [self.rootip]

        if self.caching:
            res, iplist, namelist = self.check_cache(hostname)
            if not (iplist == [] and namelist == []):
                self.log("CACHE HIT")
                if res:
                    return hostname, namelist, iplist
                else:
                    serveriplist = iplist

        while len(serveriplist) != 0:
            res, iplist, namelist = self.resolve_request(serveriplist[0], Name(hostname))
            if res:
                self.log("END OF QUERY:", hostname)
                if self.caching:
                    self.cache.write_cache_file()
                return hostname, namelist, iplist
            elif len(iplist) == 0 and not len(namelist) == 0:
                newlist = []
                for x in namelist:
                    newhostname, newaliases, newips= self.gethostbyname(str(x))
                    newlist.extend(newips)
                newlist.extend(serveriplist)
                serveriplist = newlist
            else:
                iplist.extend(serveriplist[1:])
                serveriplist = iplist
        self.log("FAILURE")
        return hostname, [], []
        return hostname, [], []
示例#18
0
class Resolver:
    """DNS resolver"""
    def __init__(self, timeout, caching, ttl):
        """Initialize the resolver

        Args:
            caching (bool): caching is enabled if True
            ttl (int): ttl of cache entries (if > 0)
        """
        self.cache = RecordCache(10000000)
        self.timeout = timeout
        self.caching = caching
        self.ttl = ttl
        self.hints = {
            "a.root-servers.net": "198.41.0.4",
            "b.root-servers.net": "192.228.79.201",
            "c.root-servers.net": "192.33.4.12",
            "d.root-servers.net": "199.7.91.13",
            "e.root-servers.net": "192.203.230.10",
            "f.root-servers.net": "192.5.5.241",
            "g.root-servers.net": "192.112.36.4",
            "h.root-servers.net": "128.63.2.53",
            "i.root-servers.net": "192.36.148.17",
            "j.root-servers.net": "192.58.128.30",
            "k.root-servers.net": "193.0.14.129",
            "l.root-servers.net": "199.7.83.42",
            "m.root-servers.net": "202.12.27.33"
        }

    def gethostbyname(self, hostname):
        """Translate a host name to IPv4 address.

        Currently this method contains an example. You will have to replace
        this example with the algorithm described in section 5.3.3 in RFC 1034.

        Args:
            hostname (str): the hostname to resolve

        Returns:
            (str, [str], [str]): (hostname, aliaslist, ipaddrlist)
        """
        #start with empty lists
        aliaslist = []
        ipaddrlist = []
        #initiate with the list of hints to use.
        servers = self.hints.copy()
        #so hostname is not lost when we resolve a CNAME at some point.
        domain = hostname
        finished = False
        #check cache of hostname is in the cache.
        cache_entries = self.cache.lookup(hostname, Type.A,
                                          Class.IN) + self.cache.lookup(
                                              hostname, Type.CNAME, Class.IN)
        if (len(cache_entries) > 0):
            #TODO
            pass
        while (servers and not finished):
            #take a server from the list of servers to ask for the domain we're looking for.
            name, server = servers.popitem()
            #if the server doesn't have an address, resolve this address first.
            if (server == None):
                a, b, c = self.gethostbyname(name)
                if (len(c) > 0):
                    servers[name] = c[0]
                    server = c[0]
                else:
                    continue
            #check whether the name is in the cache.
            cache_entries = self.cache.lookup(name, Type.A,
                                              Class.IN) + self.cache.lookup(
                                                  name, Type.CNAME, Class.IN)
            if (len(cache_entries) > 0):
                for ce in cache_entries:
                    if ce.type_ == 1:
                        address = ans.rdata.to_dict()['address']
                        if (address not in ipaddrlist):
                            ipaddrlist.append(address)
                    if ce.type_ == 5:
                        domain = ans.to_dict()['rdata']['cname']
                        if domain not in aliaslist:
                            aliaslist.append(str(domain))
            #send question to the server to see if it knows where we can find domain
            try:
                response = self.sendQuestion(domain, server)
            except:
                continue
            #if the location is in the answers we're done.
            for ans in response.answers:
                if ans.type_ == 1:
                    address = ans.rdata.to_dict()['address']
                    if (address not in ipaddrlist):
                        ipaddrlist.append(address)
                        if (self.caching):
                            self.cache.add_record(ans)
                if ans.type_ == 5:
                    domain = ans.to_dict()['rdata']['cname']
                    if domain not in aliaslist:
                        aliaslist.append(str(domain))
            #add new servers we can ask for the location to our list of servers.
            new_servers = self.getNewServers(response.authorities,
                                             response.additionals)
            #if we found a (list of) IP-address(es) we're done.
            if (len(ipaddrlist) > 0):
                finished = True
            #combine old list of servers with new one.
            servers = {**servers, **new_servers}
        return hostname, aliaslist, ipaddrlist

    '''Used to combine the suggested servers to check. '''

    def getNewServers(self, authorities, additionals):
        serverlist = {}
        for b in authorities:
            if (b.type_ == 6):
                continue
            serverlist.update({str(b.rdata.to_dict()['nsdname']): None})
            for a in additionals:
                if (str(a.name) == str(b.rdata.to_dict()['nsdname'])):
                    if (a.type_ == 1):
                        serverlist.update({
                            str(b.rdata.to_dict()['nsdname']):
                            a.rdata.to_dict()['address']
                        })
        return serverlist

    '''Used to handle the sending and receiving of requests/responses.'''

    def sendQuestion(self, hostname, server):
        sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
        sock.settimeout(self.timeout)

        # Create and send query
        identifier = 9001  # placeholder
        question = Question(Name(hostname), Type.A, Class.IN)
        header = Header(identifier, 0, 1, 0, 0, 0)
        header.qr = 0
        header.opcode = 0
        header.rd = 1
        query = Message(header, [question])
        sock.sendto(query.to_bytes(), (server, 53))

        # Receive response
        data = sock.recv(512)
        response = Message.from_bytes(data)
        sock.close()
        return response
示例#19
0
    def gethostbyname(self, hostname, dnsserv='192.112.36.4'):
        """Translate a host name to IPv4 address.

        Currently this method contains an example. You will have to replace
        this example with the algorithm described in section 5.3.3 in RFC 1034.

        Args:
            hostname (str): the hostname to resolve

        Returns:
            (str, [str], [str]): (hostname, aliaslist, ipaddrlist)
        """
        ipaddrlist = []
        cnames = []
        temp = []
        if (self.caching):
            rcache = RecordCache(self.ttl)
            rcord = rcache.lookup(hostname, Type.ANY, Class.IN)
            if (rcord):
                for rec in rcord:
                    if rec.type_ == Type.A:
                        arec = rec.rdata
                        ipaddrlist.append(arec.address)
                    elif rec.type_ == Type.CNAME:
                        crec = rec.rdata
                        cnames.append(crec.cname)
            if ipaddrlist:
                return hostname, cnames, ipaddrlist
            elif cnames:
                return self.gethostbyname(cnames[0], dnsserv)

        sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
        sock.settimeout(self.timeout)

        # Create and send query
        question = Question(Name(str(hostname)), Type.A, Class.IN)
        header = Header(9001, 0, 1, 0, 0, 0)
        header.qr = 0
        header.opcode = 0
        header.rd = 1
        query = Message(header, [question])
        sock.sendto(query.to_bytes(), (str(dnsserv), 53))

        # Receive response
        data = sock.recv(2048)
        response = Message.from_bytes(data)
        print("Number of answers: " + str(len(response.answers)))
        print("Number of authorities: " + str(len(response.authorities)))
        print("Number of additionals: " + str(len(response.additionals)))

        # Get data
        aliaslist = cnames
        ipaddrlist = []
        dnslist = []

        while response.answers:
            for answer in response.answers:
                if answer.type_ == Type.A:
                    print("found A RR")
                    if (self.caching):
                        rcache.add_record(answer)
                    ipaddrlist.append(answer.rdata.address)
                if answer.type_ == Type.CNAME:
                    aliaslist.append(answer.rdata.cname)
                if answer.type_ == Type.NS:
                    dnslist.append(answer.rdata.nsdname)
            if ipaddrlist:
                return hostname, aliaslist, ipaddrlist
            elif aliaslist:
                question = Question(Name(aliaslist[0]), Type.A, Class.IN)
                query = Message(header, [question])
                sock.sendto(query.to_bytes(), (dnsserv, 53))
                data = sock.recv(2048)
                response = Message.from_bytes(data)
            elif dnslist:
                nsname = dnslist.pop()
                maybe_dnsserv = self.getnsaddr(nsname, response.additionals)
                if maybe_dnsserv:
                    dnsserv = maybe_dnsserv
                else:
                    pass
                sock.sendto(query.to_bytes(), (dnsserv, 53))
                data = sock.recv(2048)
                response = Message.from_bytes(data)
            else:
                break

        if response.authorities:
            for authority in response.authorities:
                if authority.type_ != Type.NS:
                    pass
                dnslist.append(authority.rdata.nsdname)
            while dnslist:
                nsname = dnslist.pop()
                maybe_next_dnsserv = self.getnsaddr(nsname,
                                                    response.additionals)
                if maybe_next_dnsserv:
                    next_dns_serv = maybe_next_dnsserv
                else:
                    pass
                (hname, aliasl, ipaddrl) = self.gethostbyname(hostname, nsname)
                if ipaddrl:
                    return hname, aliasl, ipaddrl
示例#20
0
文件: resolver.py 项目: W-M-T/pyDNS
class Resolver(object):
    """ DNS resolver """
    
    def __init__(self, timeout, caching, ttl, nameservers=[], use_rs=True):
        """ Initialize the resolver
        
        Args:
            caching (bool): caching is enabled if True
            ttl (int): ttl of cache entries (if > 0)
        """
        self.timeout = timeout
        self.caching = caching
        self.ttl = ttl if ttl > 0 else 0 #Deze check is niet nodig voor de resolver gemaakt via de server, maar wel voor de resolver gemaakt door de client
        if caching:
            self.cache = RecordCache()
        self.nameservers = nameservers
        if use_rs:
            self.nameservers += dns.consts.ROOT_SERVERS


    def is_valid_hostname(self, hostname):
        """ Check if hostname could be a valid hostname

        Args:
            hostname (str): the hostname that is to be checked

        Returns:
            boolean indiciting if hostname could be valid
        """
        valid_hostnames = "^(([a-zA-Z0-9]|[a-zA-Z0-9][a-zA-Z0-9\-]*[a-zA-Z0-9])\.)*([A-Za-z0-9]|[A-Za-z0-9][A-Za-z0-9\-]*[A-Za-z0-9])$"
        return re.match(valid_hostnames, hostname)


    def save_cache(self):
        """ Save the cache if appropriate """
        if self.caching:
            if self.cache is not None:
                self.cache.write_cache_file()


    def ask_server(self, query, server):
        """ Send query to a server

        Args: 
            query (Message): the query that is to be sent
            server (str): IP address of the server that the query must be sent to
        
        Returns:
            responses ([Message]): the responses received converted to Messages
        """
        response = None
        
        sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
        sock.settimeout(self.timeout)
        try:
            sock.sendto(query.to_bytes(), (server, 53))
            data = sock.recv(1024)
            response = dns.message.Message.from_bytes(data)
            if response.header.ident != query.header.ident:
                return None
            
            if self.caching:
                for record in response.additionals + response.answers + response.authorities:
                    if record.type_ == Type.A or record.type_ == Type.CNAME:
                        record.ttl = self.ttl
                        record.timestamp = int(time.time())
                        self.cache.add_record(record)
                    
        except socket.timeout:
            pass
        
        return response


    def gethostbyname(self, hostname):
        """ Resolve hostname to an IP address

        Args:
            hostname (str): the FQDN that we want to resolve

        Returns:
            hostname (str): the FQDN that we want to resolve,
            aliaslist ([str]): list of aliases of the hostname,
            ipaddrlist ([str]): list of IP addresses of the hostname 

        """
        print("==GETHOSTNAME START=================")
        aliaslist = []
        ipaddrlist = []

        #Check if the hostname is valid
        valid = self.is_valid_hostname(hostname)
        if not valid:
            return hostname, [], []

        #Check if the information is in the cache
        if self.caching:   		
            for alias in self.cache.lookup(hostname, Type.CNAME, Class.IN):
                aliaslist.append(alias.rdata.data)
            
            for address in self.cache.lookup(hostname, Type.A, Class.IN):
                ipaddrlist.append(address.rdata.data)

            if ipaddrlist:
                print("We found an address in the cache!")
                return hostname, aliaslist, ipaddrlist

        #Do the recursive algorithm
        hints = self.nameservers
        
        while hints:
            #Get the server to ask
            hint = hints[0]
            hints = hints[1:]

            #Build the query to send to that server
            identifier = randint(0, 65535)
            
            question = dns.message.Question(hostname, Type.A, Class.IN)
            header = dns.message.Header(identifier, 0, 1, 0, 0, 0)
            header.qr = 0
            header.opcode = 0
            header.rd = 0
            query = dns.message.Message(header, [question])

            #Try to get a response
            response = self.ask_server(query, hint)

            if response == None:#We didn't get a response for this server, so check the next one
                print("Server at " + hint + " did not respond.")
                continue

            #Analyze the response
            for answer in response.answers + response.additionals:#First get the aliases
                if answer.type_ == Type.CNAME and answer.rdata.data not in aliases:
                    aliaslist.append(answer.rdata.data)

            for answer in response.answers:#Then try to get an address
                if answer.type_ == Type.A and (answer.name == hostname or answer.name in aliaslist):  
                    ipaddrlist.append(answer.rdata.data)
                
            if ipaddrlist != []:
                print("We found an address using the recursive search!")
                return hostname, aliaslist, ipaddrlist

            else:
                for nameserver in response.authorities:
                    if nameserver.type_ == Type.NS:#Do a lookup for that ns?
                        #print(nameserver.rdata.data)
                        if self.caching:
                            self.cache.add_record(nameserver)
                        hints = [nameserver.rdata.data] + hints

        print("Recursive search for " + hostname + " was a total failure")
        return hostname, [], []