コード例 #1
0
ファイル: dns_client.py プロジェクト: kliyer-ai/NDS-DNS
def resolve():
    """Resolve a hostname using the resolver """
    parser = ArgumentParser(description="DNS Client")
    parser.add_argument("hostname", help="hostname to resolve")
    parser.add_argument("--timeout",
                        metavar="time",
                        type=int,
                        default=5,
                        help="resolver timeout")
    parser.add_argument("-c",
                        "--caching",
                        action="store_true",
                        help="Enable caching")
    parser.add_argument("-t",
                        "--ttl",
                        metavar="time",
                        type=int,
                        default=0,
                        help="TTL value of cached entries (if > 0)")
    args = parser.parse_args()

    s = SocketWrapper(53)
    s.start()
    rc = RecordCache(3600)
    resolver = Resolver(args.timeout, args.caching, args.ttl, s, rc)
    hostname, aliaslist, ipaddrlist = resolver.gethostbyname(args.hostname)
    s.shutdown()
    rc.write_cache_file()

    print(hostname)
    print([rr.rdata.address for rr in aliaslist])
    print([rr.rdata.address for rr in ipaddrlist])
コード例 #2
0
    def __init__(self, timeout, caching, ttl):
        """Initialize the resolver

        Args:
            caching (bool): caching is enabled if True
            ttl (int): ttl of cache entries (if > 0)
        """
        self.cache = RecordCache(10000000)
        self.timeout = timeout
        self.caching = caching
        self.ttl = ttl
        self.hints = {
            "a.root-servers.net": "198.41.0.4",
            "b.root-servers.net": "192.228.79.201",
            "c.root-servers.net": "192.33.4.12",
            "d.root-servers.net": "199.7.91.13",
            "e.root-servers.net": "192.203.230.10",
            "f.root-servers.net": "192.5.5.241",
            "g.root-servers.net": "192.112.36.4",
            "h.root-servers.net": "128.63.2.53",
            "i.root-servers.net": "192.36.148.17",
            "j.root-servers.net": "192.58.128.30",
            "k.root-servers.net": "193.0.14.129",
            "l.root-servers.net": "199.7.83.42",
            "m.root-servers.net": "202.12.27.33"
        }
コード例 #3
0
    def setUp(self):
        # put invalid record in cache file
        record_data = RecordData.create(Type.A, "192.168.123.456")
        self.rr = ResourceRecord("invalid.invalid", Type.A, Class.IN, 3, record_data)

        cache = RecordCache()
        cache.add_record(self.rr)
        cache.write_cache_file()
コード例 #4
0
 def test_cache_lookup(self):
     """
     Add a record to the cache and look it up
     """
     rr = ResourceRecord("wiki.nl", Type.A, Class.IN, self.ttl, RecordData.create(Type.A, "192.168.123.456"))
     cache = RecordCache()
     cache.add_record(rr)
     lookup_vals = cache.lookup("wiki.nl", Type.A, Class.IN)
     self.assertEqual([rr], lookup_vals)
コード例 #5
0
 def test_cache_lookup(self):
     """
     Add a record to the cache and look it up
     """
     rr = ResourceRecord("wiki.nl", Type.A, Class.IN, self.ttl,
                         RecordData.create(Type.A, "192.168.123.456"))
     cache = RecordCache()
     cache.add_record(rr)
     lookup_vals = cache.lookup("wiki.nl", Type.A, Class.IN)
     self.assertEqual([rr], lookup_vals)
 def test_expired_cache_entry(self):
     cache = RecordCache(0)
     record = ResourceRecord(
         name=Name("bonobo.putin"),
         type_=Type.A,
         class_=Class.IN,
         ttl=0,
         rdata=ARecordData("1.0.0.1"),
     )
     cache.add_record(record)
     self.assertEqual(cache.lookup(Name("bonobo.putin"), Type.A, Class.IN),
                      None)
コード例 #7
0
    def test_TTL_expiration(self):
        """
        cache a record, wait till ttl expires, see if record is removed from cache
        """
        rr = ResourceRecord("wiki.nl", Type.A, Class.IN, self.ttl, RecordData.create(Type.A, "192.168.123.456"))
        cache = RecordCache()
        cache.add_record(rr)

        time.sleep(rr.ttl)

        lookup_vals = cache.lookup("wiki.nl", Type.A, Class.IN)
        self.assertFalse(lookup_vals)
 def test_invalid_domain_from_cache(self):
     cache = RecordCache(0)
     record = ResourceRecord(
         name=Name("bonobo.putin"),
         type_=Type.A,
         class_=Class.IN,
         ttl=60,
         rdata=ARecordData("1.0.0.1"),
     )
     cache.add_record(record)
     self.assertEqual(cache.lookup(Name("bonobo.putin"), Type.A, Class.IN),
                      record)
 def test_ttl_overwrite(self):
     cache = RecordCache(60)
     record = ResourceRecord(
         name=Name("bonobo.putin"),
         type_=Type.A,
         class_=Class.IN,
         ttl=0,
         rdata=ARecordData("1.0.0.1"),
     )
     cache.add_record(record)
     cache_entry = cache.lookup(Name("bonobo.putin"), Type.A, Class.IN)
     self.assertEqual(cache_entry, record)
     self.assertEqual(cache_entry.ttl, 60)
コード例 #10
0
    def test_TTL_expiration(self):
        """
        cache a record, wait till ttl expires, see if record is removed from cache
        """
        rr = ResourceRecord("wiki.nl", Type.A, Class.IN, self.ttl,
                            RecordData.create(Type.A, "192.168.123.456"))
        cache = RecordCache()
        cache.add_record(rr)

        time.sleep(rr.ttl)

        lookup_vals = cache.lookup("wiki.nl", Type.A, Class.IN)
        self.assertFalse(lookup_vals)
コード例 #11
0
    def __init__(self, port, caching, ttl):
        """Initialize the server

        Args:
            port (int): port that server is listening on
            caching (bool): server uses resolver with caching if true
            ttl (int): ttl for records (if > 0) of cache
        """
        self.caching = caching
        self.ttl = ttl
        self.port = port
        self.done = False
        self.catalog = Catalog()
        self.sock = SocketWrapper(self.port)
        self.sock.start()
        self.cache = RecordCache(self.ttl)
コード例 #12
0
    def __init__(self, caching, cache=None):
        """ Initialize the resolver
        
        Args:
            caching (bool): caching is enabled if True
        """
        self.caching = caching
        if cache:
            self.CACHE = cache
        else:
            if self.caching:
                self.CACHE = RecordCache()
            else:
                self.CACHE = MockedCache()

        self.STYPE = Type.A
        self.SCLASS = Class.IN
        self.SLIST = []
        self.SBELT = [('A.ROOT-SERVERS.NET', '198.41.0.4'),
                      ('B.ROOT-SERVERS.NET', '192.228.79.201'),
                      ('C.ROOT-SERVERS.NET', '192.33.4.12'),
                      ('D.ROOT-SERVERS.NET', '199.7.91.13'),
                      ('E.ROOT-SERVERS.NET', '192.203.230.10'),
                      ('F.ROOT-SERVERS.NET', '192.5.5.241'),
                      ('G.ROOT-SERVERS.NET', '192.112.36.4'),
                      ('H.ROOT-SERVERS.NET', '198.97.190.53'),
                      ('I.ROOT-SERVERS.NET', '192.36.148.17'),
                      ('J.ROOT-SERVERS.NET', '192.58.128.30'),
                      ('K.ROOT-SERVERS.NET', '193.0.14.129'),
                      ('L.ROOT-SERVERS.NET', '199.7.83.42'),
                      ('M.ROOT-SERVERS.NET', '202.12.27.33')]
        self.aliases = []
        self.addresses = []
        self.timeout = 3
コード例 #13
0
    def __init__(self, timeout, caching, ttl, rootip="198.41.0.4"):
        """Initialize the resolver

        Args:
            caching (bool): caching is enabled if True
            ttl (int): ttl of cache entries (if > 0)
        """
        self.timeout = timeout
        self.caching = caching
        self.ttl = ttl
        self.doLogging = False
        self.rootip = rootip
        self.rd = 0

        if self.caching:
            self.cache = RecordCache(ttl)
            self.cache.read_cache_file()
コード例 #14
0
    def __init__(self, port, caching, ttl):
        """Initialize the server

        Args:
            port (int): port that server is listening on
            caching (bool): server uses resolver with caching if true
            ttl (int): ttl for records (if > 0) of cache
        """
        self.caching = caching
        self.ttl = ttl
        self.port = port
        self.done = False
        self.zone = Zone()
        self.zone.read_master_file('zone')
        self.cache = RecordCache(ttl)
        if self.caching:
            self.cache.read_cache_file()
        self.doLogging = True
コード例 #15
0
    def __init__(self, timeout, caching, ttl, sock=None, cache=None):
        """Initialize the resolver

        Args:
            caching (bool): caching is enabled if True
            ttl (int): ttl of cache entries (if > 0)
        """
        self.timeout = timeout
        self.caching = caching
        self.rc = cache
        self.ttl = ttl
        if self.rc == None:
            self.rc = RecordCache(self.ttl)
        self.zone = Zone()
        self.sock = sock
        if self.sock is None:
            self.sock = SocketWrapper(53)
            self.sock.start()
コード例 #16
0
class Server:
    """A recursive DNS server"""
    """
    FLAGS
    The DNS Name Server is aware of DNS flags, namely it should use its DNS Resolver only if recursion is
    enabled in the query. Moreover, if it is authoritative over the requested FQDN in the query, it should then
    send an authoritative response.
    """
    def __init__(self, port, caching, ttl):
        """Initialize the server

        Args:
            port (int): port that server is listening on
            caching (bool): server uses resolver with caching if true
            ttl (int): ttl for records (if > 0) of cache
        """
        self.caching = caching
        self.ttl = ttl
        self.port = port
        self.done = False
        self.catalog = Catalog()
        self.sock = SocketWrapper(self.port)
        self.sock.start()
        self.cache = RecordCache(self.ttl)

    def serve(self):
        """Start serving requests"""
        while not self.done:
            data = None
            while not data:
                msgs = self.sock.msgThere(-1)
                for m in msgs:
                    data, addr = m
                    re = RequestHandler(data, addr, self.catalog, self.sock,
                                        self.caching, self.cache, self.ttl)
                    re.start()

    def shutdown(self):
        """Shut the server down"""
        self.cache.write_cache_file()  #just to be sure
        self.done = True
        self.sock.shutdown()
 def test_no_cache_entry(self):
     cache = RecordCache(0)
     resolver = Resolver(5, cache)
     self.assertEqual(
         resolver.gethostbyname("google-public-dns-a.google.com"),
         ("google-public-dns-a.google.com", [], ["8.8.8.8"]),
     )
     self.assertEqual(
         resolver.gethostbyname("google-public-dns-b.google.com"),
         ("google-public-dns-b.google.com", [], ["8.8.4.4"]),
     )
コード例 #18
0
 def testCache(self):
     cac = RecordCache(0)
     self.assertEqual(cac.lookup("abcdefghqqqq.com", Type.A, Class.IN), [])
     test = ResourceRecord("blabla.com", Type.A, Class.IN, 0,
                           ARecordData("111.111.111.111"))
     cac.add_record(test)
     self.assertEqual(cac.lookup("blabla.com", Type.A, Class.IN), test)
def resolve():
    """Resolve a hostname using the resolver """
    parser = ArgumentParser(description="DNS Client")
    parser.add_argument("hostname", help="hostname to resolve")
    parser.add_argument("--timeout", metavar="time", type=int, default=5,
                        help="resolver timeout")
    parser.add_argument("-c", "--caching", action="store_true",
                        help="Enable caching")
    parser.add_argument("-t", "--ttl", metavar="time", type=int, default=0,
                        help="TTL value of cached entries (if > 0)")
    args = parser.parse_args()

    cache = RecordCache(args.ttl)
    if args.caching:
        cache.read_cache_file()
        resolver = Resolver(args.timeout, cache)
    else:
        resolver = Resolver(args.timeout)
    hostname, aliaslist, ipaddrlist = resolver.gethostbyname(args.hostname)
    if args.caching:
        cache.write_cache_file()

    print(hostname)
    print(aliaslist)
    print(ipaddrlist)
コード例 #20
0
    def setUp(self):
        # put invalid record in cache file
        record_data = RecordData.create(Type.A, "192.168.123.456")
        self.rr = ResourceRecord("invalid.invalid", Type.A, Class.IN, 3,
                                 record_data)

        cache = RecordCache()
        cache.add_record(self.rr)
        cache.write_cache_file()
コード例 #21
0
ファイル: resolver.py プロジェクト: W-M-T/pyDNS
 def __init__(self, timeout, caching, ttl, nameservers=[], use_rs=True):
     """ Initialize the resolver
     
     Args:
         caching (bool): caching is enabled if True
         ttl (int): ttl of cache entries (if > 0)
     """
     self.timeout = timeout
     self.caching = caching
     self.ttl = ttl if ttl > 0 else 0 #Deze check is niet nodig voor de resolver gemaakt via de server, maar wel voor de resolver gemaakt door de client
     if caching:
         self.cache = RecordCache()
     self.nameservers = nameservers
     if use_rs:
         self.nameservers += dns.consts.ROOT_SERVERS
コード例 #22
0
 def setUp(self):
     file = "testCache"
     with open(file, "w") as file_:
         file_.write("")
     RC = RecordCache(2, file)
     self.RC = RC
     self.r1 = ResourceRecord.from_dict({
         "type": "A",
         "name": "dnsIsAwesome.com",
         "class": "IN",
         "ttl": 2,
         "rdata": {
             "address": "192.123.12.23"
         }
     })
     RC.add_record(self.r1)
     RC.write_cache_file()
コード例 #23
0
def run_server():
    parser = ArgumentParser(description="DNS Server")
    parser.add_argument(
        "-c",
        "--caching",
        action="store_true",
        help="Enable caching",
    )
    parser.add_argument(
        "-t",
        "--ttl",
        metavar="time",
        type=int,
        default=0,
        help="TTL value of cached entries (if > 0)",
    )
    parser.add_argument(
        "-p",
        "--port",
        type=int,
        default=53,
        help="Port which server listens on",
    )
    args = parser.parse_args()

    zone = Zone()
    zone.read_master_file("zone")
    Server.catalog.add_zone("gumpe.", zone)

    if args.caching:
        cache = RecordCache(args.ttl)
        cache.read_cache_file()
        Server.cache = cache

    server = Server(args.port)
    try:
        server.serve()
    except KeyboardInterrupt:
        server.shutdown()

    if args.caching:
        cache.write_cache_file()
 def test_invalid_domain_from_cache(self):
     cache = RecordCache(0)
     resolver = Resolver(5, cache)
     cache.add_record(
         ResourceRecord(
             name=Name("bonobo.putin"),
             type_=Type.A,
             class_=Class.IN,
             ttl=60,
             rdata=ARecordData("1.0.0.1"),
         ))
     cache.add_record(
         ResourceRecord(
             name=Name("bonobo.putin"),
             type_=Type.CNAME,
             class_=Class.IN,
             ttl=60,
             rdata=CNAMERecordData(Name("putin.bonobo")),
         ))
     self.assertEqual(resolver.gethostbyname("bonobo.putin"),
                      ("bonobo.putin", ["putin.bonobo."], ["1.0.0.1"]))
 def test_expired_cache_entry(self):
     cache = RecordCache(0)
     resolver = Resolver(5, cache)
     cache.add_record(
         ResourceRecord(
             name=Name("hw.gumpe"),
             type_=Type.A,
             class_=Class.IN,
             ttl=0,
             rdata=ARecordData("1.0.0.2"),
         ))
     cache.add_record(
         ResourceRecord(
             name=Name("hw.gumpe"),
             type_=Type.CNAME,
             class_=Class.IN,
             ttl=0,
             rdata=CNAMERecordData(Name("gumpe.hw")),
         ))
     self.assertEqual(resolver.gethostbyname("hw.gumpe"),
                      ("hw.gumpe", [], []))
コード例 #26
0
ファイル: resolver.py プロジェクト: W-M-T/pyDNS
class Resolver(object):
    """ DNS resolver """
    
    def __init__(self, timeout, caching, ttl, nameservers=[], use_rs=True):
        """ Initialize the resolver
        
        Args:
            caching (bool): caching is enabled if True
            ttl (int): ttl of cache entries (if > 0)
        """
        self.timeout = timeout
        self.caching = caching
        self.ttl = ttl if ttl > 0 else 0 #Deze check is niet nodig voor de resolver gemaakt via de server, maar wel voor de resolver gemaakt door de client
        if caching:
            self.cache = RecordCache()
        self.nameservers = nameservers
        if use_rs:
            self.nameservers += dns.consts.ROOT_SERVERS


    def is_valid_hostname(self, hostname):
        """ Check if hostname could be a valid hostname

        Args:
            hostname (str): the hostname that is to be checked

        Returns:
            boolean indiciting if hostname could be valid
        """
        valid_hostnames = "^(([a-zA-Z0-9]|[a-zA-Z0-9][a-zA-Z0-9\-]*[a-zA-Z0-9])\.)*([A-Za-z0-9]|[A-Za-z0-9][A-Za-z0-9\-]*[A-Za-z0-9])$"
        return re.match(valid_hostnames, hostname)


    def save_cache(self):
        """ Save the cache if appropriate """
        if self.caching:
            if self.cache is not None:
                self.cache.write_cache_file()


    def ask_server(self, query, server):
        """ Send query to a server

        Args: 
            query (Message): the query that is to be sent
            server (str): IP address of the server that the query must be sent to
        
        Returns:
            responses ([Message]): the responses received converted to Messages
        """
        response = None
        
        sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
        sock.settimeout(self.timeout)
        try:
            sock.sendto(query.to_bytes(), (server, 53))
            data = sock.recv(1024)
            response = dns.message.Message.from_bytes(data)
            if response.header.ident != query.header.ident:
                return None
            
            if self.caching:
                for record in response.additionals + response.answers + response.authorities:
                    if record.type_ == Type.A or record.type_ == Type.CNAME:
                        record.ttl = self.ttl
                        record.timestamp = int(time.time())
                        self.cache.add_record(record)
                    
        except socket.timeout:
            pass
        
        return response


    def gethostbyname(self, hostname):
        """ Resolve hostname to an IP address

        Args:
            hostname (str): the FQDN that we want to resolve

        Returns:
            hostname (str): the FQDN that we want to resolve,
            aliaslist ([str]): list of aliases of the hostname,
            ipaddrlist ([str]): list of IP addresses of the hostname 

        """
        print("==GETHOSTNAME START=================")
        aliaslist = []
        ipaddrlist = []

        #Check if the hostname is valid
        valid = self.is_valid_hostname(hostname)
        if not valid:
            return hostname, [], []

        #Check if the information is in the cache
        if self.caching:   		
            for alias in self.cache.lookup(hostname, Type.CNAME, Class.IN):
                aliaslist.append(alias.rdata.data)
            
            for address in self.cache.lookup(hostname, Type.A, Class.IN):
                ipaddrlist.append(address.rdata.data)

            if ipaddrlist:
                print("We found an address in the cache!")
                return hostname, aliaslist, ipaddrlist

        #Do the recursive algorithm
        hints = self.nameservers
        
        while hints:
            #Get the server to ask
            hint = hints[0]
            hints = hints[1:]

            #Build the query to send to that server
            identifier = randint(0, 65535)
            
            question = dns.message.Question(hostname, Type.A, Class.IN)
            header = dns.message.Header(identifier, 0, 1, 0, 0, 0)
            header.qr = 0
            header.opcode = 0
            header.rd = 0
            query = dns.message.Message(header, [question])

            #Try to get a response
            response = self.ask_server(query, hint)

            if response == None:#We didn't get a response for this server, so check the next one
                print("Server at " + hint + " did not respond.")
                continue

            #Analyze the response
            for answer in response.answers + response.additionals:#First get the aliases
                if answer.type_ == Type.CNAME and answer.rdata.data not in aliases:
                    aliaslist.append(answer.rdata.data)

            for answer in response.answers:#Then try to get an address
                if answer.type_ == Type.A and (answer.name == hostname or answer.name in aliaslist):  
                    ipaddrlist.append(answer.rdata.data)
                
            if ipaddrlist != []:
                print("We found an address using the recursive search!")
                return hostname, aliaslist, ipaddrlist

            else:
                for nameserver in response.authorities:
                    if nameserver.type_ == Type.NS:#Do a lookup for that ns?
                        #print(nameserver.rdata.data)
                        if self.caching:
                            self.cache.add_record(nameserver)
                        hints = [nameserver.rdata.data] + hints

        print("Recursive search for " + hostname + " was a total failure")
        return hostname, [], []
コード例 #27
0
    def test_cache_disk_io(self):
        """
        Add a record to the cache, write to disk, read from disk, do a lookup
        """
        rr = ResourceRecord("wiki.nl", Type.A, Class.IN, self.ttl, RecordData.create(Type.A, "192.168.123.456"))
        cache = RecordCache()
        cache.write_cache_file() # overwrite the current cache file

        # add rr to cache and write to disk
        cache.add_record(rr)
        cache.write_cache_file()

        # read from disk again
        new_cache = RecordCache()
        new_cache.read_cache_file()
        lookup_vals = new_cache.lookup("wiki.nl", Type.A, Class.IN)
        self.assertEqual([rr], lookup_vals)
コード例 #28
0
class Resolver:
    """DNS resolver"""

    def __init__(self, timeout, caching, ttl, rootip="198.41.0.4"):
        """Initialize the resolver

        Args:
            caching (bool): caching is enabled if True
            ttl (int): ttl of cache entries (if > 0)
        """
        self.timeout = timeout
        self.caching = caching
        self.ttl = ttl
        self.doLogging = False
        self.rootip = rootip
        self.rd = 0

        if self.caching:
            self.cache = RecordCache(ttl)
            self.cache.read_cache_file()

    def log(self, *args, end="\n"):
        if self.doLogging:
            print(*args, end=end)

    def logHeader(self, header):
        self.log("\tFLAGS", end="")
        self.log(" QR", header.qr, end=";")
        self.log(" OPCODE", header.opcode, end=";")
        self.log(" AA", header.aa, end=";")
        self.log(" TC", header.tc, end=";")
        self.log(" RD", header.rd, end=";")
        self.log(" RA", header.ra, end=";")
        self.log(" Z", header.z, end=";")
        self.log(" RCODE", header.rcode, end=";\n")

    def check_cache(self, hostname):
        iplist = self.cache.lookup(Name(hostname), Type.A, Class.IN)
        namelist = self.cache.lookup(Name(hostname), Type.CNAME, Class.IN)
        if not (iplist == [] and namelist == []):
            iplist = [x.rdata.address for x in iplist]
            namelist = [x.rdata.cname for x in namelist]
            return True, iplist, namelist

        hostname = hostname.split('.')
        for i in range(len(hostname)):
            test = '.'.join(hostname[i:])
            namelist = self.cache.lookup(Name(test), Type.NS, Class.IN)
            namelist = [x.rdata.nsdname for x in namelist]
            if not namelist == []:
                for x in namelist:
                    newips = self.cache.lookup(x, Type.A, Class.IN)
                    iplist.extend(newips)
                iplist = [x.rdata.address for x in iplist]
                return False, iplist, namelist
        return False, [], []

    def send_request(self, ip, name):

        #create socket and request
        sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
        sock.settimeout(self.timeout)

        question = Question(name, Type.A, Class.IN)
        header = Header(9001, 0, 1, 0, 0, 0)
        header.qr = 0
        header.opcode = 0
        header.rd = self.rd
        query = Message(header, [question])

        sock.sendto(query.to_bytes(), (ip, 53))

        # Receive response
        data = sock.recv(512)
        sock.close()
        response = Message.from_bytes(data)
        self.logHeader(response.header)
        if self.caching:
            for r in response.resources:
                if r.type_ == Type.A or r.type_ == Type.CNAME or r.type_ == Type.NS:
                    self.cache.add_record(r)

        return response.answers, response.authorities, response.additionals

    def resolve_request(self, ip, hostname):
        self.log("\nRESOLVING REQUEST", hostname, "at:", ip)
        answers, authorities, additionals = self.send_request(ip, hostname)

        namelist = []
        ipaddrlist = []
        if len(answers) != 0:
            self.log("\n\tGOT RESPONSE")
            for answer in answers:
                if answer.type_ == Type.A:
                    ipaddrlist.append(answer.rdata.address)
                    self.log("\t\t", answer.type_, answer.rdata.address)
                if answer.type_ == Type.CNAME:
                    namelist.append(hostname)
                    self.log("\t\t", answer.type_, answer.rdata.cname)
            return True, ipaddrlist, namelist
        else:

            if len(additionals) != 0:
                self.log("\n\tGOT ADDITIONALS")
                for answer in additionals:
                    if answer.type_ == Type.A:
                        ipaddrlist.append(answer.rdata.address)
                        self.log('\t\t', answer.type_, answer.rdata.address)

            if len(authorities) != 0:
                self.log("\n\tGOT AUTHORITIES")
                for answer in authorities:
                    if answer.type_ == Type.NS:
                        namelist.append(answer.rdata.nsdname)
                        self.log('\t\t', answer.type_, answer.rdata.nsdname)

            return False, ipaddrlist, namelist

    def gethostbyname(self, hostname):
        """Translate a host name to IPv4 address.

        Currently this method contains an example. You will have to replace
        this example with the algorithm described in section 5.3.3 in RFC 1034.

        Args:
            hostname (str): the hostname to resolve

        Returns:
            (str, [str], [str]): (hostname, aliaslist, ipaddrlist)
        """

        self.log("NEW QUERY:", hostname)
        serveriplist = [self.rootip]

        if self.caching:
            res, iplist, namelist = self.check_cache(hostname)
            if not (iplist == [] and namelist == []):
                self.log("CACHE HIT")
                if res:
                    return hostname, namelist, iplist
                else:
                    serveriplist = iplist

        while len(serveriplist) != 0:
            res, iplist, namelist = self.resolve_request(serveriplist[0], Name(hostname))
            if res:
                self.log("END OF QUERY:", hostname)
                if self.caching:
                    self.cache.write_cache_file()
                return hostname, namelist, iplist
            elif len(iplist) == 0 and not len(namelist) == 0:
                newlist = []
                for x in namelist:
                    newhostname, newaliases, newips= self.gethostbyname(str(x))
                    newlist.extend(newips)
                newlist.extend(serveriplist)
                serveriplist = newlist
            else:
                iplist.extend(serveriplist[1:])
                serveriplist = iplist
        self.log("FAILURE")
        return hostname, [], []
        return hostname, [], []
コード例 #29
0
class Resolver:
    """DNS resolver"""
    def __init__(self, timeout, caching, ttl, sock=None, cache=None):
        """Initialize the resolver

        Args:
            caching (bool): caching is enabled if True
            ttl (int): ttl of cache entries (if > 0)
        """
        self.timeout = timeout
        self.caching = caching
        self.rc = cache
        self.ttl = ttl
        if self.rc == None:
            self.rc = RecordCache(self.ttl)
        self.zone = Zone()
        self.sock = sock
        if self.sock is None:
            self.sock = SocketWrapper(53)
            self.sock.start()

    def _make_id(self):
        gm = time.gmtime()
        mss = str(time.time()).split(".")[1][0:3]
        gms = str(gm.tm_sec)
        id = int(gms + mss)
        return id

    def gethostbyname(self, hostname):
        """Translate a host name to IPv4 address.

        Currently this method contains an example. You will have to replace
        this example with the algorithm described in section 5.3.3 in RFC 1034.

        Args:
            hostname (str): the hostname to resolve

        Returns:
            (str, [str], [str]): (hostname, aliaslist, ipaddrlist)
        """
        alias_list = []
        a_list = []
        slist = []
        found = False

        acs = self.getRecordsFromCache(hostname, Type.A, Class.IN)
        if acs:
            a_list += acs
            return hostname, alias_list, a_list

        nscs = self.matchByLabel(hostname, Type.NS, Class.IN)
        for ns in nscs:
            glue = self.getRecordsFromCache(str(ns.rdata.nsdname))
            if glue:
                slist += glue
            else:
                slist += [ns]

        id = self._make_id()

        # Create and send query
        question = Question(Name(hostname), Type.A, Class.IN)
        header = Header(id, 0, 1, 0, 0, 0)
        header.qr = 0  # 0 for query
        header.opcode = 0  # standad query
        header.rd = 0  # not recursive
        query = Message(header, [question])

        self.zone.read_master_file('dns/root.zone')

        sbelt = []
        for root in list(self.zone.records.values()):
            sbelt += [r for r in root if r.type_ == Type.A]

        while not found:
            if slist:
                rr = slist.pop()
                if rr.type_ == Type.A:
                    addr = rr.rdata.address
                    self.sock.send((query, addr, 53))
                elif rr.type_ == Type.NS:
                    fqdn = str(rr.rdata.nsdname)
                    _, _, a_rrs = self.gethostbyname(fqdn)
                    slist += a_rrs
                    continue
                elif rr.type_ == Type.CNAME:
                    fqdn = str(rr.rdata.cname)
                    _, cname_rrs, a_rrs = self.gethostbyname(fqdn)
                    a_list += a_rrs
                    alias_list += cname_rrs
                    break

            elif sbelt:
                rr = sbelt.pop()
                addr = rr.rdata.address
                self.sock.send((query, addr, 53))
            else:
                break

            # Receive response
            data = None
            while not data:
                data = self.sock.msgThere(id)
            response, _ = data[0]
            #response = Message.from_bytes(data)

            for answer in response.answers:
                if answer.type_ == Type.A:
                    self.addRecordToCache(answer)
                    a_list.append(answer)
                    found = True
                if answer.type_ == Type.CNAME:
                    self.addRecordToCache(answer)
                    alias_list.append(answer)
                    slist += [answer]
                    continue

            nss = []
            for auth in response.authorities:
                if auth.type_ == Type.NS:
                    nss.append(auth)
                    self.addRecordToCache(auth)

            a_add = {}
            for add in response.additionals:
                if add.type_ == Type.A:
                    name = str(add.name)
                    a_add[name] = add
                    self.addRecordToCache(add)

            for ns in nss:
                name = str(ns.rdata.nsdname)
                if name in a_add:
                    slist += [a_add[name]]
                else:
                    slist += [ns]

        return hostname, alias_list, a_list

    def addRecordToCache(self, record):
        if self.caching:
            self.rc.add_record(record)

    def getRecordsFromCache(self, dname, t=Type.A, c=Class.IN):
        if self.caching:
            return self.rc.lookup(dname, t, c)
        else:
            return []

    def matchByLabel(self, dname, type_, class_):
        if self.caching:
            return self.rc.matchByLabel(dname, type_, class_)
        else:
            return []

    def shutdown(self):
        """
        to be used only when initialized for testing
        :return:
        """
        self.sock.shutdown()
コード例 #30
0
    def test_cache_disk_io(self):
        """
        Add a record to the cache, write to disk, read from disk, do a lookup
        """
        rr = ResourceRecord("wiki.nl", Type.A, Class.IN, self.ttl,
                            RecordData.create(Type.A, "192.168.123.456"))
        cache = RecordCache()
        cache.write_cache_file()  # overwrite the current cache file

        # add rr to cache and write to disk
        cache.add_record(rr)
        cache.write_cache_file()

        # read from disk again
        new_cache = RecordCache()
        new_cache.read_cache_file()
        lookup_vals = new_cache.lookup("wiki.nl", Type.A, Class.IN)
        self.assertEqual([rr], lookup_vals)
コード例 #31
0
 def setUp(self):
     self.RC = RecordCache(100,
                           "dns/ResolverTestCache.cache")  #never written to
     self.res = Resolver(5, True, -1)
コード例 #32
0
class Resolver:
    """DNS resolver"""
    def __init__(self, timeout, caching, ttl):
        """Initialize the resolver

        Args:
            caching (bool): caching is enabled if True
            ttl (int): ttl of cache entries (if > 0)
        """
        self.cache = RecordCache(10000000)
        self.timeout = timeout
        self.caching = caching
        self.ttl = ttl
        self.hints = {
            "a.root-servers.net": "198.41.0.4",
            "b.root-servers.net": "192.228.79.201",
            "c.root-servers.net": "192.33.4.12",
            "d.root-servers.net": "199.7.91.13",
            "e.root-servers.net": "192.203.230.10",
            "f.root-servers.net": "192.5.5.241",
            "g.root-servers.net": "192.112.36.4",
            "h.root-servers.net": "128.63.2.53",
            "i.root-servers.net": "192.36.148.17",
            "j.root-servers.net": "192.58.128.30",
            "k.root-servers.net": "193.0.14.129",
            "l.root-servers.net": "199.7.83.42",
            "m.root-servers.net": "202.12.27.33"
        }

    def gethostbyname(self, hostname):
        """Translate a host name to IPv4 address.

        Currently this method contains an example. You will have to replace
        this example with the algorithm described in section 5.3.3 in RFC 1034.

        Args:
            hostname (str): the hostname to resolve

        Returns:
            (str, [str], [str]): (hostname, aliaslist, ipaddrlist)
        """
        #start with empty lists
        aliaslist = []
        ipaddrlist = []
        #initiate with the list of hints to use.
        servers = self.hints.copy()
        #so hostname is not lost when we resolve a CNAME at some point.
        domain = hostname
        finished = False
        #check cache of hostname is in the cache.
        cache_entries = self.cache.lookup(hostname, Type.A,
                                          Class.IN) + self.cache.lookup(
                                              hostname, Type.CNAME, Class.IN)
        if (len(cache_entries) > 0):
            #TODO
            pass
        while (servers and not finished):
            #take a server from the list of servers to ask for the domain we're looking for.
            name, server = servers.popitem()
            #if the server doesn't have an address, resolve this address first.
            if (server == None):
                a, b, c = self.gethostbyname(name)
                if (len(c) > 0):
                    servers[name] = c[0]
                    server = c[0]
                else:
                    continue
            #check whether the name is in the cache.
            cache_entries = self.cache.lookup(name, Type.A,
                                              Class.IN) + self.cache.lookup(
                                                  name, Type.CNAME, Class.IN)
            if (len(cache_entries) > 0):
                for ce in cache_entries:
                    if ce.type_ == 1:
                        address = ans.rdata.to_dict()['address']
                        if (address not in ipaddrlist):
                            ipaddrlist.append(address)
                    if ce.type_ == 5:
                        domain = ans.to_dict()['rdata']['cname']
                        if domain not in aliaslist:
                            aliaslist.append(str(domain))
            #send question to the server to see if it knows where we can find domain
            try:
                response = self.sendQuestion(domain, server)
            except:
                continue
            #if the location is in the answers we're done.
            for ans in response.answers:
                if ans.type_ == 1:
                    address = ans.rdata.to_dict()['address']
                    if (address not in ipaddrlist):
                        ipaddrlist.append(address)
                        if (self.caching):
                            self.cache.add_record(ans)
                if ans.type_ == 5:
                    domain = ans.to_dict()['rdata']['cname']
                    if domain not in aliaslist:
                        aliaslist.append(str(domain))
            #add new servers we can ask for the location to our list of servers.
            new_servers = self.getNewServers(response.authorities,
                                             response.additionals)
            #if we found a (list of) IP-address(es) we're done.
            if (len(ipaddrlist) > 0):
                finished = True
            #combine old list of servers with new one.
            servers = {**servers, **new_servers}
        return hostname, aliaslist, ipaddrlist

    '''Used to combine the suggested servers to check. '''

    def getNewServers(self, authorities, additionals):
        serverlist = {}
        for b in authorities:
            if (b.type_ == 6):
                continue
            serverlist.update({str(b.rdata.to_dict()['nsdname']): None})
            for a in additionals:
                if (str(a.name) == str(b.rdata.to_dict()['nsdname'])):
                    if (a.type_ == 1):
                        serverlist.update({
                            str(b.rdata.to_dict()['nsdname']):
                            a.rdata.to_dict()['address']
                        })
        return serverlist

    '''Used to handle the sending and receiving of requests/responses.'''

    def sendQuestion(self, hostname, server):
        sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
        sock.settimeout(self.timeout)

        # Create and send query
        identifier = 9001  # placeholder
        question = Question(Name(hostname), Type.A, Class.IN)
        header = Header(identifier, 0, 1, 0, 0, 0)
        header.qr = 0
        header.opcode = 0
        header.rd = 1
        query = Message(header, [question])
        sock.sendto(query.to_bytes(), (server, 53))

        # Receive response
        data = sock.recv(512)
        response = Message.from_bytes(data)
        sock.close()
        return response
コード例 #33
0
    def gethostbyname(self, hostname, dnsserv='192.112.36.4'):
        """Translate a host name to IPv4 address.

        Currently this method contains an example. You will have to replace
        this example with the algorithm described in section 5.3.3 in RFC 1034.

        Args:
            hostname (str): the hostname to resolve

        Returns:
            (str, [str], [str]): (hostname, aliaslist, ipaddrlist)
        """
        ipaddrlist = []
        cnames = []
        temp = []
        if (self.caching):
            rcache = RecordCache(self.ttl)
            rcord = rcache.lookup(hostname, Type.ANY, Class.IN)
            if (rcord):
                for rec in rcord:
                    if rec.type_ == Type.A:
                        arec = rec.rdata
                        ipaddrlist.append(arec.address)
                    elif rec.type_ == Type.CNAME:
                        crec = rec.rdata
                        cnames.append(crec.cname)
            if ipaddrlist:
                return hostname, cnames, ipaddrlist
            elif cnames:
                return self.gethostbyname(cnames[0], dnsserv)

        sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
        sock.settimeout(self.timeout)

        # Create and send query
        question = Question(Name(str(hostname)), Type.A, Class.IN)
        header = Header(9001, 0, 1, 0, 0, 0)
        header.qr = 0
        header.opcode = 0
        header.rd = 1
        query = Message(header, [question])
        sock.sendto(query.to_bytes(), (str(dnsserv), 53))

        # Receive response
        data = sock.recv(2048)
        response = Message.from_bytes(data)
        print("Number of answers: " + str(len(response.answers)))
        print("Number of authorities: " + str(len(response.authorities)))
        print("Number of additionals: " + str(len(response.additionals)))

        # Get data
        aliaslist = cnames
        ipaddrlist = []
        dnslist = []

        while response.answers:
            for answer in response.answers:
                if answer.type_ == Type.A:
                    print("found A RR")
                    if (self.caching):
                        rcache.add_record(answer)
                    ipaddrlist.append(answer.rdata.address)
                if answer.type_ == Type.CNAME:
                    aliaslist.append(answer.rdata.cname)
                if answer.type_ == Type.NS:
                    dnslist.append(answer.rdata.nsdname)
            if ipaddrlist:
                return hostname, aliaslist, ipaddrlist
            elif aliaslist:
                question = Question(Name(aliaslist[0]), Type.A, Class.IN)
                query = Message(header, [question])
                sock.sendto(query.to_bytes(), (dnsserv, 53))
                data = sock.recv(2048)
                response = Message.from_bytes(data)
            elif dnslist:
                nsname = dnslist.pop()
                maybe_dnsserv = self.getnsaddr(nsname, response.additionals)
                if maybe_dnsserv:
                    dnsserv = maybe_dnsserv
                else:
                    pass
                sock.sendto(query.to_bytes(), (dnsserv, 53))
                data = sock.recv(2048)
                response = Message.from_bytes(data)
            else:
                break

        if response.authorities:
            for authority in response.authorities:
                if authority.type_ != Type.NS:
                    pass
                dnslist.append(authority.rdata.nsdname)
            while dnslist:
                nsname = dnslist.pop()
                maybe_next_dnsserv = self.getnsaddr(nsname,
                                                    response.additionals)
                if maybe_next_dnsserv:
                    next_dns_serv = maybe_next_dnsserv
                else:
                    pass
                (hname, aliasl, ipaddrl) = self.gethostbyname(hostname, nsname)
                if ipaddrl:
                    return hname, aliasl, ipaddrl
コード例 #34
0
class Server:
    """A recursive DNS server"""

    def __init__(self, port, caching, ttl):
        """Initialize the server

        Args:
            port (int): port that server is listening on
            caching (bool): server uses resolver with caching if true
            ttl (int): ttl for records (if > 0) of cache
        """
        self.caching = caching
        self.ttl = ttl
        self.port = port
        self.done = False
        self.zone = Zone()
        self.zone.read_master_file('zone')
        self.cache = RecordCache(ttl)
        if self.caching:
            self.cache.read_cache_file()
        self.doLogging = True

    def log(self, *args, end="\n"):
        if self.doLogging:
            print(*args, end=end)

    def search_zone(self, question):
        tests = str(question).split('.')
        for i in range(len(tests)):
            test = '.'.join(tests[i:])
            if test == '':
                test = '.'
            r = self.zone.records.get(test, None)
            if not r == None:
                if test == str(question):
                    return True, r
                return False, r
        return False, []

    def zone_resolution(self, questions):
        for q in questions:
            self.log("\tRESOLVING:", q.qname)
            res, rlist = self.search_zone(q.qname)
            answers = []
            authorities = []
            additionals = []
            for r in rlist:
                if r.type_ == Type.A and res:
                    answers.append(r)
                if r.type_ == Type.NS:
                    res, a = self.search_zone(r.rdata.nsdname)
                    authorities.append(r)
                    if res:
                        additionals.append(a[0])
        return answers, authorities, additionals

    def check_cache(self, hostname):
        iplist = self.cache.lookup(hostname, Type.A, Class.IN)
        namelist = self.cache.lookup(hostname, Type.CNAME, Class.IN)
        if not (iplist == [] and namelist == []):
            iplist = [x.rdata.address for x in iplist]
            namelist = [x.rdata.cname for x in namelist]
            return True, iplist, namelist
        return False, [], []

    def consult_cache(self, questions):
        answers = []
        for q in questions:
            res, iplist, namelist = self.check_cache(q.qname)
            if res:
                for ip in iplist:
                    answers.append(ResourceRecord(q.qname, Type.A, Class.IN, self.ttl, ARecordData(ip)))
                for n in namelist:
                    answers.append(ResourceRecord(q.qname, Type.CNAME, Class.IN, self.ttl, CNAMERecordData(n)))
        return answers

    def build_message(self,id, rd, aa, rcode, questions, answers, authorities, additionals):
        header = Header(id, 0, len(questions), len(answers), len(authorities), len(additionals))
        header.qr = 1
        header.opcode = 0
        header.rd = rd
        header.ra = 1
        header.aa = aa
        header.rcode = rcode
        return Message(header, questions=questions, answers=answers, authorities=authorities, additionals=additionals)


    def serve(self):
        """Start serving requests"""
        sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
        sock.bind(("", self.port))

        while not self.done:
            data, address = sock.recvfrom(65565)
            message = Message.from_bytes(data)

            rd = message.header.rd
            rcode = 0
            aa = 1

            self.log("REQUEST RECIEVED:", address)
            answers, authorities, additionals = self.zone_resolution(message.questions)

            if answers == []:
                if authorities == [] and additionals == []:
                    self.log("\tZONE RESOLUTION FAILED")
                    answers = self.consult_cache(message.questions)

                    if answers == []:
                        self.log("\tCACHE LOOKUP FAILED")
                        rcode = 3
                    else:
                        aa = 0

                if rcode == 3 and rd == 1:
                    rcode = 0
                    self.log("\tCALLING RESOLVER")
                    resolver = Resolver(5, True, 0)
                    resolver.rd = 0
                    resolver.rootip = "198.41.0.4"
                    for q in message.questions:
                        self.log("\t\tRESOLVING:", q.qname)

                        hostname, namelist, iplist = resolver.gethostbyname(str(q.qname))
                        if hostname == str(q.qname):
                            for ip in iplist:
                                answers.append(ResourceRecord(q.qname, Type.A, Class.IN, self.ttl, ARecordData(ip)))
                            for n in namelist:
                                answers.append(ResourceRecord(q.qname, Type.CNAME, Class.IN, self.ttl, CNAMERecordData(n)))

            self.log("SENDING RESPONSE:", rcode, "\n")
            mess = self.build_message(message.header.ident, rd, aa, rcode, message.questions, answers, authorities, additionals)
            sock.sendto(mess.to_bytes(), address)

    def shutdown(self):
        """Shut the server down"""
        self.done = True