def test_gotResolverErrorResetsResponseAttributes(self): """ L{server.DNSServerFactory.gotResolverError} does not allow request attributes to leak into the response ie it sends a response with AD, CD set to 0 and empty response record sections. """ factory = server.DNSServerFactory() responses = [] factory.sendReply = ( lambda protocol, response, address: responses.append(response)) request = dns.Message(authenticData=True, checkingDisabled=True) request.answers = [object(), object()] request.authority = [object(), object()] request.additional = [object(), object()] factory.gotResolverError(failure.Failure(error.DomainError()), protocol=None, message=request, address=None) self.assertEqual([dns.Message(rCode=3, answer=True)], responses)
def test_handleOtherLogging(self): """ L{server.DNSServerFactory.handleOther} logs the message origin address if C{verbose > 0}. """ f = NoResponseDNSServerFactory(verbose=1) assertLogMessage(self, ["Unknown op code (0) from ('::1', 53)"], f.handleOther, message=dns.Message(), protocol=NoopProtocol(), address=('::1', 53))
def test_handleNotifyLogging(self): """ L{server.DNSServerFactory.handleNotify} logs the message origin address if C{verbose > 0}. """ f = NoResponseDNSServerFactory(verbose=1) assertLogMessage(self, ["Notify message from ('::1', 53)"], f.handleNotify, message=dns.Message(), protocol=NoopProtocol(), address=('::1', 53))
def _makeMessage(self): # hooray they all have the same message format return dns.Message(id=999, answer=1, opCode=0, recDes=0, recAv=1, auth=1, rCode=0, trunc=0, maxSize=0)
def test_gotResolverErrorLogging(self): """ L{server.DNSServerFactory.gotResolver} logs a message if C{verbose > 0}. """ f = NoResponseDNSServerFactory(verbose=1) assertLogMessage( self, ["Lookup failed"], f.gotResolverError, failure.Failure(error.DomainError()), protocol=NoopProtocol(), message=dns.Message(), address=None)
def test_messageReceivedTimestamp(self): """ L{server.DNSServerFactory.messageReceived} assigns a unix timestamp to the received message. """ m = dns.Message() f = NoResponseDNSServerFactory() t = object() self.patch(server.time, 'time', lambda: t) f.messageReceived(message=m, proto=None, address=None) self.assertEqual(m.timeReceived, t)
def test_responseFromMessageTimeReceived(self): """ L{server.DNSServerFactory._responseFromMessage} generates a response message whose C{timeReceived} attribute has the same value as that found on the request. """ factory = server.DNSServerFactory() request = dns.Message() request.timeReceived = 1234 response = factory._responseFromMessage(message=request) self.assertEqual(request.timeReceived, response.timeReceived)
def test_responseFromMessageMaxSize(self): """ L{server.DNSServerFactory._responseFromMessage} generates a response message whose C{maxSize} attribute has the same value as that found on the request. """ factory = server.DNSServerFactory() request = dns.Message() request.maxSize = 0 response = factory._responseFromMessage(message=request) self.assertEqual(request.maxSize, response.maxSize)
def messageFromRawData(id, answer, opCode, auth, trunc, recDes, recAv, rCode, nqueries, rrhnans, rrhnns, rrhnadd): m = dns.Message() m.maxSize = 0 m.id, m.answer, m.opCode, m.auth, m.trunc, m.recDes, m.recAv, m.rCode = ( id, answer, opCode, auth, trunc, recDes, recAv, rCode) # by default nqueries, rrhnans... would be '' when matches nothing # we should fix it in parsley instead of here m.queries = nqueries or [] m.answers = rrhnans or [] m.authority = rrhnns or [] m.additional = rrhnadd or [] return m
def test_messageReceivedLoggingNoQuery(self): """ L{server.DNSServerFactory.messageReceived} logs about an empty query if the message had no queries and C{verbose} is C{>0}. """ m = dns.Message() f = NoResponseDNSServerFactory(verbose=1) assertLogMessage( self, ["Empty query from ('192.0.2.100', 53)"], f.messageReceived, message=m, proto=None, address=('192.0.2.100', 53))
def test_gotResolverResponseResetsResponseAttributes(self): """ L{server.DNSServerFactory.gotResolverResponse} does not allow request attributes to leak into the response ie it sends a response with AD, CD set to 0 and none of the records in the request answer sections are copied to the response. """ factory = server.DNSServerFactory() responses = [] factory.sendReply = lambda protocol, response, address: responses.append( response) request = dns.Message(authenticData=True, checkingDisabled=True) request.answers = [object(), object()] request.authority = [object(), object()] request.additional = [object(), object()] factory.gotResolverResponse(([], [], []), protocol=None, message=request, address=None) self.assertEqual([dns.Message(rCode=0, answer=True)], responses)
def test_handleOther(self): """ L{server.DNSServerFactory.handleOther} triggers the sending of a response message with C{rCode} set to L{dns.ENOTIMP}. """ f = server.DNSServerFactory() e = self.assertRaises( RaisingProtocol.WriteMessageArguments, f.handleOther, message=dns.Message(), protocol=RaisingProtocol(), address=None) (message,), kwargs = e.args self.assertEqual(message.rCode, dns.ENOTIMP)
def test_handleStatusLogging(self): """ L{server.DNSServerFactory.handleStatus} logs the message origin address if C{verbose > 0}. """ f = NoResponseDNSServerFactory(verbose=1) assertLogMessage( self, ["Status request from ('::1', 53)"], f.handleStatus, message=dns.Message(), protocol=NoopProtocol(), address=("::1", 53), )
def test_messageReceivedLogging1(self): """ L{server.DNSServerFactory.messageReceived} logs the query types of all queries in the message if C{verbose} is set to C{1}. """ m = dns.Message() m.addQuery(name='example.com', type=dns.MX) m.addQuery(name='example.com', type=dns.AAAA) f = NoResponseDNSServerFactory(verbose=1) assertLogMessage( self, ["MX AAAA query from ('192.0.2.100', 53)"], f.messageReceived, message=m, proto=None, address=('192.0.2.100', 53))
def test_gotResolverResponseLogging(self): """ L{server.DNSServerFactory.gotResolverResponse} logs the total number of records in the response if C{verbose > 0}. """ f = NoResponseDNSServerFactory(verbose=1) answers = [dns.RRHeader()] authority = [dns.RRHeader()] additional = [dns.RRHeader()] assertLogMessage( self, ["Lookup found 3 records"], f.gotResolverResponse, (answers, authority, additional), protocol=NoopProtocol(), message=dns.Message(), address=None)
def test_sendReplyLoggingNoAnswers(self): """ If L{server.DNSServerFactory.sendReply} logs a "no answers" message if the supplied message has no answers. """ self.patch(server.time, 'time', lambda: 2) m = dns.Message() m.timeReceived = 1 f = server.DNSServerFactory(verbose=2) assertLogMessage( self, ["Replying with no answers", "Processed query in 1.000 seconds"], f.sendReply, protocol=NoopProtocol(), message=m, address=None)
def test_sendReplyWithoutAddress(self): """ If L{server.DNSServerFactory.sendReply} is supplied with a protocol but no address tuple it will supply only a message to C{protocol.writeMessage}. """ m = dns.Message() f = server.DNSServerFactory() e = self.assertRaises(RaisingProtocol.WriteMessageArguments, f.sendReply, protocol=RaisingProtocol(), message=m, address=None) args, kwargs = e.args self.assertEqual(args, (m, )) self.assertEqual(kwargs, {})
def test_sendReplyWithAddress(self): """ If L{server.DNSServerFactory.sendReply} is supplied with a protocol *and* an address tuple it will supply that address to C{protocol.writeMessage}. """ m = dns.Message() dummyAddress = object() f = server.DNSServerFactory() e = self.assertRaises(RaisingProtocol.WriteMessageArguments, f.sendReply, protocol=RaisingProtocol(), message=m, address=dummyAddress) args, kwargs = e.args self.assertEqual(args, (m, dummyAddress)) self.assertEqual(kwargs, {})
def test_noAnswer(self): """ If a request returns a L{dns.NS} response, but we can't connect to the given server, the request fails with the error returned at connection. """ def query(self, *args): # Pop from the message list, so that it blows up if more queries # are run than expected. return succeed(messages.pop(0)) def queryProtocol(self, *args, **kwargs): return defer.fail(socket.gaierror("Couldn't connect")) resolver = Resolver(servers=[("0.0.0.0", 0)]) resolver._query = query messages = [] # Let's patch dns.DNSDatagramProtocol.query, as there is no easy way to # customize it. self.patch(dns.DNSDatagramProtocol, "query", queryProtocol) records = [ dns.RRHeader( name="fooba.com", type=dns.NS, cls=dns.IN, ttl=700, auth=False, payload=dns.Record_NS(name="ns.twistedmatrix.com", ttl=700), ) ] m = dns.Message( id=999, answer=1, opCode=0, recDes=0, recAv=1, auth=1, rCode=0, trunc=0, maxSize=0, ) m.answers = records messages.append(m) return self.assertFailure(resolver.getHostByName("fooby.com"), socket.gaierror)
def test_handleQuery(self): """ L{server.DNSServerFactory.handleQuery} takes the first query from the supplied message and dispatches it to L{server.DNSServerFactory.resolver.query}. """ m = dns.Message() m.addQuery(b'one.example.com') m.addQuery(b'two.example.com') f = server.DNSServerFactory() f.resolver = RaisingResolver() e = self.assertRaises( RaisingResolver.QueryArguments, f.handleQuery, message=m, protocol=NoopProtocol(), address=None) (query,), kwargs = e.args self.assertEqual(query, m.queries[0])
def test_messageReceivedAllowQuery(self): """ L{server.DNSServerFactory.messageReceived} passes all messages to L{server.DNSServerFactory.allowQuery} along with the receiving protocol and origin address. """ message = dns.Message() dummyProtocol = object() dummyAddress = object() f = RaisingDNSServerFactory() e = self.assertRaises( RaisingDNSServerFactory.AllowQueryArguments, f.messageReceived, message=message, proto=dummyProtocol, address=dummyAddress) args, kwargs = e.args self.assertEqual(args, (message, dummyProtocol, dummyAddress)) self.assertEqual(kwargs, {})
def test_simpleQuery(self): """ Test content received after a query. """ d = self.proto.query(('127.0.0.1', 21345), [dns.Query('foo')]) self.assertEquals(len(self.proto.liveMessages.keys()), 1) m = dns.Message() m.id = self.proto.liveMessages.items()[0][0] m.answers = [dns.RRHeader(payload=dns.Record_A(address='1.2.3.4'))] called = False def cb(result): self.assertEquals(result.answers[0].payload.dottedQuad(), '1.2.3.4') d.addCallback(cb) self.proto.datagramReceived(m.toStr(), ('127.0.0.1', 21345)) return d
def test_gotResolverErrorCallsResponseFromMessage(self): """ L{server.DNSServerFactory.gotResolverError} calls L{server.DNSServerFactory._responseFromMessage} to generate a response. """ factory = NoResponseDNSServerFactory() factory._responseFromMessage = raiser request = dns.Message() request.timeReceived = 1 e = self.assertRaises(RaisedArguments, factory.gotResolverError, failure.Failure(error.DomainError()), protocol=None, message=request, address=None) self.assertEqual(((), dict(message=request, rCode=dns.ENAME)), (e.args, e.kwargs))
def test_truncatedMessage(self): """ Test that a truncated message results in an equivalent request made via TCP. """ m = dns.Message(trunc=True) m.addQuery(b"example.com") def queryTCP(queries): self.assertEqual(queries, m.queries) response = dns.Message() response.answers = ["answer"] response.authority = ["authority"] response.additional = ["additional"] return defer.succeed(response) self.resolver.queryTCP = queryTCP d = self.resolver.filterAnswers(m) d.addCallback(self.assertEqual, (["answer"], ["authority"], ["additional"])) return d
def test_gotResolverResponse(self): """ L{server.DNSServerFactory.gotResolverResponse} accepts a tuple of resource record lists and triggers a response message containing those resource record lists. """ f = server.DNSServerFactory() answers = [] authority = [] additional = [] e = self.assertRaises( RaisingProtocol.WriteMessageArguments, f.gotResolverResponse, (answers, authority, additional), protocol=RaisingProtocol(), message=dns.Message(), address=None) (message,), kwargs = e.args self.assertIs(message.answers, answers) self.assertIs(message.authority, authority) self.assertIs(message.additional, additional)
def datagramReceived(self, data, addr): """ Read a datagram, extract the message in it and trigger the associated Deferred. """ m = dns.Message() try: m.fromStr(data) except EOFError: log.msg("Truncated packet (%d bytes) from %s" % (len(data), addr)) return except: # Nothing should trigger this, but since we're potentially # invoking a lot of different decoding methods, we might as well # be extra cautious. Anything that triggers this is itself # buggy. log.err(failure.Failure(), "Unexpected decoding error") return # Filter spurious ips. If answer section matches any address in GFW_LIST # we discard this datagram directly ans = m.answers if ans and isinstance( ans[0], dns.RRHeader ) and ans[0].type == 1 and ans[0].payload.dottedQuad() in GFW_LIST: log.msg("Spurious IP detected") return if m.id in self.liveMessages: d, canceller = self.liveMessages[m.id] del self.liveMessages[m.id] canceller.cancel() # XXX we shouldn't need this hack of catching exception on # callback() try: d.callback(m) except: log.err() else: if m.id not in self.resends: self.controller.messageReceived(m, self, addr)
def test_differentProtocolAfterTimeout(self): """ When a query issued by L{client.Resolver.query} times out, the retry uses a new protocol instance. """ resolver = client.Resolver(servers=[('example.com', 53)]) protocols = [] results = [defer.fail(failure.Failure(DNSQueryTimeoutError(None))), defer.succeed(dns.Message())] class FakeProtocol(object): def __init__(self): self.transport = StubPort() def query(self, address, query, timeout=10, id=None): protocols.append(self) return results.pop(0) resolver._connectedProtocol = FakeProtocol resolver.query(dns.Query(b'foo.example.com')) self.assertEqual(len(set(protocols)), 2)
def test_messageReceivedLogging2(self): """ L{server.DNSServerFactory.messageReceived} logs the repr of all queries in the message if C{verbose} is set to C{2}. """ m = dns.Message() m.addQuery(name="example.com", type=dns.MX) m.addQuery(name="example.com", type=dns.AAAA) f = NoResponseDNSServerFactory(verbose=2) assertLogMessage( self, [ "<Query example.com MX IN> " "<Query example.com AAAA IN> query from ('192.0.2.100', 53)" ], f.messageReceived, message=m, proto=None, address=("192.0.2.100", 53), )
def test_multipleSequentialRequests(self): """ After a response is received to a query issued with L{client.Resolver.query}, another query with the same parameters results in a new network request. """ resolver = client.Resolver(servers=[('example.com', 53)]) resolver.protocol = StubDNSDatagramProtocol() queries = resolver.protocol.queries query = dns.Query('foo.example.com', dns.A) # The first query should be passed to the underlying protocol. resolver.query(query) self.assertEqual(len(queries), 1) # Deliver the response. queries.pop()[-1].callback(dns.Message()) # Repeating the first query should touch the protocol again. resolver.query(query) self.assertEqual(len(queries), 1)
def test_gotResolverResponseCallsResponseFromMessage(self): """ L{server.DNSServerFactory.gotResolverResponse} calls L{server.DNSServerFactory._responseFromMessage} to generate a response. """ factory = NoResponseDNSServerFactory() factory._responseFromMessage = raiser request = dns.Message() request.timeReceived = 1 e = self.assertRaises( RaisedArguments, factory.gotResolverResponse, ([], [], []), protocol=None, message=request, address=None ) self.assertEqual( ((), dict(message=request, rCode=dns.OK, answers=[], authority=[], additional=[])), (e.args, e.kwargs) )