def test_invalid_packets_ignored_and_does_not_cause_loop_exception(): """Ensure an invalid packet cannot cause the loop to collapse.""" zc = Zeroconf(interfaces=['127.0.0.1']) generated = r.DNSOutgoing(0) packet = generated.packets()[0] packet = packet[:8] + b'deadbeef' + packet[8:] parsed = r.DNSIncoming(packet) assert parsed.valid is False mock_out = unittest.mock.Mock() mock_out.packets = lambda: [packet] zc.send(mock_out) generated = r.DNSOutgoing(const._FLAGS_QR_RESPONSE) entry = r.DNSText( "didnotcrashincoming._crash._tcp.local.", const._TYPE_TXT, const._CLASS_IN | const._CLASS_UNIQUE, 500, b'path=/~paulsm/', ) assert isinstance(entry, r.DNSText) assert isinstance(entry, r.DNSRecord) assert isinstance(entry, r.DNSEntry) generated.add_answer_at_time(entry, 0) zc.send(generated) time.sleep(0.2) zc.close() assert zc.cache.get(entry) is not None
def test_ptr_optimization(): # instantiate a zeroconf instance zc = Zeroconf(interfaces=['127.0.0.1']) # service definition type_ = "_test-srvc-type._tcp.local." name = "xxxyyy" registration_name = "%s.%s" % (name, type_) desc = {'path': '/~paulsm/'} info = ServiceInfo( type_, registration_name, 80, 0, 0, desc, "ash-2.local.", addresses=[socket.inet_aton("10.0.1.2")] ) # register zc.register_service(info) # Verify we won't respond for 1s with the same multicast query = r.DNSOutgoing(const._FLAGS_QR_QUERY) query.add_question(r.DNSQuestion(info.type, const._TYPE_PTR, const._CLASS_IN)) unicast_out, multicast_out = zc.query_handler.response( [r.DNSIncoming(packet) for packet in query.packets()], None, const._MDNS_PORT ) assert unicast_out is None assert multicast_out is None # Clear the cache to allow responding again _clear_cache(zc) # Verify we will now respond query = r.DNSOutgoing(const._FLAGS_QR_QUERY) query.add_question(r.DNSQuestion(info.type, const._TYPE_PTR, const._CLASS_IN)) unicast_out, multicast_out = zc.query_handler.response( [r.DNSIncoming(packet) for packet in query.packets()], None, const._MDNS_PORT ) assert multicast_out.id == query.id assert unicast_out is None assert multicast_out is not None has_srv = has_txt = has_a = False nbr_additionals = 0 nbr_answers = len(multicast_out.answers) nbr_authorities = len(multicast_out.authorities) for answer in multicast_out.additionals: nbr_additionals += 1 if answer.type == const._TYPE_SRV: has_srv = True elif answer.type == const._TYPE_TXT: has_txt = True elif answer.type == const._TYPE_A: has_a = True assert nbr_answers == 1 and nbr_additionals == 3 and nbr_authorities == 0 assert has_srv and has_txt and has_a # unregister zc.unregister_service(info) zc.close()
def test_dns_hinfo(self): generated = r.DNSOutgoing(0) generated.add_additional_answer(DNSHinfo('irrelevant', r._TYPE_HINFO, 0, 0, 'cpu', 'os')) parsed = r.DNSIncoming(generated.packet()) answer = cast(r.DNSHinfo, parsed.answers[0]) self.assertEqual(answer.cpu, u'cpu') self.assertEqual(answer.os, u'os') generated = r.DNSOutgoing(0) generated.add_additional_answer(DNSHinfo('irrelevant', r._TYPE_HINFO, 0, 0, 'cpu', 'x' * 257)) self.assertRaises(r.NamePartTooLongException, generated.packet)
def test_dns_hinfo(self): generated = r.DNSOutgoing(0) generated.add_additional_answer( DNSHinfo("irrelevant", r._TYPE_HINFO, 0, 0, "cpu", "os")) parsed = r.DNSIncoming(generated.packet()) self.assertEqual(parsed.answers[0].cpu, u"cpu") self.assertEqual(parsed.answers[0].os, u"os") generated = r.DNSOutgoing(0) generated.add_additional_answer( DNSHinfo("irrelevant", r._TYPE_HINFO, 0, 0, "cpu", "x" * 257)) self.assertRaises(r.NamePartTooLongException, generated.packet)
def test_only_one_answer_can_by_large(self): """Test that only the first answer in each packet can be large. https://datatracker.ietf.org/doc/html/rfc6762#section-17 """ generated = r.DNSOutgoing(const._FLAGS_QR_RESPONSE) query = r.DNSIncoming( r.DNSOutgoing(const._FLAGS_QR_QUERY).packets()[0]) for i in range(3): generated.add_answer( query, r.DNSText( "zoom._hap._tcp.local.", const._TYPE_TXT, const._CLASS_IN | const._CLASS_UNIQUE, 1200, b'\x04ff=0\x04ci=2\x04sf=0\x0bsh=6fLM5A==' * 100, ), ) generated.add_answer( query, r.DNSService( "testname1.local.", const._TYPE_SRV, const._CLASS_IN | const._CLASS_UNIQUE, const._DNS_HOST_TTL, 0, 0, 80, "foo.local.", ), ) assert len(generated.answers) == 4 packets = generated.packets() assert len(packets) == 4 assert len(packets[0]) <= const._MAX_MSG_ABSOLUTE assert len(packets[0]) > const._MAX_MSG_TYPICAL assert len(packets[1]) <= const._MAX_MSG_ABSOLUTE assert len(packets[1]) > const._MAX_MSG_TYPICAL assert len(packets[2]) <= const._MAX_MSG_ABSOLUTE assert len(packets[2]) > const._MAX_MSG_TYPICAL assert len(packets[3]) <= const._MAX_MSG_TYPICAL for packet in packets: parsed = r.DNSIncoming(packet) assert len(parsed.answers) == 1
def test_aaaa_query(): """Test that queries for AAAA records work.""" zc = Zeroconf(interfaces=['127.0.0.1']) type_ = "_knownservice._tcp.local." name = "knownname" registration_name = "%s.%s" % (name, type_) desc = {'path': '/~paulsm/'} server_name = "ash-2.local." ipv6_address = socket.inet_pton(socket.AF_INET6, "2001:db8::1") info = ServiceInfo(type_, registration_name, 80, 0, 0, desc, server_name, addresses=[ipv6_address]) zc.register_service(info) _clear_cache(zc) generated = r.DNSOutgoing(const._FLAGS_QR_QUERY) question = r.DNSQuestion(server_name, const._TYPE_AAAA, const._CLASS_IN) generated.add_question(question) packets = generated.packets() _, multicast_out = zc.query_handler.response(r.DNSIncoming(packets[0]), "1.2.3.4", const._MDNS_PORT) assert multicast_out.answers[0][0].address == ipv6_address # unregister zc.unregister_service(info) zc.close()
def test_register_and_lookup_type_by_uppercase_name(self): # instantiate a zeroconf instance zc = Zeroconf(interfaces=['127.0.0.1']) type_ = "_mylowertype._tcp.local." name = "Home" registration_name = "%s.%s" % (name, type_) info = ServiceInfo( type_, name=registration_name, server="random123.local.", addresses=[socket.inet_pton(socket.AF_INET, "1.2.3.4")], port=80, properties={"version": "1.0"}, ) zc.register_service(info) _clear_cache(zc) info = ServiceInfo(type_, registration_name) info.load_from_cache(zc) assert info.addresses == [] out = r.DNSOutgoing(const._FLAGS_QR_QUERY) out.add_question( r.DNSQuestion(type_.upper(), const._TYPE_PTR, const._CLASS_IN)) zc.send(out) time.sleep(0.5) info = ServiceInfo(type_, registration_name) info.load_from_cache(zc) assert info.addresses == [socket.inet_pton(socket.AF_INET, "1.2.3.4")] assert info.properties == {b"version": b"1.0"} zc.close()
def test_any_query_for_ptr(): """Test that queries for ANY will return PTR records.""" zc = Zeroconf(interfaces=['127.0.0.1']) type_ = "_anyptr._tcp.local." name = "knownname" registration_name = "%s.%s" % (name, type_) desc = {'path': '/~paulsm/'} server_name = "ash-2.local." ipv6_address = socket.inet_pton(socket.AF_INET6, "2001:db8::1") info = ServiceInfo(type_, registration_name, 80, 0, 0, desc, server_name, addresses=[ipv6_address]) zc.registry.add(info) _clear_cache(zc) generated = r.DNSOutgoing(const._FLAGS_QR_QUERY) question = r.DNSQuestion(type_, const._TYPE_ANY, const._CLASS_IN) generated.add_question(question) packets = generated.packets() _, multicast_out = zc.query_handler.response( [r.DNSIncoming(packet) for packet in packets], "1.2.3.4", const._MDNS_PORT ) assert multicast_out.answers[0][0].name == type_ assert multicast_out.answers[0][0].alias == registration_name # unregister zc.registry.remove(info) zc.close()
def testLongName(self): generated = r.DNSOutgoing(r._FLAGS_QR_RESPONSE) question = r.DNSQuestion( "this.is.a.very.long.name.with.lots.of.parts.in.it.local.", r._TYPE_SRV, r._CLASS_IN) generated.addQuestion(question) parsed = r.DNSIncoming(generated.packet())
def test_sending_unicast(): """Test sending unicast response.""" zc = Zeroconf(interfaces=['127.0.0.1']) generated = r.DNSOutgoing(const._FLAGS_QR_RESPONSE) entry = r.DNSText( "didnotcrashincoming._crash._tcp.local.", const._TYPE_TXT, const._CLASS_IN | const._CLASS_UNIQUE, 500, b'path=/~paulsm/', ) generated.add_answer_at_time(entry, 0) zc.send(generated, "2001:db8::1", const._MDNS_PORT) # https://www.iana.org/go/rfc3849 time.sleep(0.2) assert zc.cache.get(entry) is None zc.send(generated, "198.51.100.0", const._MDNS_PORT) # Documentation (TEST-NET-2) time.sleep(0.2) assert zc.cache.get(entry) is None zc.send(generated) time.sleep(0.2) assert zc.cache.get(entry) is not None zc.close()
def request(self, zc, timeout): now = time.time() delay = 0.2 next_ = now + delay last = now + timeout try: zc.add_listener( self, zeroconf.DNSQuestion(self.name, zeroconf._TYPE_ANY, zeroconf._CLASS_IN)) while self.address is None: if last <= now: # Timeout return False if next_ <= now: out = zeroconf.DNSOutgoing(zeroconf._FLAGS_QR_QUERY) out.add_question( zeroconf.DNSQuestion(self.name, zeroconf._TYPE_A, zeroconf._CLASS_IN)) out.add_answer_at_time( zc.cache.get_by_details(self.name, zeroconf._TYPE_A, zeroconf._CLASS_IN), now) zc.send(out) next_ = now + delay delay *= 2 zc.wait(min(next_, last) - now) now = time.time() finally: zc.remove_listener(self) return True
def test_same_name(self): name = "paired.local." generated = r.DNSOutgoing(r._FLAGS_QR_RESPONSE) question = r.DNSQuestion(name, r._TYPE_SRV, r._CLASS_IN) generated.add_question(question) generated.add_question(question) r.DNSIncoming(generated.packet())
def request(self, zc, timeout): """Returns true if the service could be discovered on the network, and updates this object with details discovered. """ now = zeroconf.current_time_millis() delay = zeroconf._LISTENER_TIME next_ = now + delay last = now + timeout record_types_for_check_cache = [ (zeroconf._TYPE_SRV, zeroconf._CLASS_IN), (zeroconf._TYPE_TXT, zeroconf._CLASS_IN), ] if self.server is not None: record_types_for_check_cache.append((zeroconf._TYPE_A, zeroconf._CLASS_IN)) for record_type in record_types_for_check_cache: cached = zc.cache.get_by_details(self.name, *record_type) if cached: self.update_record(zc, now, cached) if None not in (self.server, self.address, self.text, self.port): return True try: zc.add_listener(self, zeroconf.DNSQuestion(self.name, zeroconf._TYPE_ANY, zeroconf._CLASS_IN)) while None in (self.server, self.address, self.text, self.port): if last <= now: return False if next_ <= now: out = zeroconf.DNSOutgoing(zeroconf._FLAGS_QR_QUERY) out.add_question( zeroconf.DNSQuestion(self.name, zeroconf._TYPE_SRV, zeroconf._CLASS_IN)) if self.port is not None: out.add_answer_at_time( zc.cache.get_by_details( self.name, zeroconf._TYPE_SRV, zeroconf._CLASS_IN), now) out.add_question( zeroconf.DNSQuestion(self.name, zeroconf._TYPE_TXT, zeroconf._CLASS_IN)) out.add_answer_at_time( zc.cache.get_by_details( self.name, zeroconf._TYPE_TXT, zeroconf._CLASS_IN), now) if self.server is not None: out.add_question( zeroconf.DNSQuestion(self.server, zeroconf._TYPE_A, zeroconf._CLASS_IN)) out.add_answer_at_time( zc.cache.get_by_details( self.server, zeroconf._TYPE_A, zeroconf._CLASS_IN), now) zc.send(out) next_ = now + delay delay *= 2 zc.wait(min(next_, last) - now) now = zeroconf.current_time_millis() finally: zc.remove_listener(self) return True
def test_long_name(self): generated = r.DNSOutgoing(const._FLAGS_QR_RESPONSE) question = r.DNSQuestion( "this.is.a.very.long.name.with.lots.of.parts.in.it.local.", const._TYPE_SRV, const._CLASS_IN) generated.add_question(question) r.DNSIncoming(generated.packets()[0])
def test_tc_bit_not_set_in_answer_packet(): """Verify the TC bit is not set when there are no questions and answers exceed the packet size.""" out = r.DNSOutgoing(const._FLAGS_QR_RESPONSE | const._FLAGS_AA) for i in range(30): out.add_answer_at_time( DNSText( ("HASS Bridge W9DN %s._hap._tcp.local." % i), const._TYPE_TXT, const._CLASS_IN | const._CLASS_UNIQUE, const._DNS_OTHER_TTL, b'\x13md=HASS Bridge W9DN\x06pv=1.0\x14id=11:8E:DB:5B:5C:C5\x05c#=12\x04s#=1' b'\x04ff=0\x04ci=2\x04sf=0\x0bsh=6fLM5A==', ), 0, ) packets = out.packets() assert len(packets) == 3 first_packet = r.DNSIncoming(packets[0]) assert first_packet.flags & const._FLAGS_TC == 0 assert first_packet.valid is True second_packet = r.DNSIncoming(packets[1]) assert second_packet.flags & const._FLAGS_TC == 0 assert second_packet.valid is True third_packet = r.DNSIncoming(packets[2]) assert third_packet.flags & const._FLAGS_TC == 0 assert third_packet.valid is True
def test_incoming_exception_handling(self): generated = r.DNSOutgoing(0) packet = generated.packet() packet = packet[:8] + b"deadbeef" + packet[8:] parsed = r.DNSIncoming(packet) parsed = r.DNSIncoming(packet) assert parsed.valid is False
def test_parse_own_packet_response(self): generated = r.DNSOutgoing(r._FLAGS_QR_RESPONSE) generated.add_answer_at_time(r.DNSService( "æøå.local.", r._TYPE_SRV, r._CLASS_IN, r._DNS_TTL, 0, 0, 80, "foo.local."), 0) parsed = r.DNSIncoming(generated.packet()) self.assertEqual(len(generated.answers), 1) self.assertEqual(len(generated.answers), len(parsed.answers))
def mock_split_incoming_msg( service_state_change: r.ServiceStateChange) -> r.DNSIncoming: """Mock an incoming message for the case where the packet is split.""" ttl = 120 generated = r.DNSOutgoing(const._FLAGS_QR_RESPONSE) generated.add_answer_at_time( r.DNSAddress( service_server, const._TYPE_A, const._CLASS_IN | const._CLASS_UNIQUE, ttl, socket.inet_aton(service_address), ), 0, ) generated.add_answer_at_time( r.DNSService( service_name, const._TYPE_SRV, const._CLASS_IN | const._CLASS_UNIQUE, ttl, 0, 0, 80, service_server, ), 0, ) return r.DNSIncoming(generated.packets()[0])
def test_tc_bit_in_query_packet(): """Verify the TC bit is set when known answers exceed the packet size.""" out = r.DNSOutgoing(const._FLAGS_QR_QUERY | const._FLAGS_AA) type_ = "_hap._tcp.local." out.add_question(r.DNSQuestion(type_, const._TYPE_PTR, const._CLASS_IN)) for i in range(30): out.add_answer_at_time( DNSText( ("HASS Bridge W9DN %s._hap._tcp.local." % i), const._TYPE_TXT, const._CLASS_IN | const._CLASS_UNIQUE, const._DNS_OTHER_TTL, b'\x13md=HASS Bridge W9DN\x06pv=1.0\x14id=11:8E:DB:5B:5C:C5\x05c#=12\x04s#=1' b'\x04ff=0\x04ci=2\x04sf=0\x0bsh=6fLM5A==', ), 0, ) packets = out.packets() assert len(packets) == 3 first_packet = r.DNSIncoming(packets[0]) assert first_packet.truncated assert first_packet.valid is True second_packet = r.DNSIncoming(packets[1]) assert second_packet.truncated assert second_packet.valid is True third_packet = r.DNSIncoming(packets[2]) assert not third_packet.truncated assert third_packet.valid is True
def test_numbers(self): generated = r.DNSOutgoing(r._FLAGS_QR_RESPONSE) bytes = generated.packet() (num_questions, num_answers, num_authorities, num_additionals) = struct.unpack('!4H', bytes[4:12]) self.assertEqual(num_questions, 0) self.assertEqual(num_answers, 0) self.assertEqual(num_authorities, 0) self.assertEqual(num_additionals, 0)
def generate_host(zc, host_name, type_): name = '.'.join((host_name, type_)) out = r.DNSOutgoing(r._FLAGS_QR_RESPONSE | r._FLAGS_AA) out.add_answer_at_time(r.DNSPointer(type_, r._TYPE_PTR, r._CLASS_IN, r._DNS_OTHER_TTL, name), 0) out.add_answer_at_time( r.DNSService(type_, r._TYPE_SRV, r._CLASS_IN, r._DNS_HOST_TTL, 0, 0, 80, name), 0 ) zc.send(out)
def mock_incoming_msg( service_state_change: r.ServiceStateChange, service_type: str, service_name: str, ttl: int ) -> r.DNSIncoming: generated = r.DNSOutgoing(const._FLAGS_QR_RESPONSE) generated.add_answer_at_time( r.DNSPointer(service_type, const._TYPE_PTR, const._CLASS_IN, ttl, service_name), 0 ) return r.DNSIncoming(generated.packets()[0])
def test_incoming_unknown_type(self): generated = r.DNSOutgoing(0) answer = r.DNSAddress("a", r._TYPE_SOA, r._CLASS_IN, 1, b"a") generated.add_additional_answer(answer) packet = generated.packet() parsed = r.DNSIncoming(packet) assert len(parsed.answers) == 0 assert parsed.is_query() != parsed.is_response()
def test_match_question(self): generated = r.DNSOutgoing(r._FLAGS_QR_QUERY) question = r.DNSQuestion("testname.local.", r._TYPE_SRV, r._CLASS_IN) generated.add_question(question) parsed = r.DNSIncoming(generated.packet()) self.assertEqual(len(generated.questions), 1) self.assertEqual(len(generated.questions), len(parsed.questions)) self.assertEqual(question, parsed.questions[0])
def mock_incoming_msg(records) -> r.DNSIncoming: generated = r.DNSOutgoing(const._FLAGS_QR_RESPONSE) for record in records: generated.add_answer_at_time(record, 0) return r.DNSIncoming(generated.packets()[0])
def test_numbers(self): generated = r.DNSOutgoing(const._FLAGS_QR_RESPONSE) bytes = generated.packets()[0] (num_questions, num_answers, num_authorities, num_additionals) = struct.unpack('!4H', bytes[4:12]) assert num_questions == 0 assert num_answers == 0 assert num_authorities == 0 assert num_additionals == 0
def test_match_question(self): generated = r.DNSOutgoing(const._FLAGS_QR_QUERY) question = r.DNSQuestion("testname.local.", const._TYPE_SRV, const._CLASS_IN) generated.add_question(question) parsed = r.DNSIncoming(generated.packets()[0]) assert len(generated.questions) == 1 assert len(generated.questions) == len(parsed.questions) assert question == parsed.questions[0]
def test_suppress_answer(self): query_generated = r.DNSOutgoing(r._FLAGS_QR_QUERY) question = r.DNSQuestion("testname.local.", r._TYPE_SRV, r._CLASS_IN) query_generated.add_question(question) answer1 = r.DNSService( "testname1.local.", r._TYPE_SRV, r._CLASS_IN, r._DNS_HOST_TTL, 0, 0, 80, "foo.local." ) staleanswer2 = r.DNSService( "testname2.local.", r._TYPE_SRV, r._CLASS_IN, r._DNS_HOST_TTL / 2, 0, 0, 80, "foo.local." ) answer2 = r.DNSService( "testname2.local.", r._TYPE_SRV, r._CLASS_IN, r._DNS_HOST_TTL, 0, 0, 80, "foo.local." ) query_generated.add_answer_at_time(answer1, 0) query_generated.add_answer_at_time(staleanswer2, 0) query = r.DNSIncoming(query_generated.packet()) # Should be suppressed response = r.DNSOutgoing(r._FLAGS_QR_RESPONSE) response.add_answer(query, answer1) assert len(response.answers) == 0 # Should not be suppressed, TTL in query is too short response.add_answer(query, answer2) assert len(response.answers) == 1 # Should not be suppressed, name is different tmp = copy.copy(answer1) tmp.name = "testname3.local." response.add_answer(query, tmp) assert len(response.answers) == 2 # Should not be suppressed, type is different tmp = copy.copy(answer1) tmp.type = r._TYPE_A response.add_answer(query, tmp) assert len(response.answers) == 3 # Should not be suppressed, class is different tmp = copy.copy(answer1) tmp.class_ = r._CLASS_NONE response.add_answer(query, tmp) assert len(response.answers) == 4
def mock_incoming_msg( service_state_change: r.ServiceStateChange) -> r.DNSIncoming: ttl = 120 generated = r.DNSOutgoing(const._FLAGS_QR_RESPONSE) if service_state_change == r.ServiceStateChange.Updated: generated.add_answer_at_time( r.DNSText( service_name, const._TYPE_TXT, const._CLASS_IN | const._CLASS_UNIQUE, ttl, service_text, ), 0, ) return r.DNSIncoming(generated.packets()[0]) if service_state_change == r.ServiceStateChange.Removed: ttl = 0 generated.add_answer_at_time( r.DNSPointer(service_type, const._TYPE_PTR, const._CLASS_IN, ttl, service_name), 0) generated.add_answer_at_time( r.DNSService( service_name, const._TYPE_SRV, const._CLASS_IN | const._CLASS_UNIQUE, ttl, 0, 0, 80, service_server, ), 0, ) generated.add_answer_at_time( r.DNSText(service_name, const._TYPE_TXT, const._CLASS_IN | const._CLASS_UNIQUE, ttl, service_text), 0, ) generated.add_answer_at_time( r.DNSAddress( service_server, const._TYPE_A, const._CLASS_IN | const._CLASS_UNIQUE, ttl, socket.inet_aton(service_address), ), 0, ) return r.DNSIncoming(generated.packets()[0])
def test_incoming_ipv6(self): addr = "2606:2800:220:1:248:1893:25c8:1946" # example.com packed = socket.inet_pton(socket.AF_INET6, addr) generated = r.DNSOutgoing(0) answer = r.DNSAddress('domain', r._TYPE_AAAA, r._CLASS_IN, 1, packed) generated.add_additional_answer(answer) packet = generated.packet() parsed = r.DNSIncoming(packet) record = parsed.answers[0] assert isinstance(record, r.DNSAddress) assert record.address == packed