def setup_test(self):
     super().setup_test()
     self.dut_le_acl_manager = PyLeAclManager(self.dut)
     self.cert_hci_le_event_stream = EventStream(
         self.cert.hci.StreamLeSubevents(empty_proto.Empty()))
     self.cert_acl_data_stream = EventStream(
         self.cert.hci.StreamAcl(empty_proto.Empty()))
Exemple #2
0
    def __init__(self, device):
        self._device = device
        self._le_acl_manager = PyLeAclManager(device)
        self._le_acl = None

        self.control_table = {
            LeCommandCode.DISCONNECTION_REQUEST:
            self._on_disconnection_request_default,
            LeCommandCode.DISCONNECTION_RESPONSE:
            self._on_disconnection_response_default,
            LeCommandCode.LE_FLOW_CONTROL_CREDIT: self._on_credit,
        }

        self._cid_to_cert_channels = {}
class LeAclManagerTest(GdBaseTestClass):
    def setup_class(self):
        super().setup_class(dut_module='HCI_INTERFACES', cert_module='HCI')

    def setup_test(self):
        super().setup_test()
        self.dut_le_acl_manager = PyLeAclManager(self.dut)
        self.cert_hci_le_event_stream = EventStream(
            self.cert.hci.StreamLeSubevents(empty_proto.Empty()))
        self.cert_acl_data_stream = EventStream(
            self.cert.hci.StreamAcl(empty_proto.Empty()))

    def teardown_test(self):
        safeClose(self.cert_hci_le_event_stream)
        safeClose(self.cert_acl_data_stream)
        safeClose(self.dut_le_acl_manager)
        super().teardown_test()

    def set_privacy_policy_static(self):
        self.dut_address = b'd0:05:04:03:02:01'
        private_policy = le_initiator_address_facade.PrivacyPolicy(
            address_policy=le_initiator_address_facade.AddressPolicy.
            USE_STATIC_ADDRESS,
            address_with_type=common.BluetoothAddressWithType(
                address=common.BluetoothAddress(
                    address=bytes(self.dut_address)),
                type=common.RANDOM_DEVICE_ADDRESS))
        self.dut.hci_le_initiator_address.SetPrivacyPolicyForInitiatorAddress(
            private_policy)

    def register_for_event(self, event_code):
        msg = hci_facade.EventRequest(code=int(event_code))
        self.cert.hci.RequestEvent(msg)

    def register_for_le_event(self, event_code):
        msg = hci_facade.EventRequest(code=int(event_code))
        self.cert.hci.RequestLeSubevent(msg)

    def enqueue_hci_command(self, command):
        cmd_bytes = bytes(command.Serialize())
        cmd = common.Data(payload=cmd_bytes)
        self.cert.hci.SendCommand(cmd)

    def enqueue_acl_data(self, handle, pb_flag, b_flag, data):
        acl = hci_packets.AclBuilder(handle, pb_flag, b_flag, RawBuilder(data))
        self.cert.hci.SendAcl(common.Data(payload=bytes(acl.Serialize())))

    def dut_connects(self, check_address):
        self.register_for_le_event(
            hci_packets.SubeventCode.CONNECTION_COMPLETE)
        self.register_for_le_event(
            hci_packets.SubeventCode.ENHANCED_CONNECTION_COMPLETE)

        # Cert Advertises
        advertising_handle = 0
        self.enqueue_hci_command(
            hci_packets.LeSetExtendedAdvertisingLegacyParametersBuilder(
                advertising_handle,
                hci_packets.LegacyAdvertisingProperties.ADV_IND,
                400,
                450,
                7,
                hci_packets.OwnAddressType.RANDOM_DEVICE_ADDRESS,
                hci_packets.PeerAddressType.PUBLIC_DEVICE_OR_IDENTITY_ADDRESS,
                '00:00:00:00:00:00',
                hci_packets.AdvertisingFilterPolicy.ALL_DEVICES,
                0xF8,
                1,  #SID
                hci_packets.Enable.DISABLED  # Scan request notification
            ))

        self.enqueue_hci_command(
            hci_packets.LeSetExtendedAdvertisingRandomAddressBuilder(
                advertising_handle, '0C:05:04:03:02:01'))

        gap_name = hci_packets.GapData()
        gap_name.data_type = hci_packets.GapDataType.COMPLETE_LOCAL_NAME
        gap_name.data = list(bytes(b'Im_A_Cert'))

        self.enqueue_hci_command(
            hci_packets.LeSetExtendedAdvertisingDataBuilder(
                advertising_handle,
                hci_packets.Operation.COMPLETE_ADVERTISEMENT,
                hci_packets.FragmentPreference.CONTROLLER_SHOULD_NOT,
                [gap_name]))

        gap_short_name = hci_packets.GapData()
        gap_short_name.data_type = hci_packets.GapDataType.SHORTENED_LOCAL_NAME
        gap_short_name.data = list(bytes(b'Im_A_C'))

        self.enqueue_hci_command(
            hci_packets.LeSetExtendedAdvertisingScanResponseBuilder(
                advertising_handle,
                hci_packets.Operation.COMPLETE_ADVERTISEMENT,
                hci_packets.FragmentPreference.CONTROLLER_SHOULD_NOT,
                [gap_short_name]))

        enabled_set = hci_packets.EnabledSet()
        enabled_set.advertising_handle = advertising_handle
        enabled_set.duration = 0
        enabled_set.max_extended_advertising_events = 0
        self.enqueue_hci_command(
            hci_packets.LeSetExtendedAdvertisingEnableBuilder(
                hci_packets.Enable.ENABLED, [enabled_set]))

        self.dut_le_acl = self.dut_le_acl_manager.connect_to_remote(
            remote_addr=common.BluetoothAddressWithType(
                address=common.BluetoothAddress(
                    address=bytes('0C:05:04:03:02:01', 'utf8')),
                type=int(hci_packets.AddressType.RANDOM_DEVICE_ADDRESS)))

        # Cert gets ConnectionComplete with a handle and sends ACL data
        handle = 0xfff
        address = hci_packets.Address()

        def get_handle(packet):
            packet_bytes = packet.payload
            nonlocal handle
            nonlocal address
            if b'\x3e\x13\x01\x00' in packet_bytes:
                cc_view = hci_packets.LeConnectionCompleteView(
                    hci_packets.LeMetaEventView(
                        hci_packets.EventView(
                            bt_packets.PacketViewLittleEndian(
                                list(packet_bytes)))))
                handle = cc_view.GetConnectionHandle()
                address = cc_view.GetPeerAddress()
                return True
            if b'\x3e\x13\x0A\x00' in packet_bytes:
                cc_view = hci_packets.LeEnhancedConnectionCompleteView(
                    hci_packets.LeMetaEventView(
                        hci_packets.EventView(
                            bt_packets.PacketViewLittleEndian(
                                list(packet_bytes)))))
                handle = cc_view.GetConnectionHandle()
                address = cc_view.GetPeerResolvablePrivateAddress()
                return True
            return False

        self.cert_hci_le_event_stream.assert_event_occurs(get_handle)
        self.cert_handle = handle
        dut_address_from_complete = address
        if check_address:
            assertThat(dut_address_from_complete).isEqualTo(
                self.dut_address.decode())

    def send_receive_and_check(self):
        self.enqueue_acl_data(
            self.cert_handle,
            hci_packets.PacketBoundaryFlag.FIRST_NON_AUTOMATICALLY_FLUSHABLE,
            hci_packets.BroadcastFlag.POINT_TO_POINT,
            bytes(b'\x19\x00\x07\x00SomeAclData from the Cert'))

        self.dut_le_acl.send(b'\x1C\x00\x07\x00SomeMoreAclData from the DUT')
        self.cert_acl_data_stream.assert_event_occurs(
            lambda packet: b'SomeMoreAclData' in packet.payload)
        assertThat(self.dut_le_acl).emits(
            lambda packet: b'SomeAclData' in packet.payload)

    def test_dut_connects(self):
        self.set_privacy_policy_static()
        self.dut_connects(check_address=True)
        self.send_receive_and_check()

    def test_dut_connects_resolvable_address(self):
        privacy_policy = le_initiator_address_facade.PrivacyPolicy(
            address_policy=le_initiator_address_facade.AddressPolicy.
            USE_RESOLVABLE_ADDRESS,
            rotation_irk=
            b'\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f',
            minimum_rotation_time=7 * 60 * 1000,
            maximum_rotation_time=15 * 60 * 1000)
        self.dut.hci_le_initiator_address.SetPrivacyPolicyForInitiatorAddress(
            privacy_policy)
        self.dut_connects(check_address=False)
        self.send_receive_and_check()

    def test_dut_connects_non_resolvable_address(self):
        privacy_policy = le_initiator_address_facade.PrivacyPolicy(
            address_policy=le_initiator_address_facade.AddressPolicy.
            USE_NON_RESOLVABLE_ADDRESS,
            rotation_irk=
            b'\x10\x11\x12\x13\x14\x15\x16\x17\x18\x19\x1a\x1b\x1c\x1d\x1e\x1f',
            minimum_rotation_time=8 * 60 * 1000,
            maximum_rotation_time=14 * 60 * 1000)
        self.dut.hci_le_initiator_address.SetPrivacyPolicyForInitiatorAddress(
            privacy_policy)
        self.dut_connects(check_address=False)
        self.send_receive_and_check()

    def test_dut_connects_public_address(self):
        self.dut.hci_le_initiator_address.SetPrivacyPolicyForInitiatorAddress(
            le_initiator_address_facade.PrivacyPolicy(
                address_policy=le_initiator_address_facade.AddressPolicy.
                USE_PUBLIC_ADDRESS))
        self.dut_connects(check_address=False)
        self.send_receive_and_check()

    def test_dut_connects_public_address_cancelled(self):
        self.dut.hci_le_initiator_address.SetPrivacyPolicyForInitiatorAddress(
            le_initiator_address_facade.PrivacyPolicy(
                address_policy=le_initiator_address_facade.AddressPolicy.
                USE_PUBLIC_ADDRESS))
        self.dut_connects(check_address=False)
        self.send_receive_and_check()

    def test_cert_connects(self):
        self.set_privacy_policy_static()
        self.register_for_le_event(
            hci_packets.SubeventCode.CONNECTION_COMPLETE)

        self.dut_le_acl_manager.listen_for_incoming_connections()

        # DUT Advertises
        gap_name = hci_packets.GapData()
        gap_name.data_type = hci_packets.GapDataType.COMPLETE_LOCAL_NAME
        gap_name.data = list(bytes(b'Im_The_DUT'))
        gap_data = le_advertising_facade.GapDataMsg(
            data=bytes(gap_name.Serialize()))
        config = le_advertising_facade.AdvertisingConfig(
            advertisement=[gap_data],
            interval_min=512,
            interval_max=768,
            advertising_type=le_advertising_facade.AdvertisingEventType.
            ADV_IND,
            own_address_type=common.USE_RANDOM_DEVICE_ADDRESS,
            peer_address_type=common.PUBLIC_DEVICE_OR_IDENTITY_ADDRESS,
            peer_address=common.BluetoothAddress(
                address=bytes(b'A6:A5:A4:A3:A2:A1')),
            channel_map=7,
            filter_policy=le_advertising_facade.AdvertisingFilterPolicy.
            ALL_DEVICES)
        request = le_advertising_facade.CreateAdvertiserRequest(config=config)

        self.dut.hci_le_advertising_manager.CreateAdvertiser(request)

        # Cert Connects
        self.enqueue_hci_command(
            hci_packets.LeSetRandomAddressBuilder('0C:05:04:03:02:01'))
        phy_scan_params = hci_packets.LeCreateConnPhyScanParameters()
        phy_scan_params.scan_interval = 0x60
        phy_scan_params.scan_window = 0x30
        phy_scan_params.conn_interval_min = 0x18
        phy_scan_params.conn_interval_max = 0x28
        phy_scan_params.conn_latency = 0
        phy_scan_params.supervision_timeout = 0x1f4
        phy_scan_params.min_ce_length = 0
        phy_scan_params.max_ce_length = 0
        self.enqueue_hci_command(
            hci_packets.LeExtendedCreateConnectionBuilder(
                hci_packets.InitiatorFilterPolicy.USE_PEER_ADDRESS,
                hci_packets.OwnAddressType.RANDOM_DEVICE_ADDRESS,
                hci_packets.AddressType.RANDOM_DEVICE_ADDRESS,
                self.dut_address.decode(), 1, [phy_scan_params]))

        # Cert gets ConnectionComplete with a handle and sends ACL data
        handle = 0xfff

        def get_handle(packet):
            packet_bytes = packet.payload
            nonlocal handle
            if b'\x3e\x13\x01\x00' in packet_bytes:
                cc_view = hci_packets.LeConnectionCompleteView(
                    hci_packets.LeMetaEventView(
                        hci_packets.EventView(
                            bt_packets.PacketViewLittleEndian(
                                list(packet_bytes)))))
                handle = cc_view.GetConnectionHandle()
                return True
            if b'\x3e\x13\x0A\x00' in packet_bytes:
                cc_view = hci_packets.LeEnhancedConnectionCompleteView(
                    hci_packets.LeMetaEventView(
                        hci_packets.EventView(
                            bt_packets.PacketViewLittleEndian(
                                list(packet_bytes)))))
                handle = cc_view.GetConnectionHandle()
                return True
            return False

        self.cert_hci_le_event_stream.assert_event_occurs(get_handle)
        self.cert_handle = handle

        self.enqueue_acl_data(
            self.cert_handle,
            hci_packets.PacketBoundaryFlag.FIRST_NON_AUTOMATICALLY_FLUSHABLE,
            hci_packets.BroadcastFlag.POINT_TO_POINT,
            bytes(b'\x19\x00\x07\x00SomeAclData from the Cert'))

        # DUT gets a connection complete event and sends and receives
        handle = 0xfff
        self.dut_le_acl = self.dut_le_acl_manager.complete_incoming_connection(
        )

        self.send_receive_and_check()

    def test_recombination_l2cap_packet(self):
        self.set_privacy_policy_static()
        self.dut_connects(check_address=True)

        self.enqueue_acl_data(
            self.cert_handle,
            hci_packets.PacketBoundaryFlag.FIRST_NON_AUTOMATICALLY_FLUSHABLE,
            hci_packets.BroadcastFlag.POINT_TO_POINT,
            bytes(b'\x06\x00\x07\x00Hello'))
        self.enqueue_acl_data(
            self.cert_handle,
            hci_packets.PacketBoundaryFlag.CONTINUING_FRAGMENT,
            hci_packets.BroadcastFlag.POINT_TO_POINT, bytes(b'!'))

        assertThat(
            self.dut_le_acl).emits(lambda packet: b'Hello!' in packet.payload)
Exemple #4
0
class CertLeL2cap(Closable):
    def __init__(self, device):
        self._device = device
        self._le_acl_manager = PyLeAclManager(device)
        self._le_acl = None

        self.control_table = {
            LeCommandCode.DISCONNECTION_REQUEST:
            self._on_disconnection_request_default,
            LeCommandCode.DISCONNECTION_RESPONSE:
            self._on_disconnection_response_default,
            LeCommandCode.LE_FLOW_CONTROL_CREDIT: self._on_credit,
        }

        self._cid_to_cert_channels = {}

    def close(self):
        self._le_acl_manager.close()
        safeClose(self._le_acl)

    def connect_le_acl(self, remote_addr):
        self._le_acl = self._le_acl_manager.connect_to_remote(remote_addr)
        self.control_channel = CertLeL2capChannel(self._device,
                                                  5,
                                                  5,
                                                  self._get_acl_stream(),
                                                  self._le_acl,
                                                  control_channel=None)
        self._get_acl_stream().register_callback(self._handle_control_packet)

    def wait_for_connection(self):
        self._le_acl = self._le_acl_manager.wait_for_connection()
        self.control_channel = CertLeL2capChannel(self._device,
                                                  5,
                                                  5,
                                                  self._get_acl_stream(),
                                                  self._le_acl,
                                                  control_channel=None)
        self._get_acl_stream().register_callback(self._handle_control_packet)

    def open_fixed_channel(self, cid=4):
        channel = CertLeL2capChannel(self._device, cid, cid,
                                     self._get_acl_stream(), self._le_acl,
                                     None, 0)
        return channel

    def open_channel(self,
                     signal_id,
                     psm,
                     scid,
                     mtu=1000,
                     mps=100,
                     initial_credit=6):
        self.control_channel.send(
            l2cap_packets.LeCreditBasedConnectionRequestBuilder(
                signal_id, psm, scid, mtu, mps, initial_credit))

        response = L2capCaptures.CreditBasedConnectionResponse()
        assertThat(self.control_channel).emits(response)
        channel = CertLeL2capChannel(self._device, scid,
                                     response.get().GetDestinationCid(),
                                     self._get_acl_stream(), self._le_acl,
                                     self.control_channel,
                                     response.get().GetInitialCredits())
        self._cid_to_cert_channels[scid] = channel
        return channel

    def open_channel_with_expected_result(
            self,
            psm=0x33,
            result=LeCreditBasedConnectionResponseResult.SUCCESS):
        self.control_channel.send(
            l2cap_packets.LeCreditBasedConnectionRequestBuilder(
                1, psm, 0x40, 1000, 100, 6))

        response = L2capMatchers.CreditBasedConnectionResponse(result)
        assertThat(self.control_channel).emits(response)

    def verify_and_respond_open_channel_from_remote(
            self,
            psm=0x33,
            result=LeCreditBasedConnectionResponseResult.SUCCESS,
            our_scid=None):
        request = L2capCaptures.CreditBasedConnectionRequest(psm)
        assertThat(self.control_channel).emits(request)
        (scid, dcid) = self._respond_connection_request_default(
            request.get(), result, our_scid)
        channel = CertLeL2capChannel(self._device, scid, dcid,
                                     self._get_acl_stream(), self._le_acl,
                                     self.control_channel,
                                     request.get().GetInitialCredits())
        self._cid_to_cert_channels[scid] = channel
        return channel

    def verify_and_reject_open_channel_from_remote(self, psm=0x33):
        request = L2capCaptures.CreditBasedConnectionRequest(psm)
        assertThat(self.control_channel).emits(request)
        sid = request.get().GetIdentifier()
        reject = l2cap_packets.LeCommandRejectNotUnderstoodBuilder(sid)
        self.control_channel.send(reject)

    def verify_le_flow_control_credit(self, channel):
        assertThat(self.control_channel).emits(
            L2capMatchers.LeFlowControlCredit(channel._dcid))

    def _respond_connection_request_default(
            self,
            request,
            result=LeCreditBasedConnectionResponseResult.SUCCESS,
            our_scid=None):
        sid = request.GetIdentifier()
        their_scid = request.GetSourceCid()
        mtu = request.GetMtu()
        mps = request.GetMps()
        initial_credits = request.GetInitialCredits()
        # If our_scid is not specified, we use the same value - their scid as their scid
        if our_scid is None:
            our_scid = their_scid
        our_dcid = their_scid
        response = l2cap_packets.LeCreditBasedConnectionResponseBuilder(
            sid, our_scid, mtu, mps, initial_credits, result)
        self.control_channel.send(response)
        return (our_scid, our_dcid)

    def get_control_channel(self):
        return self.control_channel

    def _get_acl_stream(self):
        return self._le_acl.acl_stream

    def _on_disconnection_request_default(self, request):
        disconnection_request = l2cap_packets.LeDisconnectionRequestView(
            request)
        sid = disconnection_request.GetIdentifier()
        scid = disconnection_request.GetSourceCid()
        dcid = disconnection_request.GetDestinationCid()
        response = l2cap_packets.LeDisconnectionResponseBuilder(
            sid, dcid, scid)
        self.control_channel.send(response)

    def _on_disconnection_response_default(self, request):
        disconnection_response = l2cap_packets.LeDisconnectionResponseView(
            request)

    def _on_credit(self, l2cap_le_control_view):
        credit_view = l2cap_packets.LeFlowControlCreditView(
            l2cap_le_control_view)
        cid = credit_view.GetCid()
        if cid not in self._cid_to_cert_channels:
            return
        self._cid_to_cert_channels[
            cid]._credits_left += credit_view.GetCredits()

    def _handle_control_packet(self, l2cap_packet):
        packet_bytes = l2cap_packet.payload
        l2cap_view = l2cap_packets.BasicFrameView(
            bt_packets.PacketViewLittleEndian(list(packet_bytes)))
        if l2cap_view.GetChannelId() != 5:
            return
        request = l2cap_packets.LeControlView(l2cap_view.GetPayload())
        fn = self.control_table.get(request.GetCode())
        if fn is not None:
            fn(request)
        return