예제 #1
0
def register_service(zc, info, ttl=60, send_num=3):
    """
    like zeroconf.Zeroconf.register_service() but
    just broadcasts send_num packets and then returns
    """
    logger.info("Registering service: {s}".format(s=info))
    now = current_time_millis()
    next_time = now
    i = 0
    while i < 3:
        if now < next_time:
            sleep_time = next_time - now
            logger.debug("sleeping {s}".format(s=sleep_time))
            zc.wait(sleep_time)
            now = current_time_millis()
            continue
        out = DNSOutgoing(_FLAGS_QR_RESPONSE | _FLAGS_AA)
        out.add_answer_at_time(
            DNSPointer(info.type, _TYPE_PTR, _CLASS_IN, ttl, info.name), 0)
        out.add_answer_at_time(
            DNSService(info.name, _TYPE_SRV, _CLASS_IN, ttl, info.priority,
                       info.weight, info.port, info.server), 0)
        out.add_answer_at_time(
            DNSText(info.name, _TYPE_TXT, _CLASS_IN, ttl, info.text), 0)
        if info.address:
            out.add_answer_at_time(
                DNSAddress(info.server, _TYPE_A, _CLASS_IN, ttl, info.address),
                0)
        zc.send(out)
        i += 1
        next_time += _REGISTER_TIME
    logger.debug("done registering service")
예제 #2
0
    def request(self, zc, timeout):
        """Returns true if the service could be discovered on the
        network, and updates this object with details discovered.
        """
        now = zeroconf.current_time_millis()
        delay = zeroconf._LISTENER_TIME
        next_ = now + delay
        last = now + timeout

        record_types_for_check_cache = [
            (zeroconf._TYPE_SRV, zeroconf._CLASS_IN),
            (zeroconf._TYPE_TXT, zeroconf._CLASS_IN),
        ]
        if self.server is not None:
            record_types_for_check_cache.append((zeroconf._TYPE_A, zeroconf._CLASS_IN))
        for record_type in record_types_for_check_cache:
            cached = zc.cache.get_by_details(self.name, *record_type)
            if cached:
                self.update_record(zc, now, cached)

        if None not in (self.server, self.address, self.text, self.port):
            return True

        try:
            zc.add_listener(self, zeroconf.DNSQuestion(self.name, zeroconf._TYPE_ANY, zeroconf._CLASS_IN))
            while None in (self.server, self.address, self.text, self.port):
                if last <= now:
                    return False
                if next_ <= now:
                    out = zeroconf.DNSOutgoing(zeroconf._FLAGS_QR_QUERY)
                    out.add_question(
                        zeroconf.DNSQuestion(self.name, zeroconf._TYPE_SRV, zeroconf._CLASS_IN))
                    
                    if self.port is not None:
                        out.add_answer_at_time(
                            zc.cache.get_by_details(
                                self.name, zeroconf._TYPE_SRV, zeroconf._CLASS_IN), now)
            
                    out.add_question(
                        zeroconf.DNSQuestion(self.name, zeroconf._TYPE_TXT, zeroconf._CLASS_IN))
                    out.add_answer_at_time(
                        zc.cache.get_by_details(
                            self.name, zeroconf._TYPE_TXT, zeroconf._CLASS_IN), now)
            
                    if self.server is not None:
                        out.add_question(
                            zeroconf.DNSQuestion(self.server, zeroconf._TYPE_A, zeroconf._CLASS_IN))
                        out.add_answer_at_time(
                            zc.cache.get_by_details(
                                self.server, zeroconf._TYPE_A, zeroconf._CLASS_IN), now)
                    zc.send(out)
                    next_ = now + delay
                    delay *= 2
        
                zc.wait(min(next_, last) - now)
                now = zeroconf.current_time_millis()
        finally:
            zc.remove_listener(self)

        return True
예제 #3
0
def register_service(zc, info, ttl=60, send_num=3):
    """
    like zeroconf.Zeroconf.register_service() but
    just broadcasts send_num packets and then returns
    """
    logger.info("Registering service: {s}".format(s=info))
    now = current_time_millis()
    next_time = now
    i = 0
    while i < 3:
        if now < next_time:
            sleep_time = next_time - now
            logger.debug("sleeping {s}".format(s=sleep_time))
            zc.wait(sleep_time)
            now = current_time_millis()
            continue
        out = DNSOutgoing(_FLAGS_QR_RESPONSE | _FLAGS_AA)
        out.add_answer_at_time(DNSPointer(info.type, _TYPE_PTR,
                                          _CLASS_IN, ttl, info.name), 0)
        out.add_answer_at_time(DNSService(info.name, _TYPE_SRV,
                                          _CLASS_IN, ttl, info.priority, info.weight, info.port,
                                          info.server), 0)
        out.add_answer_at_time(DNSText(info.name, _TYPE_TXT, _CLASS_IN,
                                       ttl, info.text), 0)
        if info.address:
            out.add_answer_at_time(DNSAddress(info.server, _TYPE_A,
                                              _CLASS_IN, ttl, info.address), 0)
        zc.send(out)
        i += 1
        next_time += _REGISTER_TIME
    logger.debug("done registering service")
예제 #4
0
 def test_dns_record_is_recent(self):
     now = current_time_millis()
     record = r.DNSRecord('irrelevant', const._TYPE_SRV, const._CLASS_IN, 8)
     assert record.is_recent(now + (8 / 4.1 * 1000)) is True
     assert record.is_recent(now + (8 / 3 * 1000)) is False
     assert record.is_recent(now + (8 / 2 * 1000)) is False
     assert record.is_recent(now + (8 * 1000)) is False
예제 #5
0
파일: zeroconf.py 프로젝트: krahabb/esphome
    def host_status(self, key: str) -> bool:
        entries = self.zc.cache.entries_with_name(key)
        if not entries:
            return False
        now = current_time_millis()

        return any((entry.created + DashboardStatus.OFFLINE_AFTER) >= now
                   for entry in entries)
예제 #6
0
def test_multi_packet_known_answer_supression():
    zc = Zeroconf(interfaces=['127.0.0.1'])
    type_ = "_handlermultis._tcp.local."
    name = "knownname"
    name2 = "knownname2"
    name3 = "knownname3"

    registration_name = "%s.%s" % (name, type_)
    registration2_name = "%s.%s" % (name2, type_)
    registration3_name = "%s.%s" % (name3, type_)

    desc = {'path': '/~paulsm/'}
    server_name = "ash-2.local."
    server_name2 = "ash-3.local."
    server_name3 = "ash-4.local."

    info = ServiceInfo(
        type_, registration_name, 80, 0, 0, desc, server_name, addresses=[socket.inet_aton("10.0.1.2")]
    )
    info2 = ServiceInfo(
        type_, registration2_name, 80, 0, 0, desc, server_name2, addresses=[socket.inet_aton("10.0.1.2")]
    )
    info3 = ServiceInfo(
        type_, registration3_name, 80, 0, 0, desc, server_name3, addresses=[socket.inet_aton("10.0.1.2")]
    )
    zc.registry.add(info)
    zc.registry.add(info2)
    zc.registry.add(info3)

    now = current_time_millis()
    _clear_cache(zc)
    # Test PTR supression
    generated = r.DNSOutgoing(const._FLAGS_QR_QUERY)
    question = r.DNSQuestion(type_, const._TYPE_PTR, const._CLASS_IN)
    generated.add_question(question)
    for _ in range(1000):
        # Add so many answers we end up with another packet
        generated.add_answer_at_time(info.dns_pointer(), now)
    generated.add_answer_at_time(info2.dns_pointer(), now)
    generated.add_answer_at_time(info3.dns_pointer(), now)
    packets = generated.packets()
    assert len(packets) > 1
    unicast_out, multicast_out = zc.query_handler.response(
        [r.DNSIncoming(packet) for packet in packets], "1.2.3.4", const._MDNS_PORT
    )
    assert unicast_out is None
    assert multicast_out is None
    # unregister
    zc.registry.remove(info)
    zc.registry.remove(info2)
    zc.registry.remove(info3)
    zc.close()
예제 #7
0
    def test_dns_record_reset_ttl(self):
        record = r.DNSRecord('irrelevant', const._TYPE_SRV, const._CLASS_IN, const._DNS_HOST_TTL)
        time.sleep(1)
        record2 = r.DNSRecord('irrelevant', const._TYPE_SRV, const._CLASS_IN, const._DNS_HOST_TTL)
        now = r.current_time_millis()

        assert record.created != record2.created
        assert record.get_remaining_ttl(now) != record2.get_remaining_ttl(now)

        record.reset_ttl(record2)

        assert record.ttl == record2.ttl
        assert record.created == record2.created
        assert record.get_remaining_ttl(now) == record2.get_remaining_ttl(now)
예제 #8
0
파일: zeroconf.py 프로젝트: krahabb/esphome
 def run(self) -> None:
     while not self.stop_event.is_set():
         self.on_update({
             key: self.host_status(host)
             for key, host in self.key_to_host.items()
         })
         now = current_time_millis()
         for host in self.query_hosts:
             entries = self.zc.cache.entries_with_name(host)
             if not entries or all(
                 (entry.created + DashboardStatus.PING_AFTER) <= now
                     for entry in entries):
                 out = DNSOutgoing(_FLAGS_QR_QUERY)
                 out.add_question(DNSQuestion(host, _TYPE_A, _CLASS_IN))
                 self.zc.send(out)
         self.query_event.wait()
         self.query_event.clear()
예제 #9
0
 def test_adding_expired_answer(self):
     generated = r.DNSOutgoing(const._FLAGS_QR_RESPONSE)
     generated.add_answer_at_time(
         r.DNSService(
             "æøå.local.",
             const._TYPE_SRV,
             const._CLASS_IN | const._CLASS_UNIQUE,
             const._DNS_HOST_TTL,
             0,
             0,
             80,
             "foo.local.",
         ),
         current_time_millis() + 1000000,
     )
     parsed = r.DNSIncoming(generated.packets()[0])
     assert len(generated.answers) == 0
     assert len(generated.answers) == len(parsed.answers)
예제 #10
0
 def test_service_info_rejects_expired_records(self):
     """Verify records that are expired are rejected."""
     zc = r.Zeroconf(interfaces=['127.0.0.1'])
     desc = {'path': '/~paulsm/'}
     service_name = 'name._type._tcp.local.'
     service_type = '_type._tcp.local.'
     service_server = 'ash-1.local.'
     service_address = socket.inet_aton("10.0.1.2")
     ttl = 120
     now = r.current_time_millis()
     info = ServiceInfo(
         service_type, service_name, 22, 0, 0, desc, service_server, addresses=[service_address]
     )
     # Matching updates
     info.update_record(
         zc,
         now,
         r.DNSText(
             service_name,
             const._TYPE_TXT,
             const._CLASS_IN | const._CLASS_UNIQUE,
             ttl,
             b'\x04ff=0\x04ci=2\x04sf=0\x0bsh=6fLM5A==',
         ),
     )
     assert info.properties[b"ci"] == b"2"
     # Expired record
     expired_record = r.DNSText(
         service_name,
         const._TYPE_TXT,
         const._CLASS_IN | const._CLASS_UNIQUE,
         ttl,
         b'\x04ff=0\x04ci=3\x04sf=0\x0bsh=6fLM5A==',
     )
     expired_record.created = 1000
     expired_record._expiration_time = 1000
     info.update_record(zc, now, expired_record)
     assert info.properties[b"ci"] == b"2"
     zc.close()
예제 #11
0
    def test_massive_probe_packet_split(self):
        """Test probe with many authorative answers."""
        generated = r.DNSOutgoing(const._FLAGS_QR_QUERY | const._FLAGS_AA)
        questions = []
        for _ in range(30):
            question = r.DNSQuestion(f"_hap._tcp.local.", const._TYPE_PTR,
                                     const._CLASS_IN | const._CLASS_UNIQUE)
            generated.add_question(question)
            questions.append(question)
        assert len(generated.questions) == 30
        now = current_time_millis()
        for _ in range(200):
            authorative_answer = r.DNSPointer(
                "myservice{i}_tcp._tcp.local.",
                const._TYPE_PTR,
                const._CLASS_IN | const._CLASS_UNIQUE,
                const._DNS_OTHER_TTL,
                '123.local.',
            )
            generated.add_authorative_answer(authorative_answer)
        packets = generated.packets()
        assert len(packets) == 3
        assert len(packets[0]) <= const._MAX_MSG_TYPICAL
        assert len(packets[1]) <= const._MAX_MSG_TYPICAL
        assert len(packets[2]) <= const._MAX_MSG_TYPICAL

        parsed1 = r.DNSIncoming(packets[0])
        assert parsed1.questions[0].unicast is True
        assert len(parsed1.questions) == 30
        assert parsed1.num_authorities == 88
        assert parsed1.flags & const._FLAGS_TC == const._FLAGS_TC
        parsed2 = r.DNSIncoming(packets[1])
        assert len(parsed2.questions) == 0
        assert parsed2.num_authorities == 101
        assert parsed2.flags & const._FLAGS_TC == const._FLAGS_TC
        parsed3 = r.DNSIncoming(packets[2])
        assert len(parsed3.questions) == 0
        assert parsed3.num_authorities == 11
        assert parsed3.flags & const._FLAGS_TC == 0
예제 #12
0
    def test_many_questions_with_many_known_answers(self):
        """Test many questions and known answers get seperated into multiple packets."""
        generated = r.DNSOutgoing(const._FLAGS_QR_QUERY)
        questions = []
        for _ in range(30):
            question = r.DNSQuestion(f"_hap._tcp.local.", const._TYPE_PTR,
                                     const._CLASS_IN)
            generated.add_question(question)
            questions.append(question)
        assert len(generated.questions) == 30
        now = current_time_millis()
        for _ in range(200):
            known_answer = r.DNSPointer(
                "myservice{i}_tcp._tcp.local.",
                const._TYPE_PTR,
                const._CLASS_IN | const._CLASS_UNIQUE,
                const._DNS_OTHER_TTL,
                '123.local.',
            )
            generated.add_answer_at_time(known_answer, now)
        packets = generated.packets()
        assert len(packets) == 3
        assert len(packets[0]) <= const._MAX_MSG_TYPICAL
        assert len(packets[1]) <= const._MAX_MSG_TYPICAL
        assert len(packets[2]) <= const._MAX_MSG_TYPICAL

        parsed1 = r.DNSIncoming(packets[0])
        assert len(parsed1.questions) == 30
        assert len(parsed1.answers) == 88
        assert parsed1.flags & const._FLAGS_TC == const._FLAGS_TC
        parsed2 = r.DNSIncoming(packets[1])
        assert len(parsed2.questions) == 0
        assert len(parsed2.answers) == 101
        assert parsed2.flags & const._FLAGS_TC == const._FLAGS_TC
        parsed3 = r.DNSIncoming(packets[2])
        assert len(parsed3.questions) == 0
        assert len(parsed3.answers) == 11
        assert parsed3.flags & const._FLAGS_TC == 0
예제 #13
0
    def restore_from_cache(self):
        now = current_time_millis()
        cache: dict = self.zc.cache.cache
        for key, records in cache.items():
            try:
                if not key.endswith(self.suffix):
                    continue
                for record in records.keys():
                    if not isinstance(record, DNSText) or \
                            record.is_expired(now):
                        continue
                    key = record.key.split(".", 1)[0] + ".local."
                    address = next(r.address for r in cache[key].keys()
                                   if isinstance(r, DNSAddress))
                    host = str(ipaddress.ip_address(address))
                    data = self.decode_text(record.text)
                    asyncio.create_task(self.handler(record.name, host, data))

            except KeyError:
                _LOGGER.debug(f"Can't find key in zeroconf cache: {key}")
            except StopIteration:
                _LOGGER.debug(f"Can't find address for {key}")
            except Exception as e:
                _LOGGER.warning("Can't restore zeroconf cache", exc_info=e)
예제 #14
0
    def test_handle_response(self):
        def mock_incoming_msg(
                service_state_change: r.ServiceStateChange) -> r.DNSIncoming:
            ttl = 120
            generated = r.DNSOutgoing(const._FLAGS_QR_RESPONSE)

            if service_state_change == r.ServiceStateChange.Updated:
                generated.add_answer_at_time(
                    r.DNSText(
                        service_name,
                        const._TYPE_TXT,
                        const._CLASS_IN | const._CLASS_UNIQUE,
                        ttl,
                        service_text,
                    ),
                    0,
                )
                return r.DNSIncoming(generated.packets()[0])

            if service_state_change == r.ServiceStateChange.Removed:
                ttl = 0

            generated.add_answer_at_time(
                r.DNSPointer(service_type, const._TYPE_PTR, const._CLASS_IN,
                             ttl, service_name), 0)
            generated.add_answer_at_time(
                r.DNSService(
                    service_name,
                    const._TYPE_SRV,
                    const._CLASS_IN | const._CLASS_UNIQUE,
                    ttl,
                    0,
                    0,
                    80,
                    service_server,
                ),
                0,
            )
            generated.add_answer_at_time(
                r.DNSText(service_name, const._TYPE_TXT,
                          const._CLASS_IN | const._CLASS_UNIQUE, ttl,
                          service_text),
                0,
            )
            generated.add_answer_at_time(
                r.DNSAddress(
                    service_server,
                    const._TYPE_A,
                    const._CLASS_IN | const._CLASS_UNIQUE,
                    ttl,
                    socket.inet_aton(service_address),
                ),
                0,
            )

            return r.DNSIncoming(generated.packets()[0])

        def mock_split_incoming_msg(
                service_state_change: r.ServiceStateChange) -> r.DNSIncoming:
            """Mock an incoming message for the case where the packet is split."""
            ttl = 120
            generated = r.DNSOutgoing(const._FLAGS_QR_RESPONSE)
            generated.add_answer_at_time(
                r.DNSAddress(
                    service_server,
                    const._TYPE_A,
                    const._CLASS_IN | const._CLASS_UNIQUE,
                    ttl,
                    socket.inet_aton(service_address),
                ),
                0,
            )
            generated.add_answer_at_time(
                r.DNSService(
                    service_name,
                    const._TYPE_SRV,
                    const._CLASS_IN | const._CLASS_UNIQUE,
                    ttl,
                    0,
                    0,
                    80,
                    service_server,
                ),
                0,
            )
            return r.DNSIncoming(generated.packets()[0])

        service_name = 'name._type._tcp.local.'
        service_type = '_type._tcp.local.'
        service_server = 'ash-2.local.'
        service_text = b'path=/~paulsm/'
        service_address = '10.0.1.2'

        zeroconf = r.Zeroconf(interfaces=['127.0.0.1'])

        try:
            # service added
            _inject_response(zeroconf,
                             mock_incoming_msg(r.ServiceStateChange.Added))
            dns_text = zeroconf.cache.get_by_details(service_name,
                                                     const._TYPE_TXT,
                                                     const._CLASS_IN)
            assert dns_text is not None
            assert cast(
                r.DNSText, dns_text
            ).text == service_text  # service_text is b'path=/~paulsm/'
            all_dns_text = zeroconf.cache.get_all_by_details(
                service_name, const._TYPE_TXT, const._CLASS_IN)
            assert [dns_text] == all_dns_text

            # https://tools.ietf.org/html/rfc6762#section-10.2
            # Instead of merging this new record additively into the cache in addition
            # to any previous records with the same name, rrtype, and rrclass,
            # all old records with that name, rrtype, and rrclass that were received
            # more than one second ago are declared invalid,
            # and marked to expire from the cache in one second.
            time.sleep(1.1)

            # service updated. currently only text record can be updated
            service_text = b'path=/~humingchun/'
            _inject_response(zeroconf,
                             mock_incoming_msg(r.ServiceStateChange.Updated))
            dns_text = zeroconf.cache.get_by_details(service_name,
                                                     const._TYPE_TXT,
                                                     const._CLASS_IN)
            assert dns_text is not None
            assert cast(
                r.DNSText, dns_text
            ).text == service_text  # service_text is b'path=/~humingchun/'

            time.sleep(1.1)

            # The split message only has a SRV and A record.
            # This should not evict TXT records from the cache
            _inject_response(
                zeroconf,
                mock_split_incoming_msg(r.ServiceStateChange.Updated))
            time.sleep(1.1)
            dns_text = zeroconf.cache.get_by_details(service_name,
                                                     const._TYPE_TXT,
                                                     const._CLASS_IN)
            assert dns_text is not None
            assert cast(
                r.DNSText, dns_text
            ).text == service_text  # service_text is b'path=/~humingchun/'

            # service removed
            _inject_response(zeroconf,
                             mock_incoming_msg(r.ServiceStateChange.Removed))
            dns_text = zeroconf.cache.get_by_details(service_name,
                                                     const._TYPE_TXT,
                                                     const._CLASS_IN)
            assert dns_text.is_expired(current_time_millis() + 1000)

        finally:
            zeroconf.close()
예제 #15
0
def test_known_answer_supression_service_type_enumeration_query():
    zc = Zeroconf(interfaces=['127.0.0.1'])
    type_ = "_knownservice._tcp.local."
    name = "knownname"
    registration_name = "%s.%s" % (name, type_)
    desc = {'path': '/~paulsm/'}
    server_name = "ash-2.local."
    info = ServiceInfo(type_,
                       registration_name,
                       80,
                       0,
                       0,
                       desc,
                       server_name,
                       addresses=[socket.inet_aton("10.0.1.2")])
    zc.register_service(info)

    type_2 = "_knownservice2._tcp.local."
    name = "knownname"
    registration_name2 = "%s.%s" % (name, type_2)
    desc = {'path': '/~paulsm/'}
    server_name2 = "ash-3.local."
    info = ServiceInfo(type_2,
                       registration_name2,
                       80,
                       0,
                       0,
                       desc,
                       server_name2,
                       addresses=[socket.inet_aton("10.0.1.2")])
    zc.register_service(info)
    now = current_time_millis()
    _clear_cache(zc)

    # Test PTR supression
    generated = r.DNSOutgoing(const._FLAGS_QR_QUERY)
    question = r.DNSQuestion(const._SERVICE_TYPE_ENUMERATION_NAME,
                             const._TYPE_PTR, const._CLASS_IN)
    generated.add_question(question)
    packets = generated.packets()
    unicast_out, multicast_out = zc.query_handler.response(
        r.DNSIncoming(packets[0]), "1.2.3.4", const._MDNS_PORT)
    assert unicast_out is None
    assert multicast_out is not None and multicast_out.answers

    generated = r.DNSOutgoing(const._FLAGS_QR_QUERY)
    question = r.DNSQuestion(const._SERVICE_TYPE_ENUMERATION_NAME,
                             const._TYPE_PTR, const._CLASS_IN)
    generated.add_question(question)
    generated.add_answer_at_time(
        r.DNSPointer(
            const._SERVICE_TYPE_ENUMERATION_NAME,
            const._TYPE_PTR,
            const._CLASS_IN,
            const._DNS_OTHER_TTL,
            type_,
        ),
        now,
    )
    generated.add_answer_at_time(
        r.DNSPointer(
            const._SERVICE_TYPE_ENUMERATION_NAME,
            const._TYPE_PTR,
            const._CLASS_IN,
            const._DNS_OTHER_TTL,
            type_2,
        ),
        now,
    )
    packets = generated.packets()
    unicast_out, multicast_out = zc.query_handler.response(
        r.DNSIncoming(packets[0]), "1.2.3.4", const._MDNS_PORT)
    assert unicast_out is None
    assert not multicast_out or not multicast_out.answers

    # unregister
    zc.unregister_service(info)
    zc.close()
예제 #16
0
 def test_dns_record_is_expired(self):
     record = r.DNSRecord('irrelevant', const._TYPE_SRV, const._CLASS_IN, 8)
     now = current_time_millis()
     assert record.is_expired(now) is False
     assert record.is_expired(now + (8 / 2 * 1000)) is False
     assert record.is_expired(now + (8 * 1000)) is True
예제 #17
0
def test_known_answer_supression():
    zc = Zeroconf(interfaces=['127.0.0.1'])
    type_ = "_knownservice._tcp.local."
    name = "knownname"
    registration_name = "%s.%s" % (name, type_)
    desc = {'path': '/~paulsm/'}
    server_name = "ash-2.local."
    info = ServiceInfo(type_,
                       registration_name,
                       80,
                       0,
                       0,
                       desc,
                       server_name,
                       addresses=[socket.inet_aton("10.0.1.2")])
    zc.register_service(info)

    now = current_time_millis()
    _clear_cache(zc)
    # Test PTR supression
    generated = r.DNSOutgoing(const._FLAGS_QR_QUERY)
    question = r.DNSQuestion(type_, const._TYPE_PTR, const._CLASS_IN)
    generated.add_question(question)
    packets = generated.packets()
    unicast_out, multicast_out = zc.query_handler.response(
        r.DNSIncoming(packets[0]), "1.2.3.4", const._MDNS_PORT)
    assert unicast_out is None
    assert multicast_out is not None and multicast_out.answers

    generated = r.DNSOutgoing(const._FLAGS_QR_QUERY)
    question = r.DNSQuestion(type_, const._TYPE_PTR, const._CLASS_IN)
    generated.add_question(question)
    generated.add_answer_at_time(info.dns_pointer(), now)
    packets = generated.packets()
    unicast_out, multicast_out = zc.query_handler.response(
        r.DNSIncoming(packets[0]), "1.2.3.4", const._MDNS_PORT)
    assert unicast_out is None
    # If the answer is suppressed, the additional should be suppresed as well
    assert not multicast_out or not multicast_out.answers

    # Test A supression
    generated = r.DNSOutgoing(const._FLAGS_QR_QUERY)
    question = r.DNSQuestion(server_name, const._TYPE_A, const._CLASS_IN)
    generated.add_question(question)
    packets = generated.packets()
    unicast_out, multicast_out = zc.query_handler.response(
        r.DNSIncoming(packets[0]), "1.2.3.4", const._MDNS_PORT)
    assert unicast_out is None
    assert multicast_out is not None and multicast_out.answers

    generated = r.DNSOutgoing(const._FLAGS_QR_QUERY)
    question = r.DNSQuestion(server_name, const._TYPE_A, const._CLASS_IN)
    generated.add_question(question)
    for dns_address in info.dns_addresses():
        generated.add_answer_at_time(dns_address, now)
    packets = generated.packets()
    unicast_out, multicast_out = zc.query_handler.response(
        r.DNSIncoming(packets[0]), "1.2.3.4", const._MDNS_PORT)
    assert unicast_out is None
    assert not multicast_out or not multicast_out.answers

    # Test SRV supression
    generated = r.DNSOutgoing(const._FLAGS_QR_QUERY)
    question = r.DNSQuestion(registration_name, const._TYPE_SRV,
                             const._CLASS_IN)
    generated.add_question(question)
    packets = generated.packets()
    unicast_out, multicast_out = zc.query_handler.response(
        r.DNSIncoming(packets[0]), "1.2.3.4", const._MDNS_PORT)
    assert unicast_out is None
    assert multicast_out is not None and multicast_out.answers

    generated = r.DNSOutgoing(const._FLAGS_QR_QUERY)
    question = r.DNSQuestion(registration_name, const._TYPE_SRV,
                             const._CLASS_IN)
    generated.add_question(question)
    generated.add_answer_at_time(info.dns_service(), now)
    packets = generated.packets()
    unicast_out, multicast_out = zc.query_handler.response(
        r.DNSIncoming(packets[0]), "1.2.3.4", const._MDNS_PORT)
    assert unicast_out is None
    # If the answer is suppressed, the additional should be suppresed as well
    assert not multicast_out or not multicast_out.answers

    # Test TXT supression
    generated = r.DNSOutgoing(const._FLAGS_QR_QUERY)
    question = r.DNSQuestion(registration_name, const._TYPE_TXT,
                             const._CLASS_IN)
    generated.add_question(question)
    packets = generated.packets()
    unicast_out, multicast_out = zc.query_handler.response(
        r.DNSIncoming(packets[0]), "1.2.3.4", const._MDNS_PORT)
    assert unicast_out is None
    assert multicast_out is not None and multicast_out.answers

    generated = r.DNSOutgoing(const._FLAGS_QR_QUERY)
    question = r.DNSQuestion(registration_name, const._TYPE_TXT,
                             const._CLASS_IN)
    generated.add_question(question)
    generated.add_answer_at_time(info.dns_text(), now)
    packets = generated.packets()
    unicast_out, multicast_out = zc.query_handler.response(
        r.DNSIncoming(packets[0]), "1.2.3.4", const._MDNS_PORT)
    assert unicast_out is None
    assert not multicast_out or not multicast_out.answers

    # unregister
    zc.unregister_service(info)
    zc.close()
예제 #18
0
def test_qu_response_only_sends_additionals_if_sends_answer():
    """Test that a QU response does not send additionals unless it sends the answer as well."""
    # instantiate a zeroconf instance
    zc = Zeroconf(interfaces=['127.0.0.1'])

    type_ = "_addtest1._tcp.local."
    name = "knownname"
    registration_name = "%s.%s" % (name, type_)
    desc = {'path': '/~paulsm/'}
    server_name = "ash-2.local."
    info = ServiceInfo(
        type_, registration_name, 80, 0, 0, desc, server_name, addresses=[socket.inet_aton("10.0.1.2")]
    )
    zc.registry.add(info)

    type_2 = "_addtest2._tcp.local."
    name = "knownname"
    registration_name2 = "%s.%s" % (name, type_2)
    desc = {'path': '/~paulsm/'}
    server_name2 = "ash-3.local."
    info2 = ServiceInfo(
        type_2, registration_name2, 80, 0, 0, desc, server_name2, addresses=[socket.inet_aton("10.0.1.2")]
    )
    zc.registry.add(info2)

    ptr_record = info.dns_pointer()

    # Add the PTR record to the cache
    zc.cache.add(ptr_record)

    # Add the A record to the cache with 50% ttl remaining
    a_record = info.dns_addresses()[0]
    a_record.set_created_ttl(current_time_millis() - (a_record.ttl * 1000 / 2), a_record.ttl)
    assert not a_record.is_recent(current_time_millis())
    zc.cache.add(a_record)

    # With QU should respond to only unicast when the answer has been recently multicast
    # even if the additional has not been recently multicast
    query = r.DNSOutgoing(const._FLAGS_QR_QUERY)
    question = r.DNSQuestion(info.type, const._TYPE_PTR, const._CLASS_IN)
    question.unicast = True  # Set the QU bit
    assert question.unicast is True
    query.add_question(question)

    unicast_out, multicast_out = zc.query_handler.response(
        [r.DNSIncoming(packet) for packet in query.packets()], "1.2.3.4", const._MDNS_PORT
    )
    assert multicast_out is None
    assert a_record in unicast_out.additionals
    assert unicast_out.answers[0][0] == ptr_record

    # Remove the 50% A record and add a 100% A record
    zc.cache.remove(a_record)
    a_record = info.dns_addresses()[0]
    assert a_record.is_recent(current_time_millis())
    zc.cache.add(a_record)
    # With QU should respond to only unicast when the answer has been recently multicast
    # even if the additional has not been recently multicast
    query = r.DNSOutgoing(const._FLAGS_QR_QUERY)
    question = r.DNSQuestion(info.type, const._TYPE_PTR, const._CLASS_IN)
    question.unicast = True  # Set the QU bit
    assert question.unicast is True
    query.add_question(question)

    unicast_out, multicast_out = zc.query_handler.response(
        [r.DNSIncoming(packet) for packet in query.packets()], "1.2.3.4", const._MDNS_PORT
    )
    assert multicast_out is None
    assert a_record in unicast_out.additionals
    assert unicast_out.answers[0][0] == ptr_record

    # Remove the 100% PTR record and add a 50% PTR record
    zc.cache.remove(ptr_record)
    ptr_record.set_created_ttl(current_time_millis() - (ptr_record.ttl * 1000 / 2), ptr_record.ttl)
    assert not ptr_record.is_recent(current_time_millis())
    zc.cache.add(ptr_record)
    # With QU should respond to only multicast since the has less
    # than 75% of its ttl remaining
    query = r.DNSOutgoing(const._FLAGS_QR_QUERY)
    question = r.DNSQuestion(info.type, const._TYPE_PTR, const._CLASS_IN)
    question.unicast = True  # Set the QU bit
    assert question.unicast is True
    query.add_question(question)

    unicast_out, multicast_out = zc.query_handler.response(
        [r.DNSIncoming(packet) for packet in query.packets()], "1.2.3.4", const._MDNS_PORT
    )
    assert multicast_out.answers[0][0] == ptr_record
    assert a_record in multicast_out.additionals
    assert info.dns_text() in multicast_out.additionals
    assert info.dns_service() in multicast_out.additionals

    assert unicast_out is None

    # Ask 2 QU questions, with info the PTR is at 50%, with info2 the PTR is at 100%
    # We should get back a unicast reply for info2, but info should be multicasted since its within 75% of its TTL
    # With QU should respond to only multicast since the has less
    # than 75% of its ttl remaining
    query = r.DNSOutgoing(const._FLAGS_QR_QUERY)
    question = r.DNSQuestion(info.type, const._TYPE_PTR, const._CLASS_IN)
    question.unicast = True  # Set the QU bit
    assert question.unicast is True
    query.add_question(question)

    question = r.DNSQuestion(info2.type, const._TYPE_PTR, const._CLASS_IN)
    question.unicast = True  # Set the QU bit
    assert question.unicast is True
    query.add_question(question)
    zc.cache.add(info2.dns_pointer())  # Add 100% TTL for info2 to the cache

    unicast_out, multicast_out = zc.query_handler.response(
        [r.DNSIncoming(packet) for packet in query.packets()], "1.2.3.4", const._MDNS_PORT
    )
    assert multicast_out.answers[0][0] == info.dns_pointer()
    assert info.dns_addresses()[0] in multicast_out.additionals
    assert info.dns_text() in multicast_out.additionals
    assert info.dns_service() in multicast_out.additionals

    assert unicast_out.answers[0][0] == info2.dns_pointer()
    assert info2.dns_addresses()[0] in unicast_out.additionals
    assert info2.dns_text() in unicast_out.additionals
    assert info2.dns_service() in unicast_out.additionals

    # unregister
    zc.registry.remove(info)
    zc.close()
예제 #19
0
    def test_service_info_rejects_non_matching_updates(self):
        """Verify records with the wrong name are rejected."""

        zc = r.Zeroconf(interfaces=['127.0.0.1'])
        desc = {'path': '/~paulsm/'}
        service_name = 'name._type._tcp.local.'
        service_type = '_type._tcp.local.'
        service_server = 'ash-1.local.'
        service_address = socket.inet_aton("10.0.1.2")
        ttl = 120
        now = r.current_time_millis()
        info = ServiceInfo(
            service_type, service_name, 22, 0, 0, desc, service_server, addresses=[service_address]
        )
        # Verify backwards compatiblity with calling with None
        info.update_record(zc, now, None)
        # Matching updates
        info.update_record(
            zc,
            now,
            r.DNSText(
                service_name,
                const._TYPE_TXT,
                const._CLASS_IN | const._CLASS_UNIQUE,
                ttl,
                b'\x04ff=0\x04ci=2\x04sf=0\x0bsh=6fLM5A==',
            ),
        )
        assert info.properties[b"ci"] == b"2"
        info.update_record(
            zc,
            now,
            r.DNSService(
                service_name,
                const._TYPE_SRV,
                const._CLASS_IN | const._CLASS_UNIQUE,
                ttl,
                0,
                0,
                80,
                'ASH-2.local.',
            ),
        )
        assert info.server_key == 'ash-2.local.'
        assert info.server == 'ASH-2.local.'
        new_address = socket.inet_aton("10.0.1.3")
        info.update_record(
            zc,
            now,
            r.DNSAddress(
                'ASH-2.local.',
                const._TYPE_A,
                const._CLASS_IN | const._CLASS_UNIQUE,
                ttl,
                new_address,
            ),
        )
        assert new_address in info.addresses
        # Non-matching updates
        info.update_record(
            zc,
            now,
            r.DNSText(
                "incorrect.name.",
                const._TYPE_TXT,
                const._CLASS_IN | const._CLASS_UNIQUE,
                ttl,
                b'\x04ff=0\x04ci=3\x04sf=0\x0bsh=6fLM5A==',
            ),
        )
        assert info.properties[b"ci"] == b"2"
        info.update_record(
            zc,
            now,
            r.DNSService(
                "incorrect.name.",
                const._TYPE_SRV,
                const._CLASS_IN | const._CLASS_UNIQUE,
                ttl,
                0,
                0,
                80,
                'ASH-2.local.',
            ),
        )
        assert info.server_key == 'ash-2.local.'
        assert info.server == 'ASH-2.local.'
        new_address = socket.inet_aton("10.0.1.4")
        info.update_record(
            zc,
            now,
            r.DNSAddress(
                "incorrect.name.",
                const._TYPE_A,
                const._CLASS_IN | const._CLASS_UNIQUE,
                ttl,
                new_address,
            ),
        )
        assert new_address not in info.addresses
        zc.close()
예제 #20
0
def test_tc_bit_defers_last_response_missing():
    zc = Zeroconf(interfaces=['127.0.0.1'])
    type_ = "_knowndefer._tcp.local."
    name = "knownname"
    name2 = "knownname2"
    name3 = "knownname3"

    registration_name = "%s.%s" % (name, type_)
    registration2_name = "%s.%s" % (name2, type_)
    registration3_name = "%s.%s" % (name3, type_)

    desc = {'path': '/~paulsm/'}
    server_name = "ash-2.local."
    server_name2 = "ash-3.local."
    server_name3 = "ash-4.local."

    info = r.ServiceInfo(type_,
                         registration_name,
                         80,
                         0,
                         0,
                         desc,
                         server_name,
                         addresses=[socket.inet_aton("10.0.1.2")])
    info2 = r.ServiceInfo(type_,
                          registration2_name,
                          80,
                          0,
                          0,
                          desc,
                          server_name2,
                          addresses=[socket.inet_aton("10.0.1.2")])
    info3 = r.ServiceInfo(type_,
                          registration3_name,
                          80,
                          0,
                          0,
                          desc,
                          server_name3,
                          addresses=[socket.inet_aton("10.0.1.2")])
    zc.registry.add(info)
    zc.registry.add(info2)
    zc.registry.add(info3)

    def threadsafe_query(*args):
        async def make_query():
            zc.handle_query(*args)

        asyncio.run_coroutine_threadsafe(make_query(), zc.loop).result()

    now = r.current_time_millis()
    _clear_cache(zc)
    source_ip = '203.0.113.12'

    generated = r.DNSOutgoing(const._FLAGS_QR_QUERY)
    question = r.DNSQuestion(type_, const._TYPE_PTR, const._CLASS_IN)
    generated.add_question(question)
    for _ in range(300):
        # Add so many answers we end up with another packet
        generated.add_answer_at_time(info.dns_pointer(), now)
    generated.add_answer_at_time(info2.dns_pointer(), now)
    generated.add_answer_at_time(info3.dns_pointer(), now)
    packets = generated.packets()
    assert len(packets) == 4
    expected_deferred = []

    next_packet = r.DNSIncoming(packets.pop(0))
    expected_deferred.append(next_packet)
    threadsafe_query(next_packet, source_ip, const._MDNS_PORT)
    assert zc._deferred[source_ip] == expected_deferred
    timer1 = zc._timers[source_ip]

    next_packet = r.DNSIncoming(packets.pop(0))
    expected_deferred.append(next_packet)
    threadsafe_query(next_packet, source_ip, const._MDNS_PORT)
    assert zc._deferred[source_ip] == expected_deferred
    timer2 = zc._timers[source_ip]
    if sys.version_info >= (3, 7):
        assert timer1.cancelled()
    assert timer2 != timer1

    # Send the same packet again to similar multi interfaces
    threadsafe_query(next_packet, source_ip, const._MDNS_PORT)
    assert zc._deferred[source_ip] == expected_deferred
    assert source_ip in zc._timers
    timer3 = zc._timers[source_ip]
    if sys.version_info >= (3, 7):
        assert not timer3.cancelled()
    assert timer3 == timer2

    next_packet = r.DNSIncoming(packets.pop(0))
    expected_deferred.append(next_packet)
    threadsafe_query(next_packet, source_ip, const._MDNS_PORT)
    assert zc._deferred[source_ip] == expected_deferred
    assert source_ip in zc._timers
    timer4 = zc._timers[source_ip]
    if sys.version_info >= (3, 7):
        assert timer3.cancelled()
    assert timer4 != timer3

    for _ in range(8):
        time.sleep(0.1)
        if source_ip not in zc._timers:
            break

    assert source_ip not in zc._deferred
    assert source_ip not in zc._timers

    # unregister
    zc.registry.remove(info)
    zc.close()