def test_tc_bit_not_set_in_answer_packet(): """Verify the TC bit is not set when there are no questions and answers exceed the packet size.""" out = r.DNSOutgoing(const._FLAGS_QR_RESPONSE | const._FLAGS_AA) for i in range(30): out.add_answer_at_time( DNSText( ("HASS Bridge W9DN %s._hap._tcp.local." % i), const._TYPE_TXT, const._CLASS_IN | const._CLASS_UNIQUE, const._DNS_OTHER_TTL, b'\x13md=HASS Bridge W9DN\x06pv=1.0\x14id=11:8E:DB:5B:5C:C5\x05c#=12\x04s#=1' b'\x04ff=0\x04ci=2\x04sf=0\x0bsh=6fLM5A==', ), 0, ) packets = out.packets() assert len(packets) == 3 first_packet = r.DNSIncoming(packets[0]) assert first_packet.flags & const._FLAGS_TC == 0 assert first_packet.valid is True second_packet = r.DNSIncoming(packets[1]) assert second_packet.flags & const._FLAGS_TC == 0 assert second_packet.valid is True third_packet = r.DNSIncoming(packets[2]) assert third_packet.flags & const._FLAGS_TC == 0 assert third_packet.valid is True
def test_incoming_exception_handling(self): generated = r.DNSOutgoing(0) packet = generated.packet() packet = packet[:8] + b"deadbeef" + packet[8:] parsed = r.DNSIncoming(packet) parsed = r.DNSIncoming(packet) assert parsed.valid is False
def test_tc_bit_in_query_packet(): """Verify the TC bit is set when known answers exceed the packet size.""" out = r.DNSOutgoing(const._FLAGS_QR_QUERY | const._FLAGS_AA) type_ = "_hap._tcp.local." out.add_question(r.DNSQuestion(type_, const._TYPE_PTR, const._CLASS_IN)) for i in range(30): out.add_answer_at_time( DNSText( ("HASS Bridge W9DN %s._hap._tcp.local." % i), const._TYPE_TXT, const._CLASS_IN | const._CLASS_UNIQUE, const._DNS_OTHER_TTL, b'\x13md=HASS Bridge W9DN\x06pv=1.0\x14id=11:8E:DB:5B:5C:C5\x05c#=12\x04s#=1' b'\x04ff=0\x04ci=2\x04sf=0\x0bsh=6fLM5A==', ), 0, ) packets = out.packets() assert len(packets) == 3 first_packet = r.DNSIncoming(packets[0]) assert first_packet.truncated assert first_packet.valid is True second_packet = r.DNSIncoming(packets[1]) assert second_packet.truncated assert second_packet.valid is True third_packet = r.DNSIncoming(packets[2]) assert not third_packet.truncated assert third_packet.valid is True
def test_ptr_optimization(): # instantiate a zeroconf instance zc = Zeroconf(interfaces=['127.0.0.1']) # service definition type_ = "_test-srvc-type._tcp.local." name = "xxxyyy" registration_name = "%s.%s" % (name, type_) desc = {'path': '/~paulsm/'} info = ServiceInfo( type_, registration_name, 80, 0, 0, desc, "ash-2.local.", addresses=[socket.inet_aton("10.0.1.2")] ) # register zc.register_service(info) # Verify we won't respond for 1s with the same multicast query = r.DNSOutgoing(const._FLAGS_QR_QUERY) query.add_question(r.DNSQuestion(info.type, const._TYPE_PTR, const._CLASS_IN)) unicast_out, multicast_out = zc.query_handler.response( [r.DNSIncoming(packet) for packet in query.packets()], None, const._MDNS_PORT ) assert unicast_out is None assert multicast_out is None # Clear the cache to allow responding again _clear_cache(zc) # Verify we will now respond query = r.DNSOutgoing(const._FLAGS_QR_QUERY) query.add_question(r.DNSQuestion(info.type, const._TYPE_PTR, const._CLASS_IN)) unicast_out, multicast_out = zc.query_handler.response( [r.DNSIncoming(packet) for packet in query.packets()], None, const._MDNS_PORT ) assert multicast_out.id == query.id assert unicast_out is None assert multicast_out is not None has_srv = has_txt = has_a = False nbr_additionals = 0 nbr_answers = len(multicast_out.answers) nbr_authorities = len(multicast_out.authorities) for answer in multicast_out.additionals: nbr_additionals += 1 if answer.type == const._TYPE_SRV: has_srv = True elif answer.type == const._TYPE_TXT: has_txt = True elif answer.type == const._TYPE_A: has_a = True assert nbr_answers == 1 and nbr_additionals == 3 and nbr_authorities == 0 assert has_srv and has_txt and has_a # unregister zc.unregister_service(info) zc.close()
def mock_incoming_msg( service_state_change: r.ServiceStateChange) -> r.DNSIncoming: ttl = 120 generated = r.DNSOutgoing(const._FLAGS_QR_RESPONSE) if service_state_change == r.ServiceStateChange.Updated: generated.add_answer_at_time( r.DNSText( service_name, const._TYPE_TXT, const._CLASS_IN | const._CLASS_UNIQUE, ttl, service_text, ), 0, ) return r.DNSIncoming(generated.packets()[0]) if service_state_change == r.ServiceStateChange.Removed: ttl = 0 generated.add_answer_at_time( r.DNSPointer(service_type, const._TYPE_PTR, const._CLASS_IN, ttl, service_name), 0) generated.add_answer_at_time( r.DNSService( service_name, const._TYPE_SRV, const._CLASS_IN | const._CLASS_UNIQUE, ttl, 0, 0, 80, service_server, ), 0, ) generated.add_answer_at_time( r.DNSText(service_name, const._TYPE_TXT, const._CLASS_IN | const._CLASS_UNIQUE, ttl, service_text), 0, ) generated.add_answer_at_time( r.DNSAddress( service_server, const._TYPE_A, const._CLASS_IN | const._CLASS_UNIQUE, ttl, socket.inet_aton(service_address), ), 0, ) return r.DNSIncoming(generated.packets()[0])
def test_only_one_answer_can_by_large(self): """Test that only the first answer in each packet can be large. https://datatracker.ietf.org/doc/html/rfc6762#section-17 """ generated = r.DNSOutgoing(const._FLAGS_QR_RESPONSE) query = r.DNSIncoming( r.DNSOutgoing(const._FLAGS_QR_QUERY).packets()[0]) for i in range(3): generated.add_answer( query, r.DNSText( "zoom._hap._tcp.local.", const._TYPE_TXT, const._CLASS_IN | const._CLASS_UNIQUE, 1200, b'\x04ff=0\x04ci=2\x04sf=0\x0bsh=6fLM5A==' * 100, ), ) generated.add_answer( query, r.DNSService( "testname1.local.", const._TYPE_SRV, const._CLASS_IN | const._CLASS_UNIQUE, const._DNS_HOST_TTL, 0, 0, 80, "foo.local.", ), ) assert len(generated.answers) == 4 packets = generated.packets() assert len(packets) == 4 assert len(packets[0]) <= const._MAX_MSG_ABSOLUTE assert len(packets[0]) > const._MAX_MSG_TYPICAL assert len(packets[1]) <= const._MAX_MSG_ABSOLUTE assert len(packets[1]) > const._MAX_MSG_TYPICAL assert len(packets[2]) <= const._MAX_MSG_ABSOLUTE assert len(packets[2]) > const._MAX_MSG_TYPICAL assert len(packets[3]) <= const._MAX_MSG_TYPICAL for packet in packets: parsed = r.DNSIncoming(packet) assert len(parsed.answers) == 1
def test_parse_own_packet_response(self): generated = r.DNSOutgoing(r._FLAGS_QR_RESPONSE) generated.add_answer_at_time(r.DNSService( "æøå.local.", r._TYPE_SRV, r._CLASS_IN, r._DNS_TTL, 0, 0, 80, "foo.local."), 0) parsed = r.DNSIncoming(generated.packet()) self.assertEqual(len(generated.answers), 1) self.assertEqual(len(generated.answers), len(parsed.answers))
def testLongName(self): generated = r.DNSOutgoing(r._FLAGS_QR_RESPONSE) question = r.DNSQuestion( "this.is.a.very.long.name.with.lots.of.parts.in.it.local.", r._TYPE_SRV, r._CLASS_IN) generated.addQuestion(question) parsed = r.DNSIncoming(generated.packet())
def test_any_query_for_ptr(): """Test that queries for ANY will return PTR records.""" zc = Zeroconf(interfaces=['127.0.0.1']) type_ = "_anyptr._tcp.local." name = "knownname" registration_name = "%s.%s" % (name, type_) desc = {'path': '/~paulsm/'} server_name = "ash-2.local." ipv6_address = socket.inet_pton(socket.AF_INET6, "2001:db8::1") info = ServiceInfo(type_, registration_name, 80, 0, 0, desc, server_name, addresses=[ipv6_address]) zc.registry.add(info) _clear_cache(zc) generated = r.DNSOutgoing(const._FLAGS_QR_QUERY) question = r.DNSQuestion(type_, const._TYPE_ANY, const._CLASS_IN) generated.add_question(question) packets = generated.packets() _, multicast_out = zc.query_handler.response( [r.DNSIncoming(packet) for packet in packets], "1.2.3.4", const._MDNS_PORT ) assert multicast_out.answers[0][0].name == type_ assert multicast_out.answers[0][0].alias == registration_name # unregister zc.registry.remove(info) zc.close()
def test_same_name(self): name = "paired.local." generated = r.DNSOutgoing(r._FLAGS_QR_RESPONSE) question = r.DNSQuestion(name, r._TYPE_SRV, r._CLASS_IN) generated.add_question(question) generated.add_question(question) r.DNSIncoming(generated.packet())
def mock_split_incoming_msg( service_state_change: r.ServiceStateChange) -> r.DNSIncoming: """Mock an incoming message for the case where the packet is split.""" ttl = 120 generated = r.DNSOutgoing(const._FLAGS_QR_RESPONSE) generated.add_answer_at_time( r.DNSAddress( service_server, const._TYPE_A, const._CLASS_IN | const._CLASS_UNIQUE, ttl, socket.inet_aton(service_address), ), 0, ) generated.add_answer_at_time( r.DNSService( service_name, const._TYPE_SRV, const._CLASS_IN | const._CLASS_UNIQUE, ttl, 0, 0, 80, service_server, ), 0, ) return r.DNSIncoming(generated.packets()[0])
def test_long_name(self): generated = r.DNSOutgoing(const._FLAGS_QR_RESPONSE) question = r.DNSQuestion( "this.is.a.very.long.name.with.lots.of.parts.in.it.local.", const._TYPE_SRV, const._CLASS_IN) generated.add_question(question) r.DNSIncoming(generated.packets()[0])
def test_invalid_packets_ignored_and_does_not_cause_loop_exception(): """Ensure an invalid packet cannot cause the loop to collapse.""" zc = Zeroconf(interfaces=['127.0.0.1']) generated = r.DNSOutgoing(0) packet = generated.packets()[0] packet = packet[:8] + b'deadbeef' + packet[8:] parsed = r.DNSIncoming(packet) assert parsed.valid is False mock_out = unittest.mock.Mock() mock_out.packets = lambda: [packet] zc.send(mock_out) generated = r.DNSOutgoing(const._FLAGS_QR_RESPONSE) entry = r.DNSText( "didnotcrashincoming._crash._tcp.local.", const._TYPE_TXT, const._CLASS_IN | const._CLASS_UNIQUE, 500, b'path=/~paulsm/', ) assert isinstance(entry, r.DNSText) assert isinstance(entry, r.DNSRecord) assert isinstance(entry, r.DNSEntry) generated.add_answer_at_time(entry, 0) zc.send(generated) time.sleep(0.2) zc.close() assert zc.cache.get(entry) is not None
def test_incoming_circular_reference(self): assert not r.DNSIncoming( bytes.fromhex( '01005e0000fb542a1bf0577608004500006897934000ff11d81bc0a86a31e00000fb' '14e914e90054f9b2000084000000000100000000095f7365727669636573075f646e' '732d7364045f756470056c6f63616c00000c0001000011940018105f73706f746966' '792d636f6e6e656374045f746370c023')).valid
def test_aaaa_query(): """Test that queries for AAAA records work.""" zc = Zeroconf(interfaces=['127.0.0.1']) type_ = "_knownservice._tcp.local." name = "knownname" registration_name = "%s.%s" % (name, type_) desc = {'path': '/~paulsm/'} server_name = "ash-2.local." ipv6_address = socket.inet_pton(socket.AF_INET6, "2001:db8::1") info = ServiceInfo(type_, registration_name, 80, 0, 0, desc, server_name, addresses=[ipv6_address]) zc.register_service(info) _clear_cache(zc) generated = r.DNSOutgoing(const._FLAGS_QR_QUERY) question = r.DNSQuestion(server_name, const._TYPE_AAAA, const._CLASS_IN) generated.add_question(question) packets = generated.packets() _, multicast_out = zc.query_handler.response(r.DNSIncoming(packets[0]), "1.2.3.4", const._MDNS_PORT) assert multicast_out.answers[0][0].address == ipv6_address # unregister zc.unregister_service(info) zc.close()
def test_incoming_unknown_type(self): generated = r.DNSOutgoing(0) answer = r.DNSAddress("a", r._TYPE_SOA, r._CLASS_IN, 1, b"a") generated.add_additional_answer(answer) packet = generated.packet() parsed = r.DNSIncoming(packet) assert len(parsed.answers) == 0 assert parsed.is_query() != parsed.is_response()
def test_match_question(self): generated = r.DNSOutgoing(r._FLAGS_QR_QUERY) question = r.DNSQuestion("testname.local.", r._TYPE_SRV, r._CLASS_IN) generated.add_question(question) parsed = r.DNSIncoming(generated.packet()) self.assertEqual(len(generated.questions), 1) self.assertEqual(len(generated.questions), len(parsed.questions)) self.assertEqual(question, parsed.questions[0])
def mock_incoming_msg( service_state_change: r.ServiceStateChange, service_type: str, service_name: str, ttl: int ) -> r.DNSIncoming: generated = r.DNSOutgoing(const._FLAGS_QR_RESPONSE) generated.add_answer_at_time( r.DNSPointer(service_type, const._TYPE_PTR, const._CLASS_IN, ttl, service_name), 0 ) return r.DNSIncoming(generated.packets()[0])
def mock_incoming_msg(records) -> r.DNSIncoming: generated = r.DNSOutgoing(const._FLAGS_QR_RESPONSE) for record in records: generated.add_answer_at_time(record, 0) return r.DNSIncoming(generated.packets()[0])
def test_match_question(self): generated = r.DNSOutgoing(const._FLAGS_QR_QUERY) question = r.DNSQuestion("testname.local.", const._TYPE_SRV, const._CLASS_IN) generated.add_question(question) parsed = r.DNSIncoming(generated.packets()[0]) assert len(generated.questions) == 1 assert len(generated.questions) == len(parsed.questions) assert question == parsed.questions[0]
def test_questions_do_not_end_up_every_packet(self): """Test that questions are not sent again when multiple packets are needed. https://datatracker.ietf.org/doc/html/rfc6762#section-7.2 Sometimes a Multicast DNS querier will already have too many answers to fit in the Known-Answer Section of its query packets.... It MUST immediately follow the packet with another query packet containing no questions and as many more Known-Answer records as will fit. """ generated = r.DNSOutgoing(const._FLAGS_QR_QUERY) for i in range(35): question = r.DNSQuestion(f"testname{i}.local.", const._TYPE_SRV, const._CLASS_IN) generated.add_question(question) answer = r.DNSService( f"testname{i}.local.", const._TYPE_SRV, const._CLASS_IN | const._CLASS_UNIQUE, const._DNS_HOST_TTL, 0, 0, 80, f"foo{i}.local.", ) generated.add_answer_at_time(answer, 0) assert len(generated.questions) == 35 assert len(generated.answers) == 35 packets = generated.packets() assert len(packets) == 2 assert len(packets[0]) <= const._MAX_MSG_TYPICAL assert len(packets[1]) <= const._MAX_MSG_TYPICAL parsed1 = r.DNSIncoming(packets[0]) assert len(parsed1.questions) == 35 assert len(parsed1.answers) == 33 parsed2 = r.DNSIncoming(packets[1]) assert len(parsed2.questions) == 0 assert len(parsed2.answers) == 2
def test_many_questions(self): """Test many questions get seperated into multiple packets.""" generated = r.DNSOutgoing(const._FLAGS_QR_QUERY) questions = [] for i in range(100): question = r.DNSQuestion(f"testname{i}.local.", const._TYPE_SRV, const._CLASS_IN) generated.add_question(question) questions.append(question) assert len(generated.questions) == 100 packets = generated.packets() assert len(packets) == 2 assert len(packets[0]) < const._MAX_MSG_TYPICAL assert len(packets[1]) < const._MAX_MSG_TYPICAL parsed1 = r.DNSIncoming(packets[0]) assert len(parsed1.questions) == 85 parsed2 = r.DNSIncoming(packets[1]) assert len(parsed2.questions) == 15
def test_incoming_ipv6(self): addr = "2606:2800:220:1:248:1893:25c8:1946" # example.com packed = socket.inet_pton(socket.AF_INET6, addr) generated = r.DNSOutgoing(0) answer = r.DNSAddress('domain', r._TYPE_AAAA, r._CLASS_IN, 1, packed) generated.add_additional_answer(answer) packet = generated.packet() parsed = r.DNSIncoming(packet) record = parsed.answers[0] assert isinstance(record, r.DNSAddress) assert record.address == packed
def test_dns_hinfo(self): generated = r.DNSOutgoing(0) generated.add_additional_answer(DNSHinfo('irrelevant', r._TYPE_HINFO, 0, 0, 'cpu', 'os')) parsed = r.DNSIncoming(generated.packet()) answer = cast(r.DNSHinfo, parsed.answers[0]) self.assertEqual(answer.cpu, u'cpu') self.assertEqual(answer.os, u'os') generated = r.DNSOutgoing(0) generated.add_additional_answer(DNSHinfo('irrelevant', r._TYPE_HINFO, 0, 0, 'cpu', 'x' * 257)) self.assertRaises(r.NamePartTooLongException, generated.packet)
def send(out, addr=const._MDNS_ADDR, port=const._MDNS_PORT): """Sends an outgoing packet.""" pout = r.DNSIncoming(out.packets()[0]) nonlocal nbr_answers for answer in pout.answers: nbr_answers += 1 if not answer.ttl > expected_ttl / 2: unexpected_ttl.set() got_query.set() old_send(out, addr=addr, port=port)
def send(out, addr=r._MDNS_ADDR, port=r._MDNS_PORT): """Sends an outgoing packet.""" pout = r.DNSIncoming(out.packet()) for answer in pout.answers: nbr_queries[0] += 1 if not answer.ttl > expected_ttl / 2: unexpected_ttl.set() got_query.set() old_send(out, addr=addr, port=port)
def test_ptr_optimization(): # instantiate a zeroconf instance zc = Zeroconf(interfaces=['127.0.0.1']) # service definition type_ = "_test-srvc-type._tcp.local." name = "xxxyyy" registration_name = "%s.%s" % (name, type_) desc = {'path': '/~paulsm/'} info = ServiceInfo(type_, registration_name, socket.inet_aton("10.0.1.2"), 80, 0, 0, desc, "ash-2.local.") # we are going to monkey patch the zeroconf send to check packet sizes old_send = zc.send nbr_answers = nbr_additionals = nbr_authorities = 0 has_srv = has_txt = has_a = False def send(out, addr=r._MDNS_ADDR, port=r._MDNS_PORT): """Sends an outgoing packet.""" nonlocal nbr_answers, nbr_additionals, nbr_authorities nonlocal has_srv, has_txt, has_a nbr_answers += len(out.answers) nbr_authorities += len(out.authorities) for answer in out.additionals: nbr_additionals += 1 if answer.type == r._TYPE_SRV: has_srv = True elif answer.type == r._TYPE_TXT: has_txt = True elif answer.type == r._TYPE_A: has_a = True old_send(out, addr=addr, port=port) # monkey patch the zeroconf send setattr(zc, "send", send) # register zc.register_service(info) nbr_answers = nbr_additionals = nbr_authorities = 0 # query query = r.DNSOutgoing(r._FLAGS_QR_QUERY | r._FLAGS_AA) query.add_question(r.DNSQuestion(info.type, r._TYPE_PTR, r._CLASS_IN)) zc.handle_query(r.DNSIncoming(query.packet()), r._MDNS_ADDR, r._MDNS_PORT) assert nbr_answers == 1 and nbr_additionals == 3 and nbr_authorities == 0 assert has_srv and has_txt and has_a # unregister zc.unregister_service(info)
def test_dns_hinfo(self): generated = r.DNSOutgoing(0) generated.add_additional_answer( DNSHinfo("irrelevant", r._TYPE_HINFO, 0, 0, "cpu", "os")) parsed = r.DNSIncoming(generated.packet()) self.assertEqual(parsed.answers[0].cpu, u"cpu") self.assertEqual(parsed.answers[0].os, u"os") generated = r.DNSOutgoing(0) generated.add_additional_answer( DNSHinfo("irrelevant", r._TYPE_HINFO, 0, 0, "cpu", "x" * 257)) self.assertRaises(r.NamePartTooLongException, generated.packet)
def test_multi_packet_known_answer_supression(): zc = Zeroconf(interfaces=['127.0.0.1']) type_ = "_handlermultis._tcp.local." name = "knownname" name2 = "knownname2" name3 = "knownname3" registration_name = "%s.%s" % (name, type_) registration2_name = "%s.%s" % (name2, type_) registration3_name = "%s.%s" % (name3, type_) desc = {'path': '/~paulsm/'} server_name = "ash-2.local." server_name2 = "ash-3.local." server_name3 = "ash-4.local." info = ServiceInfo( type_, registration_name, 80, 0, 0, desc, server_name, addresses=[socket.inet_aton("10.0.1.2")] ) info2 = ServiceInfo( type_, registration2_name, 80, 0, 0, desc, server_name2, addresses=[socket.inet_aton("10.0.1.2")] ) info3 = ServiceInfo( type_, registration3_name, 80, 0, 0, desc, server_name3, addresses=[socket.inet_aton("10.0.1.2")] ) zc.registry.add(info) zc.registry.add(info2) zc.registry.add(info3) now = current_time_millis() _clear_cache(zc) # Test PTR supression generated = r.DNSOutgoing(const._FLAGS_QR_QUERY) question = r.DNSQuestion(type_, const._TYPE_PTR, const._CLASS_IN) generated.add_question(question) for _ in range(1000): # Add so many answers we end up with another packet generated.add_answer_at_time(info.dns_pointer(), now) generated.add_answer_at_time(info2.dns_pointer(), now) generated.add_answer_at_time(info3.dns_pointer(), now) packets = generated.packets() assert len(packets) > 1 unicast_out, multicast_out = zc.query_handler.response( [r.DNSIncoming(packet) for packet in packets], "1.2.3.4", const._MDNS_PORT ) assert unicast_out is None assert multicast_out is None # unregister zc.registry.remove(info) zc.registry.remove(info2) zc.registry.remove(info3) zc.close()
def test_massive_probe_packet_split(self): """Test probe with many authorative answers.""" generated = r.DNSOutgoing(const._FLAGS_QR_QUERY | const._FLAGS_AA) questions = [] for _ in range(30): question = r.DNSQuestion(f"_hap._tcp.local.", const._TYPE_PTR, const._CLASS_IN | const._CLASS_UNIQUE) generated.add_question(question) questions.append(question) assert len(generated.questions) == 30 now = current_time_millis() for _ in range(200): authorative_answer = r.DNSPointer( "myservice{i}_tcp._tcp.local.", const._TYPE_PTR, const._CLASS_IN | const._CLASS_UNIQUE, const._DNS_OTHER_TTL, '123.local.', ) generated.add_authorative_answer(authorative_answer) packets = generated.packets() assert len(packets) == 3 assert len(packets[0]) <= const._MAX_MSG_TYPICAL assert len(packets[1]) <= const._MAX_MSG_TYPICAL assert len(packets[2]) <= const._MAX_MSG_TYPICAL parsed1 = r.DNSIncoming(packets[0]) assert parsed1.questions[0].unicast is True assert len(parsed1.questions) == 30 assert parsed1.num_authorities == 88 assert parsed1.flags & const._FLAGS_TC == const._FLAGS_TC parsed2 = r.DNSIncoming(packets[1]) assert len(parsed2.questions) == 0 assert parsed2.num_authorities == 101 assert parsed2.flags & const._FLAGS_TC == const._FLAGS_TC parsed3 = r.DNSIncoming(packets[2]) assert len(parsed3.questions) == 0 assert parsed3.num_authorities == 11 assert parsed3.flags & const._FLAGS_TC == 0