Exemplo n.º 1
0
 def test_raises_value_error_invalid_topic(self, topic):
     # Manually instantiate Provider, do NOT mock paho client (paho generates this error)
     provider = MQTTProvider(
         client_id=fake_device_id, hostname=fake_hostname, username=fake_username
     )
     with pytest.raises(ValueError):
         provider.unsubscribe(topic)
Exemplo n.º 2
0
 def test_raises_value_error_invalid_payload(self, payload):
     # Manually instantiate Provider, do NOT mock paho client (paho generates this error)
     provider = MQTTProvider(
         client_id=fake_device_id, hostname=fake_hostname, username=fake_username
     )
     with pytest.raises(ValueError):
         provider.publish(topic=fake_topic, payload=payload, qos=fake_qos)
Exemplo n.º 3
0
    def test_configures_tls_context_with_default_certs(self, mocker, mock_mqtt_client, provider):
        mock_ssl_context_constructor = mocker.patch.object(ssl, "SSLContext")
        mock_ssl_context = mock_ssl_context_constructor.return_value

        provider = MQTTProvider(
            client_id=fake_device_id, hostname=fake_hostname, username=fake_username
        )
        provider.connect(fake_password)

        assert mock_ssl_context.load_default_certs.call_count == 1
        assert mock_ssl_context.load_default_certs.call_args == mocker.call()
Exemplo n.º 4
0
    def test_configures_tls_context_with_ca_certs(self, mocker, mock_mqtt_client, provider):
        mock_ssl_context_constructor = mocker.patch.object(ssl, "SSLContext")
        mock_ssl_context = mock_ssl_context_constructor.return_value
        ca_cert = "dummy_certificate"

        provider = MQTTProvider(
            client_id=fake_device_id,
            hostname=fake_hostname,
            username=fake_username,
            ca_cert=ca_cert,
        )
        provider.connect(fake_password)

        assert mock_ssl_context.load_verify_locations.call_count == 1
        assert mock_ssl_context.load_verify_locations.call_args == mocker.call(cadata=ca_cert)
Exemplo n.º 5
0
    def test_operation_infrastructure_set_up(self, mocker):
        provider = MQTTProvider(
            client_id=fake_device_id, hostname=fake_hostname, username=fake_username
        )

        assert provider._pending_operation_callbacks == {}
        assert provider._unknown_operation_responses == {}
Exemplo n.º 6
0
    def test_instantiates_mqtt_client(self, mocker):
        mock_mqtt_client_constructor = mocker.patch.object(mqtt, "Client")

        MQTTProvider(client_id=fake_device_id, hostname=fake_hostname, username=fake_username)

        assert mock_mqtt_client_constructor.call_count == 1
        assert mock_mqtt_client_constructor.call_args == mocker.call(
            client_id=fake_device_id, clean_session=False, protocol=mqtt.MQTTv311
        )
Exemplo n.º 7
0
    def test_handler_callbacks_set_to_none(self, mocker):
        mocker.patch.object(mqtt, "Client")

        provider = MQTTProvider(
            client_id=fake_device_id, hostname=fake_hostname, username=fake_username
        )

        assert provider.on_mqtt_connected is None
        assert provider.on_mqtt_disconnected is None
        assert provider.on_mqtt_message_received is None
Exemplo n.º 8
0
    def test_sets_paho_callbacks(self, mocker):
        mock_mqtt_client = mocker.patch.object(mqtt, "Client").return_value

        MQTTProvider(client_id=fake_device_id, hostname=fake_hostname, username=fake_username)

        assert callable(mock_mqtt_client.on_connect)
        assert callable(mock_mqtt_client.on_disconnect)
        assert callable(mock_mqtt_client.on_subscribe)
        assert callable(mock_mqtt_client.on_unsubscribe)
        assert callable(mock_mqtt_client.on_publish)
        assert callable(mock_mqtt_client.on_message)
Exemplo n.º 9
0
def provider(mock_mqtt_client):
    # Implicitly imports the mocked Paho MQTT Client from mock_mqtt_client
    return MQTTProvider(client_id=fake_device_id, hostname=fake_hostname, username=fake_username)
    def _run_op(self, op):
        if isinstance(op, pipeline_ops_mqtt.SetConnectionArgs):
            # SetConnectionArgs is where we create our MQTTProvider object and set
            # all of its properties.
            logger.info("{}({}): got connection args".format(self.name, op.name))
            self.hostname = op.hostname
            self.username = op.username
            self.client_id = op.client_id
            self.ca_cert = op.ca_cert
            self.provider = MQTTProvider(
                client_id=self.client_id,
                hostname=self.hostname,
                username=self.username,
                ca_cert=self.ca_cert,
            )
            self.provider.on_mqtt_connected = self.on_connected
            self.provider.on_mqtt_disconnected = self.on_disconnected
            self.provider.on_mqtt_message_received = self._on_message_received
            self.pipeline_root.provider = self.provider
            self.complete_op(op)

        elif isinstance(op, pipeline_ops_base.SetSasToken):
            # When we get a sas token from above, we just save it for later
            logger.info("{}({}): got password".format(self.name, op.name))
            self.sas_token = op.sas_token
            self.complete_op(op)

        elif isinstance(op, pipeline_ops_base.Connect):
            logger.info("{}({}): conneting".format(self.name, op.name))

            def on_connected():
                logger.info("{}({}): on_connected.  completing op.".format(self.name, op.name))
                self.provider.on_mqtt_connected = self.on_connected
                self.on_connected()
                self.complete_op(op)

            self.provider.on_mqtt_connected = on_connected
            self.provider.connect(self.sas_token)

        elif isinstance(op, pipeline_ops_base.Disconnect):
            logger.info("{}({}): disconneting".format(self.name, op.name))

            def on_disconnected():
                logger.info("{}({}): on_disconnected.  completing op.".format(self.name, op.name))
                self.provider.on_mqtt_disconnected = self.on_disconnected
                self.on_disconnected()
                self.complete_op(op)

            self.provider.on_mqtt_disconnected = on_disconnected
            self.provider.disconnect()

        elif isinstance(op, pipeline_ops_mqtt.Publish):
            logger.info("{}({}): publishing on {}".format(self.name, op.name, op.topic))

            def on_published():
                logger.info("{}({}): PUBACK received. completing op.".format(self.name, op.name))
                self.complete_op(op)

            self.provider.publish(topic=op.topic, payload=op.payload, callback=on_published)

        elif isinstance(op, pipeline_ops_mqtt.Subscribe):
            logger.info("{}({}): subscribing to {}".format(self.name, op.name, op.topic))

            def on_subscribed():
                logger.info("{}({}): SUBACK received. completing op.".format(self.name, op.name))
                self.complete_op(op)

            self.provider.subscribe(topic=op.topic, callback=on_subscribed)

        elif isinstance(op, pipeline_ops_mqtt.Unsubscribe):
            logger.info("{}({}): unsubscribing from {}".format(self.name, op.name, op.topic))

            def on_unsubscribed():
                logger.info("{}({}): UNSUBACK received.  completing op.".format(self.name, op.name))
                self.complete_op(op)

            self.provider.unsubscribe(topic=op.topic, callback=on_unsubscribed)

        else:
            self.continue_op(op)
class Provider(PipelineStage):
    """
    PipelineStage object which is responsible for interfacing with the MQTT provider object.
    This stage handles all MQTT operations and any other operations (such as Connect) which
    is not in the MQTT group of operations, but can only be run at the protocol level.
    """

    def _run_op(self, op):
        if isinstance(op, pipeline_ops_mqtt.SetConnectionArgs):
            # SetConnectionArgs is where we create our MQTTProvider object and set
            # all of its properties.
            logger.info("{}({}): got connection args".format(self.name, op.name))
            self.hostname = op.hostname
            self.username = op.username
            self.client_id = op.client_id
            self.ca_cert = op.ca_cert
            self.provider = MQTTProvider(
                client_id=self.client_id,
                hostname=self.hostname,
                username=self.username,
                ca_cert=self.ca_cert,
            )
            self.provider.on_mqtt_connected = self.on_connected
            self.provider.on_mqtt_disconnected = self.on_disconnected
            self.provider.on_mqtt_message_received = self._on_message_received
            self.pipeline_root.provider = self.provider
            self.complete_op(op)

        elif isinstance(op, pipeline_ops_base.SetSasToken):
            # When we get a sas token from above, we just save it for later
            logger.info("{}({}): got password".format(self.name, op.name))
            self.sas_token = op.sas_token
            self.complete_op(op)

        elif isinstance(op, pipeline_ops_base.Connect):
            logger.info("{}({}): conneting".format(self.name, op.name))

            def on_connected():
                logger.info("{}({}): on_connected.  completing op.".format(self.name, op.name))
                self.provider.on_mqtt_connected = self.on_connected
                self.on_connected()
                self.complete_op(op)

            self.provider.on_mqtt_connected = on_connected
            self.provider.connect(self.sas_token)

        elif isinstance(op, pipeline_ops_base.Disconnect):
            logger.info("{}({}): disconneting".format(self.name, op.name))

            def on_disconnected():
                logger.info("{}({}): on_disconnected.  completing op.".format(self.name, op.name))
                self.provider.on_mqtt_disconnected = self.on_disconnected
                self.on_disconnected()
                self.complete_op(op)

            self.provider.on_mqtt_disconnected = on_disconnected
            self.provider.disconnect()

        elif isinstance(op, pipeline_ops_mqtt.Publish):
            logger.info("{}({}): publishing on {}".format(self.name, op.name, op.topic))

            def on_published():
                logger.info("{}({}): PUBACK received. completing op.".format(self.name, op.name))
                self.complete_op(op)

            self.provider.publish(topic=op.topic, payload=op.payload, callback=on_published)

        elif isinstance(op, pipeline_ops_mqtt.Subscribe):
            logger.info("{}({}): subscribing to {}".format(self.name, op.name, op.topic))

            def on_subscribed():
                logger.info("{}({}): SUBACK received. completing op.".format(self.name, op.name))
                self.complete_op(op)

            self.provider.subscribe(topic=op.topic, callback=on_subscribed)

        elif isinstance(op, pipeline_ops_mqtt.Unsubscribe):
            logger.info("{}({}): unsubscribing from {}".format(self.name, op.name, op.topic))

            def on_unsubscribed():
                logger.info("{}({}): UNSUBACK received.  completing op.".format(self.name, op.name))
                self.complete_op(op)

            self.provider.unsubscribe(topic=op.topic, callback=on_unsubscribed)

        else:
            self.continue_op(op)

    def _on_message_received(self, topic, payload):
        """
        Handler that gets called by the protocol library when an incoming message arrives.
        Convert that message into a pipeline event and pass it up for someone to handle.
        """
        self.handle_pipeline_event(
            pipeline_events_mqtt.IncomingMessage(topic=topic, payload=payload)
        )