Exemple #1
0
class DNSDistProtobufTest(DNSDistTest):
    _protobufServerPort = 4242
    _protobufQueue = Queue()
    _protobufServerID = 'dnsdist-server-1'
    _protobufCounter = 0

    @classmethod
    def ProtobufListener(cls, port):
        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
        try:
            sock.bind(("127.0.0.1", port))
        except socket.error as e:
            print("Error binding in the protbuf listener: %s" % str(e))
            sys.exit(1)

        sock.listen(100)
        while True:
            (conn, _) = sock.accept()
            data = None
            while True:
                data = conn.recv(2)
                if not data:
                    break
                (datalen, ) = struct.unpack("!H", data)
                data = conn.recv(datalen)
                if not data:
                    break

                cls._protobufQueue.put(data, True, timeout=2.0)

            conn.close()
        sock.close()

    @classmethod
    def startResponders(cls):
        cls._UDPResponder = threading.Thread(name='UDP Responder',
                                             target=cls.UDPResponder,
                                             args=[
                                                 cls._testServerPort,
                                                 cls._toResponderQueue,
                                                 cls._fromResponderQueue
                                             ])
        cls._UDPResponder.setDaemon(True)
        cls._UDPResponder.start()

        cls._TCPResponder = threading.Thread(name='TCP Responder',
                                             target=cls.TCPResponder,
                                             args=[
                                                 cls._testServerPort,
                                                 cls._toResponderQueue,
                                                 cls._fromResponderQueue
                                             ])
        cls._TCPResponder.setDaemon(True)
        cls._TCPResponder.start()

        cls._protobufListener = threading.Thread(
            name='Protobuf Listener',
            target=cls.ProtobufListener,
            args=[cls._protobufServerPort])
        cls._protobufListener.setDaemon(True)
        cls._protobufListener.start()

    def getFirstProtobufMessage(self):
        self.assertFalse(self._protobufQueue.empty())
        data = self._protobufQueue.get(False)
        self.assertTrue(data)
        msg = dnsmessage_pb2.PBDNSMessage()
        msg.ParseFromString(data)
        return msg

    def checkProtobufBase(self,
                          msg,
                          protocol,
                          query,
                          initiator,
                          normalQueryResponse=True):
        self.assertTrue(msg)
        self.assertTrue(msg.HasField('timeSec'))
        self.assertTrue(msg.HasField('socketFamily'))
        self.assertEquals(msg.socketFamily, dnsmessage_pb2.PBDNSMessage.INET)
        self.assertTrue(msg.HasField('from'))
        fromvalue = getattr(msg, 'from')
        self.assertEquals(socket.inet_ntop(socket.AF_INET, fromvalue),
                          initiator)
        self.assertTrue(msg.HasField('socketProtocol'))
        self.assertEquals(msg.socketProtocol, protocol)
        self.assertTrue(msg.HasField('messageId'))
        self.assertTrue(msg.HasField('id'))
        self.assertEquals(msg.id, query.id)
        self.assertTrue(msg.HasField('inBytes'))
        self.assertTrue(msg.HasField('serverIdentity'))
        self.assertEquals(msg.serverIdentity,
                          self._protobufServerID.encode('utf-8'))

        if normalQueryResponse:
            # compare inBytes with length of query/response
            self.assertEquals(msg.inBytes, len(query.to_wire()))
        # dnsdist doesn't set the existing EDNS Subnet for now,
        # although it might be set from Lua
        # self.assertTrue(msg.HasField('originalRequestorSubnet'))
        # self.assertEquals(len(msg.originalRequestorSubnet), 4)
        # self.assertEquals(socket.inet_ntop(socket.AF_INET, msg.originalRequestorSubnet), '127.0.0.1')

    def checkProtobufQuery(self,
                           msg,
                           protocol,
                           query,
                           qclass,
                           qtype,
                           qname,
                           initiator='127.0.0.1'):
        self.assertEquals(msg.type, dnsmessage_pb2.PBDNSMessage.DNSQueryType)
        self.checkProtobufBase(msg, protocol, query, initiator)
        # dnsdist doesn't fill the responder field for responses
        # because it doesn't keep the information around.
        self.assertTrue(msg.HasField('to'))
        self.assertEquals(socket.inet_ntop(socket.AF_INET, msg.to),
                          '127.0.0.1')
        self.assertTrue(msg.HasField('question'))
        self.assertTrue(msg.question.HasField('qClass'))
        self.assertEquals(msg.question.qClass, qclass)
        self.assertTrue(msg.question.HasField('qType'))
        self.assertEquals(msg.question.qClass, qtype)
        self.assertTrue(msg.question.HasField('qName'))
        self.assertEquals(msg.question.qName, qname)

    def checkProtobufTags(self, tags, expectedTags):
        # only differences will be in new list
        listx = set(tags) ^ set(expectedTags)
        # exclusive or of lists should be empty
        self.assertEqual(len(listx), 0, "Protobuf tags don't match")

    def checkProtobufQueryConvertedToResponse(self,
                                              msg,
                                              protocol,
                                              response,
                                              initiator='127.0.0.0'):
        self.assertEquals(msg.type,
                          dnsmessage_pb2.PBDNSMessage.DNSResponseType)
        # skip comparing inBytes (size of the query) with the length of the generated response
        self.checkProtobufBase(msg, protocol, response, initiator, False)
        self.assertTrue(msg.HasField('response'))
        self.assertTrue(msg.response.HasField('queryTimeSec'))

    def checkProtobufResponse(self,
                              msg,
                              protocol,
                              response,
                              initiator='127.0.0.1'):
        self.assertEquals(msg.type,
                          dnsmessage_pb2.PBDNSMessage.DNSResponseType)
        self.checkProtobufBase(msg, protocol, response, initiator)
        self.assertTrue(msg.HasField('response'))
        self.assertTrue(msg.response.HasField('queryTimeSec'))

    def checkProtobufResponseRecord(self, record, rclass, rtype, rname, rttl):
        self.assertTrue(record.HasField('class'))
        self.assertEquals(getattr(record, 'class'), rclass)
        self.assertTrue(record.HasField('type'))
        self.assertEquals(record.type, rtype)
        self.assertTrue(record.HasField('name'))
        self.assertEquals(record.name, rname)
        self.assertTrue(record.HasField('ttl'))
        self.assertEquals(record.ttl, rttl)
        self.assertTrue(record.HasField('rdata'))
Exemple #2
0
class TestDnstapOverRemoteLogger(DNSDistTest):
    _remoteLoggerServerPort = 4243
    _remoteLoggerQueue = Queue()
    _remoteLoggerCounter = 0
    _config_params = ['_testServerPort', '_remoteLoggerServerPort']
    _config_template = """
    extrasmn = newSuffixMatchNode()
    extrasmn:add(newDNSName('extra.dnstap.tests.powerdns.com.'))

    luatarget = 'lua.dnstap.tests.powerdns.com.'

    function alterDnstapQuery(dq, tap)
      if extrasmn:check(dq.qname) then
        tap:setExtra("Type,Query")
      end
    end

    function alterDnstapResponse(dq, tap)
      if extrasmn:check(dq.qname) then
        tap:setExtra("Type,Response")
      end
    end

    function luaFunc(dq)
      dq.dh:setQR(true)
      dq.dh:setRCode(dnsdist.NXDOMAIN)
      return DNSAction.None, ""
    end

    newServer{address="127.0.0.1:%s", useClientSubnet=true}
    rl = newRemoteLogger('127.0.0.1:%s')

    addAction(AllRule(), DnstapLogAction("a.server", rl, alterDnstapQuery))				-- Send dnstap message before lookup

    addAction(luatarget, LuaAction(luaFunc))				-- Send dnstap message before lookup

    addResponseAction(AllRule(), DnstapLogResponseAction("a.server", rl, alterDnstapResponse))	-- Send dnstap message after lookup

    addAction('spoof.dnstap.tests.powerdns.com.', SpoofAction("192.0.2.1"))
    """

    @classmethod
    def RemoteLoggerListener(cls, port):
        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
        try:
            sock.bind(("127.0.0.1", port))
        except socket.error as e:
            print("Error binding in the protbuf listener: %s" % str(e))
            sys.exit(1)

        sock.listen(100)
        while True:
            (conn, _) = sock.accept()
            data = None
            while True:
                data = conn.recv(2)
                if not data:
                    break
                (datalen, ) = struct.unpack("!H", data)
                data = conn.recv(datalen)
                if not data:
                    break

                cls._remoteLoggerQueue.put(data, True, timeout=2.0)

            conn.close()
        sock.close()

    @classmethod
    def startResponders(cls):
        DNSDistTest.startResponders()

        cls._remoteLoggerListener = threading.Thread(
            name='RemoteLogger Listener',
            target=cls.RemoteLoggerListener,
            args=[cls._remoteLoggerServerPort])
        cls._remoteLoggerListener.setDaemon(True)
        cls._remoteLoggerListener.start()

    def getFirstDnstap(self):
        self.assertFalse(self._remoteLoggerQueue.empty())
        data = self._remoteLoggerQueue.get(False)
        self.assertTrue(data)
        dnstap = dnstap_pb2.Dnstap()
        dnstap.ParseFromString(data)
        return dnstap

    def testDnstap(self):
        """
        Dnstap: Send query and responses packed in dnstap to a remotelogger server
        """
        name = 'query.dnstap.tests.powerdns.com.'

        target = 'target.dnstap.tests.powerdns.com.'
        query = dns.message.make_query(name, 'A', 'IN')
        response = dns.message.make_response(query)

        rrset = dns.rrset.from_text(name, 3600, dns.rdataclass.IN,
                                    dns.rdatatype.CNAME, target)
        response.answer.append(rrset)

        rrset = dns.rrset.from_text(target, 3600, dns.rdataclass.IN,
                                    dns.rdatatype.A, '127.0.0.1')
        response.answer.append(rrset)

        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
        self.assertTrue(receivedQuery)
        self.assertTrue(receivedResponse)
        receivedQuery.id = query.id
        self.assertEquals(query, receivedQuery)
        self.assertEquals(response, receivedResponse)

        # give the dnstap messages time to get here
        time.sleep(1)

        # check the dnstap message corresponding to the UDP query
        dnstap = self.getFirstDnstap()

        checkDnstapQuery(self, dnstap, dnstap_pb2.UDP, query)
        checkDnstapNoExtra(self, dnstap)

        # check the dnstap message corresponding to the UDP response
        dnstap = self.getFirstDnstap()
        checkDnstapResponse(self, dnstap, dnstap_pb2.UDP, response)
        checkDnstapNoExtra(self, dnstap)

        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
        self.assertTrue(receivedQuery)
        self.assertTrue(receivedResponse)
        receivedQuery.id = query.id
        self.assertEquals(query, receivedQuery)
        self.assertEquals(response, receivedResponse)

        # give the dnstap messages time to get here
        time.sleep(1)

        # check the dnstap message corresponding to the TCP query
        dnstap = self.getFirstDnstap()

        checkDnstapQuery(self, dnstap, dnstap_pb2.TCP, query)
        checkDnstapNoExtra(self, dnstap)

        # check the dnstap message corresponding to the TCP response
        dnstap = self.getFirstDnstap()
        checkDnstapResponse(self, dnstap, dnstap_pb2.TCP, response)
        checkDnstapNoExtra(self, dnstap)

    def testDnstapExtra(self):
        """
        DnstapExtra: Send query and responses packed in dnstap to a remotelogger server. Extra data is filled out.
        """
        name = 'extra.dnstap.tests.powerdns.com.'

        target = 'target.dnstap.tests.powerdns.com.'
        query = dns.message.make_query(name, 'A', 'IN')
        response = dns.message.make_response(query)

        rrset = dns.rrset.from_text(name, 3600, dns.rdataclass.IN,
                                    dns.rdatatype.CNAME, target)
        response.answer.append(rrset)

        rrset = dns.rrset.from_text(target, 3600, dns.rdataclass.IN,
                                    dns.rdatatype.A, '127.0.0.1')
        response.answer.append(rrset)

        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
        self.assertTrue(receivedQuery)
        self.assertTrue(receivedResponse)
        receivedQuery.id = query.id
        self.assertEquals(query, receivedQuery)
        self.assertEquals(response, receivedResponse)

        # give the dnstap messages time to get here
        time.sleep(1)

        # check the dnstap message corresponding to the UDP query
        dnstap = self.getFirstDnstap()
        checkDnstapQuery(self, dnstap, dnstap_pb2.UDP, query)
        checkDnstapExtra(self, dnstap, b"Type,Query")

        # check the dnstap message corresponding to the UDP response
        dnstap = self.getFirstDnstap()
        checkDnstapResponse(self, dnstap, dnstap_pb2.UDP, response)
        checkDnstapExtra(self, dnstap, b"Type,Response")

        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
        self.assertTrue(receivedQuery)
        self.assertTrue(receivedResponse)
        receivedQuery.id = query.id
        self.assertEquals(query, receivedQuery)
        self.assertEquals(response, receivedResponse)

        # give the dnstap messages time to get here
        time.sleep(1)

        # check the dnstap message corresponding to the TCP query
        dnstap = self.getFirstDnstap()
        checkDnstapQuery(self, dnstap, dnstap_pb2.TCP, query)
        checkDnstapExtra(self, dnstap, b"Type,Query")

        # check the dnstap message corresponding to the TCP response
        dnstap = self.getFirstDnstap()
        checkDnstapResponse(self, dnstap, dnstap_pb2.TCP, response)
        checkDnstapExtra(self, dnstap, b"Type,Response")
Exemple #3
0
class TestDnstapOverFrameStreamTcpLogger(DNSDistTest):
    _fstrmLoggerPort = 4000
    _fstrmLoggerQueue = Queue()
    _fstrmLoggerCounter = 0
    _config_params = ['_testServerPort', '_fstrmLoggerPort']
    _config_template = """
    newServer{address="127.0.0.1:%s", useClientSubnet=true}
    fslu = newFrameStreamTcpLogger('127.0.0.1:%s')

    addAction(AllRule(), DnstapLogAction("a.server", fslu))
    """

    @classmethod
    def FrameStreamUnixListener(cls, port):
        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        try:
            sock.bind(("127.0.0.1", port))
        except socket.error as e:
            print("Error binding in the framestream listener: %s" % str(e))
            sys.exit(1)

        sock.listen(100)
        while True:
            (conn, _) = sock.accept()
            fstrm_handle_bidir_connection(conn, lambda data: \
                cls._fstrmLoggerQueue.put(data, True, timeout=2.0))
            conn.close()
        sock.close()

    @classmethod
    def startResponders(cls):
        DNSDistTest.startResponders()

        cls._fstrmLoggerListener = threading.Thread(
            name='FrameStreamUnixListener',
            target=cls.FrameStreamUnixListener,
            args=[cls._fstrmLoggerPort])
        cls._fstrmLoggerListener.setDaemon(True)
        cls._fstrmLoggerListener.start()

    def getFirstDnstap(self):
        data = self._fstrmLoggerQueue.get(True, timeout=2.0)
        self.assertTrue(data)
        dnstap = dnstap_pb2.Dnstap()
        dnstap.ParseFromString(data)
        return dnstap

    def testDnstapOverFrameStreamTcp(self):
        """
        Dnstap: Send query packed in dnstap to a tcp socket fstrmlogger server
        """
        name = 'query.dnstap.tests.powerdns.com.'

        target = 'target.dnstap.tests.powerdns.com.'
        query = dns.message.make_query(name, 'A', 'IN')
        response = dns.message.make_response(query)

        rrset = dns.rrset.from_text(name, 3600, dns.rdataclass.IN,
                                    dns.rdatatype.CNAME, target)
        response.answer.append(rrset)

        rrset = dns.rrset.from_text(target, 3600, dns.rdataclass.IN,
                                    dns.rdatatype.A, '127.0.0.1')
        response.answer.append(rrset)

        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
        self.assertTrue(receivedQuery)
        self.assertTrue(receivedResponse)
        receivedQuery.id = query.id
        self.assertEquals(query, receivedQuery)
        self.assertEquals(response, receivedResponse)

        # check the dnstap message corresponding to the UDP query
        dnstap = self.getFirstDnstap()

        checkDnstapQuery(self, dnstap, dnstap_pb2.UDP, query)
        checkDnstapNoExtra(self, dnstap)
Exemple #4
0
class TestProtobuf(DNSDistTest):
    _protobufServerPort = 4242
    _protobufQueue = Queue()
    _protobufServerID = 'dnsdist-server-1'
    _protobufCounter = 0
    _config_params = [
        '_testServerPort', '_protobufServerPort', '_protobufServerID',
        '_protobufServerID'
    ]
    _config_template = """
    luasmn = newSuffixMatchNode()
    luasmn:add(newDNSName('lua.protobuf.tests.powerdns.com.'))

    function alterProtobufResponse(dq, protobuf)
      if luasmn:check(dq.qname) then
        requestor = newCA(dq.remoteaddr:toString())		-- called by testLuaProtobuf()
        if requestor:isIPv4() then
          requestor:truncate(24)
        else
          requestor:truncate(56)
        end
        protobuf:setRequestor(requestor)

        local tableTags = {}
        table.insert(tableTags, "TestLabel1,TestData1")
        table.insert(tableTags, "TestLabel2,TestData2")

        protobuf:setTagArray(tableTags)

        protobuf:setTag('TestLabel3,TestData3')

        protobuf:setTag("Response,456")

      else

        local tableTags = {} 					-- called by testProtobuf()
        table.insert(tableTags, "TestLabel1,TestData1")
        table.insert(tableTags, "TestLabel2,TestData2")
        protobuf:setTagArray(tableTags)

        protobuf:setTag('TestLabel3,TestData3')

        protobuf:setTag("Response,456")

      end
    end

    function alterProtobufQuery(dq, protobuf)

      if luasmn:check(dq.qname) then
        requestor = newCA(dq.remoteaddr:toString())		-- called by testLuaProtobuf()
        if requestor:isIPv4() then
          requestor:truncate(24)
        else
          requestor:truncate(56)
        end
        protobuf:setRequestor(requestor)

        local tableTags = {}
        tableTags = dq:getTagArray()				-- get table from DNSQuery

        local tablePB = {}
          for k, v in pairs( tableTags) do
          table.insert(tablePB, k .. "," .. v)
        end

        protobuf:setTagArray(tablePB)				-- store table in protobuf
        protobuf:setTag("Query,123")				-- add another tag entry in protobuf

        protobuf:setResponseCode(dnsdist.NXDOMAIN)        	-- set protobuf response code to be NXDOMAIN

        local strReqName = dq.qname:toString()		  	-- get request dns name

        protobuf:setProtobufResponseType()			-- set protobuf to look like a response and not a query, with 0 default time

        blobData = '\127' .. '\000' .. '\000' .. '\001'		-- 127.0.0.1, note: lua 5.1 can only embed decimal not hex

        protobuf:addResponseRR(strReqName, 1, 1, 123, blobData) -- add a RR to the protobuf

        protobuf:setBytes(65)					-- set the size of the query to confirm in checkProtobufBase

      else

        local tableTags = {}                                    -- called by testProtobuf()
        table.insert(tableTags, "TestLabel1,TestData1")
        table.insert(tableTags, "TestLabel2,TestData2")

        protobuf:setTagArray(tableTags)
        protobuf:setTag('TestLabel3,TestData3')
        protobuf:setTag("Query,123")

      end
    end

    function alterLuaFirst(dq)					-- called when dnsdist receives new request
      local tt = {}
      tt["TestLabel1"] = "TestData1"
      tt["TestLabel2"] = "TestData2"

      dq:setTagArray(tt)

      dq:setTag("TestLabel3","TestData3")
      return DNSAction.None, ""				-- continue to the next rule
    end

    newServer{address="127.0.0.1:%s", useClientSubnet=true}
    rl = newRemoteLogger('127.0.0.1:%s')

    addAction(AllRule(), LuaAction(alterLuaFirst))							-- Add tags to DNSQuery first

    addAction(AllRule(), RemoteLogAction(rl, alterProtobufQuery, {serverID='%s'}))				-- Send protobuf message before lookup

    addResponseAction(AllRule(), RemoteLogResponseAction(rl, alterProtobufResponse, true, {serverID='%s'}))	-- Send protobuf message after lookup

    """

    @classmethod
    def ProtobufListener(cls, port):
        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
        try:
            sock.bind(("127.0.0.1", port))
        except socket.error as e:
            print("Error binding in the protbuf listener: %s" % str(e))
            sys.exit(1)

        sock.listen(100)
        while True:
            (conn, _) = sock.accept()
            data = None
            while True:
                data = conn.recv(2)
                if not data:
                    break
                (datalen, ) = struct.unpack("!H", data)
                data = conn.recv(datalen)
                if not data:
                    break

                cls._protobufQueue.put(data, True, timeout=2.0)

            conn.close()
        sock.close()

    @classmethod
    def startResponders(cls):
        cls._UDPResponder = threading.Thread(name='UDP Responder',
                                             target=cls.UDPResponder,
                                             args=[
                                                 cls._testServerPort,
                                                 cls._toResponderQueue,
                                                 cls._fromResponderQueue
                                             ])
        cls._UDPResponder.setDaemon(True)
        cls._UDPResponder.start()

        cls._TCPResponder = threading.Thread(name='TCP Responder',
                                             target=cls.TCPResponder,
                                             args=[
                                                 cls._testServerPort,
                                                 cls._toResponderQueue,
                                                 cls._fromResponderQueue
                                             ])
        cls._TCPResponder.setDaemon(True)
        cls._TCPResponder.start()

        cls._protobufListener = threading.Thread(
            name='Protobuf Listener',
            target=cls.ProtobufListener,
            args=[cls._protobufServerPort])
        cls._protobufListener.setDaemon(True)
        cls._protobufListener.start()

    def getFirstProtobufMessage(self):
        self.assertFalse(self._protobufQueue.empty())
        data = self._protobufQueue.get(False)
        self.assertTrue(data)
        msg = dnsmessage_pb2.PBDNSMessage()
        msg.ParseFromString(data)
        return msg

    def checkProtobufBase(self,
                          msg,
                          protocol,
                          query,
                          initiator,
                          normalQueryResponse=True):
        self.assertTrue(msg)
        self.assertTrue(msg.HasField('timeSec'))
        self.assertTrue(msg.HasField('socketFamily'))
        self.assertEquals(msg.socketFamily, dnsmessage_pb2.PBDNSMessage.INET)
        self.assertTrue(msg.HasField('from'))
        fromvalue = getattr(msg, 'from')
        self.assertEquals(socket.inet_ntop(socket.AF_INET, fromvalue),
                          initiator)
        self.assertTrue(msg.HasField('socketProtocol'))
        self.assertEquals(msg.socketProtocol, protocol)
        self.assertTrue(msg.HasField('messageId'))
        self.assertTrue(msg.HasField('id'))
        self.assertEquals(msg.id, query.id)
        self.assertTrue(msg.HasField('inBytes'))
        self.assertTrue(msg.HasField('serverIdentity'))
        self.assertEquals(msg.serverIdentity, self._protobufServerID)

        if normalQueryResponse:
            # compare inBytes with length of query/response
            self.assertEquals(msg.inBytes, len(query.to_wire()))
        # dnsdist doesn't set the existing EDNS Subnet for now,
        # although it might be set from Lua
        # self.assertTrue(msg.HasField('originalRequestorSubnet'))
        # self.assertEquals(len(msg.originalRequestorSubnet), 4)
        # self.assertEquals(socket.inet_ntop(socket.AF_INET, msg.originalRequestorSubnet), '127.0.0.1')

    def checkProtobufQuery(self,
                           msg,
                           protocol,
                           query,
                           qclass,
                           qtype,
                           qname,
                           initiator='127.0.0.1'):
        self.assertEquals(msg.type, dnsmessage_pb2.PBDNSMessage.DNSQueryType)
        self.checkProtobufBase(msg, protocol, query, initiator)
        # dnsdist doesn't fill the responder field for responses
        # because it doesn't keep the information around.
        self.assertTrue(msg.HasField('to'))
        self.assertEquals(socket.inet_ntop(socket.AF_INET, msg.to),
                          '127.0.0.1')
        self.assertTrue(msg.HasField('question'))
        self.assertTrue(msg.question.HasField('qClass'))
        self.assertEquals(msg.question.qClass, qclass)
        self.assertTrue(msg.question.HasField('qType'))
        self.assertEquals(msg.question.qClass, qtype)
        self.assertTrue(msg.question.HasField('qName'))
        self.assertEquals(msg.question.qName, qname)

    def checkProtobufTags(self, tags, expectedTags):
        # only differences will be in new list
        listx = set(tags) ^ set(expectedTags)
        # exclusive or of lists should be empty
        self.assertEqual(len(listx), 0, "Protobuf tags don't match")

    def checkProtobufQueryConvertedToResponse(self,
                                              msg,
                                              protocol,
                                              response,
                                              initiator='127.0.0.0'):
        self.assertEquals(msg.type,
                          dnsmessage_pb2.PBDNSMessage.DNSResponseType)
        # skip comparing inBytes (size of the query) with the length of the generated response
        self.checkProtobufBase(msg, protocol, response, initiator, False)
        self.assertTrue(msg.HasField('response'))
        self.assertTrue(msg.response.HasField('queryTimeSec'))

    def checkProtobufResponse(self,
                              msg,
                              protocol,
                              response,
                              initiator='127.0.0.1'):
        self.assertEquals(msg.type,
                          dnsmessage_pb2.PBDNSMessage.DNSResponseType)
        self.checkProtobufBase(msg, protocol, response, initiator)
        self.assertTrue(msg.HasField('response'))
        self.assertTrue(msg.response.HasField('queryTimeSec'))

    def checkProtobufResponseRecord(self, record, rclass, rtype, rname, rttl):
        self.assertTrue(record.HasField('class'))
        self.assertEquals(getattr(record, 'class'), rclass)
        self.assertTrue(record.HasField('type'))
        self.assertEquals(record.type, rtype)
        self.assertTrue(record.HasField('name'))
        self.assertEquals(record.name, rname)
        self.assertTrue(record.HasField('ttl'))
        self.assertEquals(record.ttl, rttl)
        self.assertTrue(record.HasField('rdata'))

    def testProtobuf(self):
        """
        Protobuf: Send data to a protobuf server
        """
        name = 'query.protobuf.tests.powerdns.com.'

        target = 'target.protobuf.tests.powerdns.com.'
        query = dns.message.make_query(name, 'A', 'IN')
        response = dns.message.make_response(query)

        rrset = dns.rrset.from_text(name, 3600, dns.rdataclass.IN,
                                    dns.rdatatype.CNAME, target)
        response.answer.append(rrset)

        rrset = dns.rrset.from_text(target, 3600, dns.rdataclass.IN,
                                    dns.rdatatype.A, '127.0.0.1')
        response.answer.append(rrset)

        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
        self.assertTrue(receivedQuery)
        self.assertTrue(receivedResponse)
        receivedQuery.id = query.id
        self.assertEquals(query, receivedQuery)
        self.assertEquals(response, receivedResponse)

        # let the protobuf messages the time to get there
        time.sleep(1)

        # check the protobuf message corresponding to the UDP query
        msg = self.getFirstProtobufMessage()

        self.checkProtobufQuery(msg, dnsmessage_pb2.PBDNSMessage.UDP, query,
                                dns.rdataclass.IN, dns.rdatatype.A, name)
        self.checkProtobufTags(msg.response.tags, [
            u"TestLabel1,TestData1", u"TestLabel2,TestData2",
            u"TestLabel3,TestData3", u"Query,123"
        ])

        # check the protobuf message corresponding to the UDP response
        msg = self.getFirstProtobufMessage()
        self.checkProtobufResponse(msg, dnsmessage_pb2.PBDNSMessage.UDP,
                                   response)
        self.checkProtobufTags(msg.response.tags, [
            u"TestLabel1,TestData1", u"TestLabel2,TestData2",
            u"TestLabel3,TestData3", u"Response,456"
        ])
        self.assertEquals(len(msg.response.rrs), 2)
        rr = msg.response.rrs[0]
        self.checkProtobufResponseRecord(rr, dns.rdataclass.IN,
                                         dns.rdatatype.CNAME, name, 3600)
        self.assertEquals(rr.rdata, target)
        rr = msg.response.rrs[1]
        self.checkProtobufResponseRecord(rr, dns.rdataclass.IN,
                                         dns.rdatatype.A, target, 3600)
        self.assertEquals(socket.inet_ntop(socket.AF_INET, rr.rdata),
                          '127.0.0.1')

        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
        self.assertTrue(receivedQuery)
        self.assertTrue(receivedResponse)
        receivedQuery.id = query.id
        self.assertEquals(query, receivedQuery)
        self.assertEquals(response, receivedResponse)

        # let the protobuf messages the time to get there
        time.sleep(1)

        # check the protobuf message corresponding to the TCP query
        msg = self.getFirstProtobufMessage()

        self.checkProtobufQuery(msg, dnsmessage_pb2.PBDNSMessage.TCP, query,
                                dns.rdataclass.IN, dns.rdatatype.A, name)
        self.checkProtobufTags(msg.response.tags, [
            u"TestLabel1,TestData1", u"TestLabel2,TestData2",
            u"TestLabel3,TestData3", u"Query,123"
        ])

        # check the protobuf message corresponding to the TCP response
        msg = self.getFirstProtobufMessage()
        self.checkProtobufResponse(msg, dnsmessage_pb2.PBDNSMessage.TCP,
                                   response)
        self.checkProtobufTags(msg.response.tags, [
            u"TestLabel1,TestData1", u"TestLabel2,TestData2",
            u"TestLabel3,TestData3", u"Response,456"
        ])
        self.assertEquals(len(msg.response.rrs), 2)
        rr = msg.response.rrs[0]
        self.checkProtobufResponseRecord(rr, dns.rdataclass.IN,
                                         dns.rdatatype.CNAME, name, 3600)
        self.assertEquals(rr.rdata, target)
        rr = msg.response.rrs[1]
        self.checkProtobufResponseRecord(rr, dns.rdataclass.IN,
                                         dns.rdatatype.A, target, 3600)
        self.assertEquals(socket.inet_ntop(socket.AF_INET, rr.rdata),
                          '127.0.0.1')

    def testLuaProtobuf(self):
        """
        Protobuf: Check that the Lua callback rewrote the initiator
        """
        name = 'lua.protobuf.tests.powerdns.com.'
        query = dns.message.make_query(name, 'A', 'IN')
        response = dns.message.make_response(query)
        rrset = dns.rrset.from_text(name, 3600, dns.rdataclass.IN,
                                    dns.rdatatype.A, '127.0.0.1')
        response.answer.append(rrset)

        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)

        self.assertTrue(receivedQuery)
        self.assertTrue(receivedResponse)
        receivedQuery.id = query.id
        self.assertEquals(query, receivedQuery)
        self.assertEquals(response, receivedResponse)

        # let the protobuf messages the time to get there
        time.sleep(1)

        # check the protobuf message corresponding to the UDP query
        msg = self.getFirstProtobufMessage()

        self.checkProtobufQueryConvertedToResponse(
            msg, dnsmessage_pb2.PBDNSMessage.UDP, response, '127.0.0.0')
        self.checkProtobufTags(msg.response.tags, [
            u"TestLabel1,TestData1", u"TestLabel2,TestData2",
            u"TestLabel3,TestData3", u"Query,123"
        ])

        # check the protobuf message corresponding to the UDP response
        msg = self.getFirstProtobufMessage()
        self.checkProtobufResponse(msg, dnsmessage_pb2.PBDNSMessage.UDP,
                                   response, '127.0.0.0')
        self.checkProtobufTags(msg.response.tags, [
            u"TestLabel1,TestData1", u"TestLabel2,TestData2",
            u"TestLabel3,TestData3", u"Response,456"
        ])
        self.assertEquals(len(msg.response.rrs), 1)
        for rr in msg.response.rrs:
            self.checkProtobufResponseRecord(rr, dns.rdataclass.IN,
                                             dns.rdatatype.A, name, 3600)
            self.assertEquals(socket.inet_ntop(socket.AF_INET, rr.rdata),
                              '127.0.0.1')

        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
        self.assertTrue(receivedQuery)
        self.assertTrue(receivedResponse)
        receivedQuery.id = query.id
        self.assertEquals(query, receivedQuery)
        self.assertEquals(response, receivedResponse)

        # let the protobuf messages the time to get there
        time.sleep(1)

        # check the protobuf message corresponding to the TCP query
        msg = self.getFirstProtobufMessage()
        self.checkProtobufQueryConvertedToResponse(
            msg, dnsmessage_pb2.PBDNSMessage.TCP, response, '127.0.0.0')
        self.checkProtobufTags(msg.response.tags, [
            u"TestLabel1,TestData1", u"TestLabel2,TestData2",
            u"TestLabel3,TestData3", u"Query,123"
        ])

        # check the protobuf message corresponding to the TCP response
        msg = self.getFirstProtobufMessage()
        self.checkProtobufResponse(msg, dnsmessage_pb2.PBDNSMessage.TCP,
                                   response, '127.0.0.0')
        self.checkProtobufTags(msg.response.tags, [
            u"TestLabel1,TestData1", u"TestLabel2,TestData2",
            u"TestLabel3,TestData3", u"Response,456"
        ])
        self.assertEquals(len(msg.response.rrs), 1)
        for rr in msg.response.rrs:
            self.checkProtobufResponseRecord(rr, dns.rdataclass.IN,
                                             dns.rdatatype.A, name, 3600)
            self.assertEquals(socket.inet_ntop(socket.AF_INET, rr.rdata),
                              '127.0.0.1')
Exemple #5
0
class TestCarbon(DNSDistTest):

    _carbonServer1Port = 8000
    _carbonServer1Name = "carbonname1"
    _carbonServer2Port = 8001
    _carbonServer2Name = "carbonname2"
    _carbonQueue1 = Queue()
    _carbonQueue2 = Queue()
    _carbonInterval = 2
    _carbonCounters = {}
    _config_params = ['_carbonServer1Port', '_carbonServer1Name', '_carbonInterval',
                      '_carbonServer2Port', '_carbonServer2Name', '_carbonInterval']
    _config_template = """
    s = newServer{address="127.0.0.1:5353"}
    s:setDown()
    s = newServer{address="127.0.0.1:5354"}
    s:setUp()
    s = newServer{address="127.0.0.1:5355"}
    s:setUp()
    carbonServer("127.0.0.1:%s", "%s", %s)
    carbonServer("127.0.0.1:%s", "%s", %s)
    """

    @classmethod
    def CarbonResponder(cls, port):
        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
        try:
            sock.bind(("127.0.0.1", port))
        except socket.error as e:
            print("Error binding in the Carbon responder: %s" % str(e))
            sys.exit(1)

        sock.listen(100)
        while True:
            (conn, _) = sock.accept()
            conn.settimeout(2.0)
            lines = b''
            while True:
                data = conn.recv(4096)
                if not data:
                    break
                lines += data

            if port == cls._carbonServer1Port:
                cls._carbonQueue1.put(lines, True, timeout=2.0)
            else:
                cls._carbonQueue2.put(lines, True, timeout=2.0)
            if threading.currentThread().name in cls._carbonCounters:
                cls._carbonCounters[threading.currentThread().name] += 1
            else:
                cls._carbonCounters[threading.currentThread().name] = 1

            conn.close()
        sock.close()

    @classmethod
    def startResponders(cls):
        cls._CarbonResponder1 = threading.Thread(name='Carbon Responder 1', target=cls.CarbonResponder, args=[cls._carbonServer1Port])
        cls._CarbonResponder1.setDaemon(True)
        cls._CarbonResponder1.start()

        cls._CarbonResponder2 = threading.Thread(name='Carbon Responder 2', target=cls.CarbonResponder, args=[cls._carbonServer2Port])
        cls._CarbonResponder2.setDaemon(True)
        cls._CarbonResponder2.start()

    def testCarbon(self):
        """
        Carbon: send data to 2 carbon servers
        """
        # wait for the carbon data to be sent
        time.sleep(self._carbonInterval + 1)

        # first server
        self.assertFalse(self._carbonQueue1.empty())
        data1 = self._carbonQueue1.get(False)
        # second server
        self.assertFalse(self._carbonQueue2.empty())
        data2 = self._carbonQueue2.get(False)
        after = time.time()

        self.assertTrue(data1)
        self.assertTrue(len(data1.splitlines()) > 1)
        expectedStart = b"dnsdist.%s.main." % self._carbonServer1Name.encode('UTF-8')
        for line in data1.splitlines():
            self.assertTrue(line.startswith(expectedStart))
            parts = line.split(b' ')
            self.assertEquals(len(parts), 3)
            self.assertTrue(parts[1].isdigit())
            self.assertTrue(parts[2].isdigit())
            self.assertTrue(int(parts[2]) <= int(after))

        self.assertTrue(data2)
        self.assertTrue(len(data2.splitlines()) > 1)
        expectedStart = b"dnsdist.%s.main." % self._carbonServer2Name.encode('UTF-8')
        for line in data2.splitlines():
            self.assertTrue(line.startswith(expectedStart))
            parts = line.split(b' ')
            self.assertEquals(len(parts), 3)
            self.assertTrue(parts[1].isdigit())
            self.assertTrue(parts[2].isdigit())
            self.assertTrue(int(parts[2]) <= int(after))

        # make sure every carbon server has received at least one connection
        for key in self._carbonCounters:
            value = self._carbonCounters[key]
            self.assertTrue(value >= 1)

    def testCarbonServerUp(self):
        """
        Carbon: set up 2 carbon servers
        """
        # wait for the carbon data to be sent
        time.sleep(self._carbonInterval + 1)

        # first server
        self.assertFalse(self._carbonQueue1.empty())
        data1 = self._carbonQueue1.get(False)
        # second server
        self.assertFalse(self._carbonQueue2.empty())
        data2 = self._carbonQueue2.get(False)
        after = time.time()

        # check the first carbon server got both servers and
        # servers-up metrics and that they are the same as
        # configured in the class definition
        self.assertTrue(data1)
        self.assertTrue(len(data1.splitlines()) > 1)
        expectedStart = b"dnsdist.%s.main.pools._default_.servers" % self._carbonServer1Name.encode('UTF-8')
        for line in data1.splitlines():
            if expectedStart in line:
                parts = line.split(b' ')
                if 'servers-up' in line:
                    self.assertEquals(len(parts), 3)
                    self.assertTrue(parts[1].isdigit())
                    self.assertEquals(int(parts[1]), 2)
                    self.assertTrue(parts[2].isdigit())
                    self.assertTrue(int(parts[2]) <= int(after))
                else:
                    self.assertEquals(len(parts), 3)
                    self.assertTrue(parts[1].isdigit())
                    self.assertEquals(int(parts[1]), 3)
                    self.assertTrue(parts[2].isdigit())
                    self.assertTrue(int(parts[2]) <= int(after))

        # check the second carbon server got both servers and
        # servers-up metrics and that they are the same as
        # configured in the class definition and the same as
        # the first carbon server
        self.assertTrue(data2)
        self.assertTrue(len(data2.splitlines()) > 1)
        expectedStart = b"dnsdist.%s.main.pools._default_.servers" % self._carbonServer2Name.encode('UTF-8')
        for line in data2.splitlines():
            if expectedStart in line:
                parts = line.split(b' ')
                if 'servers-up' in line:
                    self.assertEquals(len(parts), 3)
                    self.assertTrue(parts[1].isdigit())
                    self.assertEquals(int(parts[1]), 2)
                    self.assertTrue(parts[2].isdigit())
                    self.assertTrue(int(parts[2]) <= int(after))
                else:
                    self.assertEquals(len(parts), 3)
                    self.assertTrue(parts[1].isdigit())
                    self.assertEquals(int(parts[1]), 3)
                    self.assertTrue(parts[2].isdigit())
                    self.assertTrue(int(parts[2]) <= int(after))
Exemple #6
0
class TestTeeAction(DNSDistTest):

    _consoleKey = DNSDistTest.generateConsoleKey()
    _consoleKeyB64 = base64.b64encode(_consoleKey).decode('ascii')
    _teeServerPort = 5390
    _toTeeQueue = Queue()
    _fromTeeQueue = Queue()
    _config_template = """
    setKey("%s")
    controlSocket("127.0.0.1:%s")
    newServer{address="127.0.0.1:%d"}
    addAction(QTypeRule(DNSQType.A), TeeAction("127.0.0.1:%d", true))
    addAction(QTypeRule(DNSQType.AAAA), TeeAction("127.0.0.1:%d", false))
    """
    _config_params = [
        '_consoleKeyB64', '_consolePort', '_testServerPort', '_teeServerPort',
        '_teeServerPort'
    ]

    @classmethod
    def startResponders(cls):
        print("Launching responders..")

        cls._UDPResponder = threading.Thread(name='UDP Responder',
                                             target=cls.UDPResponder,
                                             args=[
                                                 cls._testServerPort,
                                                 cls._toResponderQueue,
                                                 cls._fromResponderQueue
                                             ])
        cls._UDPResponder.setDaemon(True)
        cls._UDPResponder.start()

        cls._TCPResponder = threading.Thread(name='TCP Responder',
                                             target=cls.TCPResponder,
                                             args=[
                                                 cls._testServerPort,
                                                 cls._toResponderQueue,
                                                 cls._fromResponderQueue,
                                                 False, True
                                             ])
        cls._TCPResponder.setDaemon(True)
        cls._TCPResponder.start()

        cls._TeeResponder = threading.Thread(
            name='Tee Responder',
            target=cls.UDPResponder,
            args=[cls._teeServerPort, cls._toTeeQueue, cls._fromTeeQueue])
        cls._TeeResponder.setDaemon(True)
        cls._TeeResponder.start()

    def testTeeWithECS(self):
        """
        TeeAction: ECS
        """
        name = 'ecs.tee.tests.powerdns.com.'
        query = dns.message.make_query(name, 'A', 'IN')
        response = dns.message.make_response(query)

        rrset = dns.rrset.from_text(name, 3600, dns.rdataclass.IN,
                                    dns.rdatatype.A, '192.0.2.1')
        response.answer.append(rrset)

        numberOfQueries = 10
        for _ in range(numberOfQueries):
            # push the response to the Tee server
            self._toTeeQueue.put(response, True, 2.0)

            (receivedQuery,
             receivedResponse) = self.sendUDPQuery(query, response)
            self.assertTrue(receivedQuery)
            self.assertTrue(receivedResponse)
            receivedQuery.id = query.id
            self.assertEquals(query, receivedQuery)
            self.assertEquals(response, receivedResponse)

            # retrieve the query from the Tee server
            teedQuery = self._fromTeeQueue.get(True, 2.0)
            ecso = clientsubnetoption.ClientSubnetOption('127.0.0.1', 24)
            expectedQuery = dns.message.make_query(name,
                                                   'A',
                                                   'IN',
                                                   use_edns=True,
                                                   options=[ecso],
                                                   payload=512)
            expectedQuery.id = query.id
            self.checkQueryEDNSWithECS(expectedQuery, teedQuery)

        # check the TeeAction stats
        stats = self.sendConsoleCommand("getAction(0):printStats()")
        self.assertEquals(
            stats, """noerrors\t%d
nxdomains\t0
other-rcode\t0
queries\t%d
recv-errors\t0
refuseds\t0
responses\t%d
send-errors\t0
servfails\t0
tcp-drops\t0
""" % (numberOfQueries, numberOfQueries, numberOfQueries))

    def testTeeWithoutECS(self):
        """
        TeeAction: No ECS
        """
        name = 'noecs.tee.tests.powerdns.com.'
        query = dns.message.make_query(name, 'AAAA', 'IN')
        response = dns.message.make_response(query)

        rrset = dns.rrset.from_text(name, 3600, dns.rdataclass.IN,
                                    dns.rdatatype.AAAA, '2001:DB8::1')
        response.answer.append(rrset)

        numberOfQueries = 10
        for _ in range(numberOfQueries):
            # push the response to the Tee server
            self._toTeeQueue.put(response, True, 2.0)

            (receivedQuery,
             receivedResponse) = self.sendUDPQuery(query, response)
            self.assertTrue(receivedQuery)
            self.assertTrue(receivedResponse)
            receivedQuery.id = query.id
            self.assertEquals(query, receivedQuery)
            self.assertEquals(response, receivedResponse)

            # retrieve the query from the Tee server
            teedQuery = self._fromTeeQueue.get(True, 2.0)
            ecso = clientsubnetoption.ClientSubnetOption('127.0.0.1', 24)
            expectedQuery = dns.message.make_query(name,
                                                   'AAAA',
                                                   'IN',
                                                   use_edns=True,
                                                   options=[ecso],
                                                   payload=512)
            expectedQuery.id = query.id
            self.checkMessageNoEDNS(expectedQuery, teedQuery)

        # check the TeeAction stats
        stats = self.sendConsoleCommand("getAction(0):printStats()")
        self.assertEquals(
            stats, """noerrors\t%d
nxdomains\t0
other-rcode\t0
queries\t%d
recv-errors\t0
refuseds\t0
responses\t%d
send-errors\t0
servfails\t0
tcp-drops\t0
""" % (numberOfQueries, numberOfQueries, numberOfQueries))