Ejemplo n.º 1
0
class DxlClientTest(unittest.TestCase):
    def setUp(self):
        self.config = DxlClientConfig(broker_ca_bundle=get_ca_bundle_pem(),
                                      cert_file=get_cert_file_pem(),
                                      private_key=get_dxl_private_key(),
                                      brokers=[])

        mqtt_client_patch = patch('paho.mqtt.client.Client')
        mqtt_client_patch.start()

        self.client = DxlClient(self.config)
        self.client._request_manager.wait_for_response = Mock(return_value=Response(request=None))

        self.test_channel = '/test/channel'

    def tearDown(self):
        patch.stopall()

    def test_client_raises_exception_on_connect_when_already_connecting(self):
        self.client._client.connect.side_effect = Exception("An exception!")

        class MyThread(threading.Thread):
            def __init__(self, client):
                super(MyThread, self).__init__()
                self._client = client

            def run(self):
                self._client.connect()

        t = MyThread(self.client)
        t.setDaemon(True)
        t.start()
        time.sleep(2)

        self.assertEqual(self.client.connected, False)
        with self.assertRaises(DxlException):
            self.client.connect()
            # self.client.disconnect()

    def test_client_raises_exception_on_connect_when_already_connected(self):
        self.client._client.connect.side_effect = Exception("An exception!")
        self.client._connected = Mock(return_value=True)
        with self.assertRaises(DxlException):
            self.client.connect()
            # self.client.disconnect()

    # The following test is too slow
    def test_client_disconnect_doesnt_raises_exception_on_disconnect_when_disconnected(self):
        self.assertEqual(self.client.connected, False)
        self.client.disconnect()
        self.client.disconnect()

    @parameterized.expand([
        # (connect + retries) * 2 = connect_count
        (0, 2),
        (1, 4),
        (2, 6),
    ])
    def test_client_retries_defines_how_many_times_the_client_retries_connection(self, retries, connect_count):
        # Client wont' connect ;)
        self.client._client.connect = Mock(side_effect=Exception('Could not connect'))
        # No delay between retries (faster unit tests)
        self.client.config.reconnect_delay = 0
        self.client._wait_for_policy_delay = 0

        broker = Broker(host_name='localhost')
        broker._parse(UuidGenerator.generate_id_as_string() + ";9999;localhost;127.0.0.1")

        self.client.config.brokers = [broker]
        self.client.config.connect_retries = retries

        with self.assertRaises(DxlException):
            self.client.connect()
        self.assertEqual(self.client._client.connect.call_count, connect_count)
        # self.client.disconnect()

    def test_client_subscribe_adds_subscription_when_not_connected(self):
        self.client._client.subscribe = Mock(return_value=None)
        self.assertFalse(self.client.connected)

        self.client.subscribe(self.test_channel)
        self.assertTrue(self.test_channel in self.client.subscriptions)
        self.assertEqual(self.client._client.subscribe.call_count, 0)

    def test_client_unsubscribe_removes_subscription_when_not_connected(self):
        self.client._client.unsubscribe = Mock(return_value=None)
        self.assertFalse(self.client.connected)
        # Add subscription
        self.client.subscribe(self.test_channel)
        self.assertTrue(self.test_channel in self.client.subscriptions)
        # Remove subscription
        self.client.unsubscribe(self.test_channel)
        self.assertFalse(self.test_channel in self.client.subscriptions)

    def test_client_subscribe_doesnt_add_twice_same_channel(self):
        # Mock client.subscribe and is_connected
        self.client._client.subscribe = Mock(return_value=None)
        self.client._connected = Mock(return_value=True)

        # We always have the default (myself) channel
        self.assertEqual(len(self.client.subscriptions), 1)
        self.client.subscribe(self.test_channel)
        self.assertEqual(len(self.client.subscriptions), 2)
        self.client.subscribe(self.test_channel)
        self.assertEqual(len(self.client.subscriptions), 2)
        self.assertEqual(self.client._client.subscribe.call_count, 1)

    def test_client_handle_message_with_event_calls_event_callback(self):
        event_callback = EventCallback()
        event_callback.on_event = Mock()
        self.client.add_event_callback(self.test_channel, event_callback)
        # Create and process Event
        evt = Event(destination_topic=self.test_channel)._to_bytes()
        self.client._handle_message(self.test_channel, evt)
        # Check that callback was called
        self.assertEqual(event_callback.on_event.call_count, 1)

    def test_client_handle_message_with_request_calls_request_callback(self):
        req_callback = RequestCallback()
        req_callback.on_request = Mock()
        self.client.add_request_callback(self.test_channel, req_callback)
        # Create and process Request
        req = Request(destination_topic=self.test_channel)._to_bytes()
        self.client._handle_message(self.test_channel, req)
        # Check that callback was called
        self.assertEqual(req_callback.on_request.call_count, 1)

    def test_client_handle_message_with_response_calls_response_callback(self):
        callback = ResponseCallback()
        callback.on_response = Mock()
        self.client.add_response_callback(self.test_channel, callback)
        # Create and process Response
        msg = Response(request=None)._to_bytes()
        self.client._handle_message(self.test_channel, msg)
        # Check that callback was called
        self.assertEqual(callback.on_response.call_count, 1)

    def test_client_send_event_publishes_message_to_dxl_fabric(self):
        self.client._client.publish = Mock(return_value=None)
        # Create and process Request
        msg = Event(destination_topic="")
        self.client.send_event(msg)
        # Check that callback was called
        self.assertEqual(self.client._client.publish.call_count, 1)

    def test_client_send_request_publishes_message_to_dxl_fabric(self):
        self.client._client.publish = Mock(return_value=None)
        # Create and process Request
        msg = Request(destination_topic="")
        self.client._send_request(msg)
        # Check that callback was called
        self.assertEqual(self.client._client.publish.call_count, 1)

    def test_client_send_response_publishes_message_to_dxl_fabric(self):
        self.client._client.publish = Mock(return_value=None)
        # Create and process Request
        msg = Response(request=None)
        self.client.send_response(msg)
        # Check that callback was called
        self.assertEqual(self.client._client.publish.call_count, 1)

    def test_client_handles_error_response_and_fire_response_handler(self):
        self.client._fire_response = Mock(return_value=None)
        # Create and process Request
        msg = ErrorResponse(request=None, error_code=666, error_message="test message")
        payload = msg._to_bytes()
        # Handle error response message
        self.client._handle_message(self.test_channel, payload)
        # Check that message response was properly delivered to handler
        self.assertEqual(self.client._fire_response.call_count, 1)

    """
    Service unit tests
    """

    def test_client_register_service_subscribes_client_to_channel(self):
        channel1 = '/mcafee/service/unittest/one'
        channel2 = '/mcafee/service/unittest/two'
        # Create dummy service
        service_info = dxlclient.service.ServiceRegistrationInfo(
            service_type='/mcafee/service/unittest', client=self.client)
        service_info.add_topic(channel1, RequestCallback())
        service_info.add_topic(channel2, RequestCallback())

        # Register service in client
        self.client.register_service_async(service_info)
        # Check subscribed channels
        subscriptions = self.client.subscriptions
        assert channel1 in subscriptions, "Client wasn't subscribed to service channel"
        assert channel2 in subscriptions, "Client wasn't subscribed to service channel"

    def test_client_wont_register_the_same_service_twice(self):
        service_info = dxlclient.service.ServiceRegistrationInfo(
            service_type='/mcafee/service/unittest', client=self.client)

        # Register service in client
        self.client.register_service_async(service_info)
        with self.assertRaises(dxlclient.DxlException):
            # Re-register service
            self.client.register_service_async(service_info)

    def test_client_register_service_sends_register_request_to_broker(self):
        service_info = dxlclient.service.ServiceRegistrationInfo(
            service_type='/mcafee/service/unittest', client=self.client)

        self.client._send_request = Mock(return_value=True)
        self.client._connected = Mock(return_value=True)

        # Register service in client
        self.client.register_service_async(service_info)
        time.sleep(2)
        # Check that method has been called
        self.assertTrue(self.client._send_request.called)

    def test_client_register_service_unsubscribes_client_to_channel(self):
        channel1 = '/mcafee/service/unittest/one'
        channel2 = '/mcafee/service/unittest/two'
        # Create dummy service
        service_info = dxlclient.service.ServiceRegistrationInfo(
            service_type='/mcafee/service/unittest', client=self.client)
        service_info.add_topic(channel1, RequestCallback())
        service_info.add_topic(channel2, RequestCallback())

        # Register service in client
        self.client.register_service_async(service_info)
        # Check subscribed channels
        subscriptions = self.client.subscriptions
        assert channel1 in subscriptions, "Client wasn't subscribed to service channel"
        assert channel2 in subscriptions, "Client wasn't subscribed to service channel"

        self.client.unregister_service_async(service_info)
        subscriptions = self.client.subscriptions
        assert channel1 not in subscriptions, "Client wasn't unsubscribed to service channel"
        assert channel2 not in subscriptions, "Client wasn't unsubscribed to service channel"

    def test_client_register_service_unsuscribes_from_channel_by_guid(self):
        channel1 = '/mcafee/service/unittest/one'
        channel2 = '/mcafee/service/unittest/two'

        # Create dummy service
        service_info = dxlclient.service.ServiceRegistrationInfo(
            service_type='/mcafee/service/unittest', client=self.client)
        service_info.add_topic(channel1, RequestCallback())
        service_info.add_topic(channel2, RequestCallback())

        # Create same dummy service - different object
        service_info2 = service_info = dxlclient.service.ServiceRegistrationInfo(
            service_type='/mcafee/service/unittest', client=self.client)
        service_info._service_id = service_info.service_id
        service_info.add_topic(channel1, RequestCallback())
        service_info.add_topic(channel2, RequestCallback())

        # Register service in client
        self.client.register_service_async(service_info)

        # Check subscribed channels
        subscriptions = self.client.subscriptions
        assert channel1 in subscriptions, "Client wasn't subscribed to service channel"
        assert channel2 in subscriptions, "Client wasn't subscribed to service channel"

        self.client.unregister_service_async(service_info2)
        subscriptions = self.client.subscriptions
        assert channel1 not in subscriptions, "Client wasn't unsubscribed to service channel"
        assert channel2 not in subscriptions, "Client wasn't unsubscribed to service channel"
Ejemplo n.º 2
0
class DxlClientTest(unittest.TestCase):
    def setUp(self):
        self.config = DxlClientConfig(broker_ca_bundle=get_ca_bundle_pem(),
                                      cert_file=get_cert_file_pem(),
                                      private_key=get_dxl_private_key(),
                                      brokers=[])

        mqtt_client_patch = patch('paho.mqtt.client.Client')
        mqtt_client_patch.start()

        self.client = DxlClient(self.config)
        self.client._request_manager.wait_for_response = Mock(
            return_value=Response(request=None))

        self.test_channel = '/test/channel'

    def tearDown(self):
        self.client._connected = False
        self.client.destroy()
        patch.stopall()

    def test_client_raises_exception_on_connect_when_already_connecting(self):
        self.client._client.connect.side_effect = Exception("An exception!")
        self.client._thread = threading.Thread()
        self.assertEqual(self.client.connected, False)
        with self.assertRaises(DxlException):
            self.client.connect()
        self.client._thread = None

    def test_client_raises_exception_on_connect_when_already_connected(self):
        self.client._client.connect.side_effect = Exception("An exception!")
        self.client._connected = True
        with self.assertRaises(DxlException):
            self.client.connect()

    # The following test is too slow
    def test_client_disconnect_doesnt_raises_exception_on_disconnect_when_disconnected(
            self):
        self.assertEqual(self.client.connected, False)
        self.client.disconnect()
        self.client.disconnect()

    @parameterized.expand([
        # (connect + retries) * 2 = connect_count
        (0, 2),
        (1, 4),
        (2, 6),
    ])
    def test_client_retries_defines_how_many_times_the_client_retries_connection(
            self, retries, connect_count):
        # Client wont' connect ;)
        self.client._client.connect = Mock(
            side_effect=Exception('Could not connect'))
        # No delay between retries (faster unit tests)
        self.client.config.reconnect_delay = 0
        self.client._wait_for_policy_delay = 0

        broker = Broker(host_name='localhost')
        broker._parse(UuidGenerator.generate_id_as_string() +
                      ";9999;localhost;127.0.0.1")

        self.client.config.brokers = [broker]
        self.client.config.connect_retries = retries

        with self.assertRaises(DxlException):
            self.client.connect()
        self.assertEqual(self.client._client.connect.call_count, connect_count)

    def test_client_subscribe_adds_subscription_when_not_connected(self):
        self.client._client.subscribe = Mock(return_value=None)
        self.assertFalse(self.client.connected)

        self.client.subscribe(self.test_channel)
        self.assertTrue(self.test_channel in self.client.subscriptions)
        self.assertEqual(self.client._client.subscribe.call_count, 0)

    def test_client_unsubscribe_removes_subscription_when_not_connected(self):
        self.client._client.unsubscribe = Mock(return_value=None)
        self.assertFalse(self.client.connected)
        # Add subscription
        self.client.subscribe(self.test_channel)
        self.assertTrue(self.test_channel in self.client.subscriptions)
        # Remove subscription
        self.client.unsubscribe(self.test_channel)
        self.assertFalse(self.test_channel in self.client.subscriptions)

    def test_client_subscribe_doesnt_add_twice_same_channel(self):
        # Mock client.subscribe and is_connected
        self.client._client.subscribe = Mock(
            return_value=(mqtt.MQTT_ERR_SUCCESS, 2))
        self.client._connected = Mock(return_value=True)
        self.client._wait_packet_acked = Mock(return_value=None)

        # We always have the default (myself) channel
        self.assertEqual(len(self.client.subscriptions), 1)
        self.client.subscribe(self.test_channel)
        self.assertEqual(len(self.client.subscriptions), 2)
        self.client.subscribe(self.test_channel)
        self.assertEqual(len(self.client.subscriptions), 2)
        self.assertEqual(self.client._client.subscribe.call_count, 1)

    def test_client_handle_message_with_event_calls_event_callback(self):
        event_callback = EventCallback()
        event_callback.on_event = Mock()
        self.client.add_event_callback(self.test_channel, event_callback)
        # Create and process Event
        evt = Event(destination_topic=self.test_channel)._to_bytes()
        self.client._handle_message(self.test_channel, evt)
        # Check that callback was called
        self.assertEqual(event_callback.on_event.call_count, 1)
        self.client.remove_event_callback(self.test_channel, event_callback)
        self.client._handle_message(self.test_channel, evt)
        # Check that callback was not called again - because the event
        # callback was unregistered
        self.assertEqual(event_callback.on_event.call_count, 1)

    def test_client_handle_message_with_request_calls_request_callback(self):
        req_callback = RequestCallback()
        req_callback.on_request = Mock()
        self.client.add_request_callback(self.test_channel, req_callback)
        # Create and process Request
        req = Request(destination_topic=self.test_channel)._to_bytes()
        self.client._handle_message(self.test_channel, req)
        # Check that callback was called
        self.assertEqual(req_callback.on_request.call_count, 1)
        self.client.remove_request_callback(self.test_channel, req_callback)
        self.client._handle_message(self.test_channel, req)
        # Check that callback was not called again - because the request
        # callback was unregistered
        self.assertEqual(req_callback.on_request.call_count, 1)

    def test_client_handle_message_with_response_calls_response_callback(self):
        callback = ResponseCallback()
        callback.on_response = Mock()
        self.client.add_response_callback(self.test_channel, callback)
        # Create and process Response
        msg = Response(request=None)._to_bytes()
        self.client._handle_message(self.test_channel, msg)
        # Check that callback was called
        self.assertEqual(callback.on_response.call_count, 1)
        self.client.remove_response_callback(self.test_channel, callback)
        self.client._handle_message(self.test_channel, msg)
        # Check that callback was not called again - because the response
        # callback was unregistered
        self.assertEqual(callback.on_response.call_count, 1)

    def test_client_remove_call_for_unregistered_callback_does_not_error(self):
        callback = EventCallback()
        callback.on_event = Mock()
        callback2 = EventCallback()
        callback2.on_event = Mock()
        self.client.add_event_callback(self.test_channel, callback)
        self.client.add_event_callback(self.test_channel, callback2)
        self.client.remove_event_callback(self.test_channel, callback)
        self.client.remove_event_callback(self.test_channel, callback)

    def test_client_send_event_publishes_message_to_dxl_fabric(self):
        self.client._client.publish = Mock(return_value=None)
        # Create and process Request
        msg = Event(destination_topic="")
        self.client.send_event(msg)
        # Check that callback was called
        self.assertEqual(self.client._client.publish.call_count, 1)

    def test_client_send_request_publishes_message_to_dxl_fabric(self):
        self.client._client.publish = Mock(return_value=None)
        # Create and process Request
        msg = Request(destination_topic="")
        self.client._send_request(msg)
        # Check that callback was called
        self.assertEqual(self.client._client.publish.call_count, 1)

    def test_client_send_response_publishes_message_to_dxl_fabric(self):
        self.client._client.publish = Mock(return_value=None)
        # Create and process Request
        msg = Response(request=None)
        self.client.send_response(msg)
        # Check that callback was called
        self.assertEqual(self.client._client.publish.call_count, 1)

    def test_client_handles_error_response_and_fire_response_handler(self):
        self.client._fire_response = Mock(return_value=None)
        # Create and process Request
        msg = ErrorResponse(request=None,
                            error_code=666,
                            error_message="test message")
        payload = msg._to_bytes()
        # Handle error response message
        self.client._handle_message(self.test_channel, payload)
        # Check that message response was properly delivered to handler
        self.assertEqual(self.client._fire_response.call_count, 1)

    def test_client_subscribe_no_ack_raises_timeout(self):
        self.client._client.subscribe = Mock(
            return_value=(mqtt.MQTT_ERR_SUCCESS, 2))
        self.client._connected = Mock(return_value=True)
        with patch.object(DxlClient, '_MAX_PACKET_ACK_WAIT', 0.01):
            with self.assertRaises(WaitTimeoutException):
                self.client.subscribe(self.test_channel)

    def test_client_unsubscribe_no_ack_raises_timeout(self):
        self.client._client.subscribe = Mock(
            return_value=(mqtt.MQTT_ERR_SUCCESS, 2))
        self.client._client.unsubscribe = Mock(
            return_value=(mqtt.MQTT_ERR_SUCCESS, 3))
        self.client._connected = Mock(return_value=True)
        original_wait_packet_acked_func = self.client._wait_packet_acked
        self.client._wait_packet_acked = Mock(return_value=None)
        self.client.subscribe(self.test_channel)
        self.client._wait_packet_acked = original_wait_packet_acked_func
        with patch.object(DxlClient, '_MAX_PACKET_ACK_WAIT', 0.01):
            with self.assertRaises(WaitTimeoutException):
                self.client.unsubscribe(self.test_channel)

    # Service unit tests

    def test_client_register_service_subscribes_client_to_channel(self):
        channel = '/mcafee/service/unittest'

        # Create dummy service
        service_info = dxlclient.service.ServiceRegistrationInfo(
            service_type='/mcafee/service/unittest', client=self.client)

        # Add topics to the service
        service_info.add_topic(channel + "1", RequestCallback())
        service_info.add_topic(channel + "2", RequestCallback())
        service_info.add_topics(
            {channel + str(i): RequestCallback()
             for i in range(3, 6)})

        subscriptions_before_registration = self.client.subscriptions
        expected_subscriptions_after_registration = \
            sorted(subscriptions_before_registration +
                   tuple(channel + str(i) for i in range(1, 6)))

        # Register service in client
        self.client.register_service_async(service_info)
        # Check subscribed channels
        subscriptions_after_registration = self.client.subscriptions

        self.assertEqual(expected_subscriptions_after_registration,
                         sorted(subscriptions_after_registration))

    def test_client_wont_register_the_same_service_twice(self):
        service_info = dxlclient.service.ServiceRegistrationInfo(
            service_type='/mcafee/service/unittest', client=self.client)

        # Register service in client
        self.client.register_service_async(service_info)
        with self.assertRaises(dxlclient.DxlException):
            # Re-register service
            self.client.register_service_async(service_info)

    def test_client_register_service_sends_register_request_to_broker(self):
        service_info = dxlclient.service.ServiceRegistrationInfo(
            service_type='/mcafee/service/unittest', client=self.client)

        self.client._send_request = Mock(return_value=True)
        self.client._connected = Mock(return_value=True)

        # Register service in client
        self.client.register_service_async(service_info)
        time.sleep(2)
        # Check that method has been called
        self.assertTrue(self.client._send_request.called)

    def test_client_register_service_unsubscribes_client_to_channel(self):
        channel1 = '/mcafee/service/unittest/one'
        channel2 = '/mcafee/service/unittest/two'
        # Create dummy service
        service_info = dxlclient.service.ServiceRegistrationInfo(
            service_type='/mcafee/service/unittest', client=self.client)
        service_info.add_topic(channel1, RequestCallback())
        service_info.add_topic(channel2, RequestCallback())

        # Register service in client
        self.client.register_service_async(service_info)
        # Check subscribed channels
        subscriptions = self.client.subscriptions
        self.assertIn(channel1, subscriptions,
                      "Client wasn't subscribed to service channel")
        self.assertIn(channel2, subscriptions,
                      "Client wasn't subscribed to service channel")

        self.client.unregister_service_async(service_info)
        subscriptions = self.client.subscriptions
        self.assertNotIn(channel1, subscriptions,
                         "Client wasn't unsubscribed to service channel")
        self.assertNotIn(channel2, subscriptions,
                         "Client wasn't unsubscribed to service channel")

    def test_client_register_service_unsuscribes_from_channel_by_guid(self):
        channel1 = '/mcafee/service/unittest/one'
        channel2 = '/mcafee/service/unittest/two'

        # Create dummy service
        service_info = dxlclient.service.ServiceRegistrationInfo(
            service_type='/mcafee/service/unittest', client=self.client)
        service_info.add_topic(channel1, RequestCallback())
        service_info.add_topic(channel2, RequestCallback())

        # Create same dummy service - different object
        service_info2 = service_info = dxlclient.service.ServiceRegistrationInfo(
            service_type='/mcafee/service/unittest', client=self.client)
        service_info._service_id = service_info.service_id
        service_info.add_topic(channel1, RequestCallback())
        service_info.add_topic(channel2, RequestCallback())

        # Register service in client
        self.client.register_service_async(service_info)

        # Check subscribed channels
        subscriptions = self.client.subscriptions
        self.assertIn(channel1, subscriptions,
                      "Client wasn't subscribed to service channel")
        self.assertIn(channel2, subscriptions,
                      "Client wasn't subscribed to service channel")

        self.client.unregister_service_async(service_info2)
        subscriptions = self.client.subscriptions
        self.assertNotIn(channel1, subscriptions,
                         "Client wasn't unsubscribed to service channel")
        self.assertNotIn(channel2, subscriptions,
                         "Client wasn't unsubscribed to service channel")