Example #1
0
class MQTTTransportStage(PipelineStage):
    """
    PipelineStage object which is responsible for interfacing with the MQTT protocol wrapper object.
    This stage handles all MQTT operations and any other operations (such as ConnectOperation) which
    is not in the MQTT group of operations, but can only be run at the protocol level.
    """
    @pipeline_thread.runs_on_pipeline_thread
    def _cancel_active_connect_disconnect_ops(self):
        """
        Cancel any running connect or disconnect op.  Since our ability to "cancel" is fairly limited,
        all this does (for now) is to fail the operation
        """

        ops_to_cancel = []
        if self._active_connect_op:
            # TODO: should this actually run a cancel call on the op?
            ops_to_cancel.add(self._active_connect_op)
            self._active_connect_op = None
        if self._active_disconnect_op:
            ops_to_cancel.add(self._active_disconnect_op)
            self._active_disconnect_op = None

        for op in ops_to_cancel:
            op.error = errors.PipelineError(
                "Cancelling because new ConnectOperation, DisconnectOperation, or ReconnectOperation was issued"
            )
            operation_flow.complete_op(stage=self, op=op)

    @pipeline_thread.runs_on_pipeline_thread
    def _execute_op(self, op):
        if isinstance(op, pipeline_ops_mqtt.SetMQTTConnectionArgsOperation):
            # pipeline_ops_mqtt.SetMQTTConnectionArgsOperation is where we create our MQTTTransport object and set
            # all of its properties.
            logger.info("{}({}): got connection args".format(
                self.name, op.name))
            self.hostname = op.hostname
            self.username = op.username
            self.client_id = op.client_id
            self.ca_cert = op.ca_cert
            self.sas_token = op.sas_token
            self.client_cert = op.client_cert

            self.transport = MQTTTransport(
                client_id=self.client_id,
                hostname=self.hostname,
                username=self.username,
                ca_cert=self.ca_cert,
                x509_cert=self.client_cert,
            )
            self.transport.on_mqtt_connected_handler = self._on_mqtt_connected
            self.transport.on_mqtt_connection_failure_handler = self._on_mqtt_connection_failure
            self.transport.on_mqtt_disconnected_handler = self._on_mqtt_disconnected
            self.transport.on_mqtt_message_received_handler = self._on_mqtt_message_received
            self._active_connect_op = None
            self._active_disconnect_op = None
            self.pipeline_root.transport = self.transport
            operation_flow.complete_op(self, op)

        elif isinstance(op, pipeline_ops_base.ConnectOperation):
            logger.info("{}({}): connecting".format(self.name, op.name))

            self._cancel_active_connect_disconnect_ops()
            self._active_connect_op = op
            try:
                self.transport.connect(password=self.sas_token)
            except Exception as e:
                self._active_connect_op = None
                raise e

        elif isinstance(op, pipeline_ops_base.ReconnectOperation):
            logger.info("{}({}): reconnecting".format(self.name, op.name))

            # We set _active_connect_op here because a reconnect is the same as a connect for "active operation" tracking purposes.
            self._cancel_active_connect_disconnect_ops()
            self._active_connect_op = op
            try:
                self.transport.reconnect(password=self.sas_token)
            except Exception as e:
                self._active_connect_op = None
                raise e

        elif isinstance(op, pipeline_ops_base.DisconnectOperation):
            logger.info("{}({}): disconnecting".format(self.name, op.name))

            self._cancel_active_connect_disconnect_ops()
            self._active_disconnect_op = op
            try:
                self.transport.disconnect()
            except Exception as e:
                self._active_disconnect_op = None
                raise e

        elif isinstance(op, pipeline_ops_mqtt.MQTTPublishOperation):
            logger.info("{}({}): publishing on {}".format(
                self.name, op.name, op.topic))

            @pipeline_thread.invoke_on_pipeline_thread_nowait
            def on_published():
                logger.info("{}({}): PUBACK received. completing op.".format(
                    self.name, op.name))
                operation_flow.complete_op(self, op)

            self.transport.publish(topic=op.topic,
                                   payload=op.payload,
                                   callback=on_published)

        elif isinstance(op, pipeline_ops_mqtt.MQTTSubscribeOperation):
            logger.info("{}({}): subscribing to {}".format(
                self.name, op.name, op.topic))

            @pipeline_thread.invoke_on_pipeline_thread_nowait
            def on_subscribed():
                logger.info("{}({}): SUBACK received. completing op.".format(
                    self.name, op.name))
                operation_flow.complete_op(self, op)

            self.transport.subscribe(topic=op.topic, callback=on_subscribed)

        elif isinstance(op, pipeline_ops_mqtt.MQTTUnsubscribeOperation):
            logger.info("{}({}): unsubscribing from {}".format(
                self.name, op.name, op.topic))

            @pipeline_thread.invoke_on_pipeline_thread_nowait
            def on_unsubscribed():
                logger.info(
                    "{}({}): UNSUBACK received.  completing op.".format(
                        self.name, op.name))
                operation_flow.complete_op(self, op)

            self.transport.unsubscribe(topic=op.topic,
                                       callback=on_unsubscribed)

        else:
            operation_flow.pass_op_to_next_stage(self, op)

    @pipeline_thread.invoke_on_pipeline_thread_nowait
    def _on_mqtt_message_received(self, topic, payload):
        """
        Handler that gets called by the protocol library when an incoming message arrives.
        Convert that message into a pipeline event and pass it up for someone to handle.
        """
        operation_flow.pass_event_to_previous_stage(
            stage=self,
            event=pipeline_events_mqtt.IncomingMQTTMessageEvent(
                topic=topic, payload=payload),
        )

    @pipeline_thread.invoke_on_pipeline_thread_nowait
    def _on_mqtt_connected(self):
        """
        Handler that gets called by the transport when it connects.
        """
        logger.info("_on_mqtt_connected called")
        # self.on_connected() tells other pipeilne stages that we're connected.  Do this before
        # we do anything else (in case upper stages have any "are we connected" logic.
        self.on_connected()
        if self._active_connect_op:
            logger.info("completing connect op")
            op = self._active_connect_op
            self._active_connect_op = None
            operation_flow.complete_op(stage=self, op=op)
        else:
            logger.warning("Connection was unexpected")

    @pipeline_thread.invoke_on_pipeline_thread_nowait
    def _on_mqtt_connection_failure(self, cause):
        """
        Handler that gets called by the transport when a connection fails.
        """

        logger.error("{}: _on_mqtt_connection_failure called: {}".format(
            self.name, cause))
        if self._active_connect_op:
            logger.info("{}: failing connect op".format(self.name))
            op = self._active_connect_op
            self._active_connect_op = None
            op.error = cause
            operation_flow.complete_op(stage=self, op=op)
        else:
            logger.warning("{}: Connection failure was unexpected".format(
                self.name))
            unhandled_exceptions.exception_caught_in_background_thread(cause)

    @pipeline_thread.invoke_on_pipeline_thread_nowait
    def _on_mqtt_disconnected(self, cause):
        """
        Handler that gets called by the transport when the transport disconnects.
        """
        logger.error("{}: _on_mqtt_disconnect called: {}".format(
            self.name, cause))

        # self.on_disconnected() tells other pipeilne stages that we're disconnected.  Do this before
        # we do anything else (in case upper stages have any "are we connected" logic.
        self.on_disconnected()

        # regardless of the cause, we wrap it in a ConnectionDroppedError object because that's
        # the real problem at this point.
        if cause:
            try:
                six.raise_from(errors.ConnectionDroppedError, cause)
            except errors.ConnectionDroppedError as e:
                cause = e

        if self._active_disconnect_op:
            logger.info("{}: completing disconnect op".format(self.name))
            op = self._active_disconnect_op
            self._active_disconnect_op = None
            op.error = cause
            operation_flow.complete_op(stage=self, op=op)
        else:
            logger.warning("{}: disconnection was unexpected".format(
                self.name))
            unhandled_exceptions.exception_caught_in_background_thread(cause)
class MQTTTransportStage(PipelineStage):
    """
    PipelineStage object which is responsible for interfacing with the MQTT protocol wrapper object.
    This stage handles all MQTT operations and any other operations (such as ConnectOperation) which
    is not in the MQTT group of operations, but can only be run at the protocol level.
    """
    def __init__(self):
        super(MQTTTransportStage, self).__init__()

        # The transport will be instantiated when Connection Args are received
        self.transport = None

        self._pending_connection_op = None

    @pipeline_thread.runs_on_pipeline_thread
    def _cancel_pending_connection_op(self, error=None):
        """
        Cancel any running connect, disconnect or reauthorize_connection op. Since our ability to "cancel" is fairly limited,
        all this does (for now) is to fail the operation
        """

        op = self._pending_connection_op
        if op:
            # NOTE: This code path should NOT execute in normal flow. There should never already be a pending
            # connection op when another is added, due to the SerializeConnectOps stage.
            # If this block does execute, there is a bug in the codebase.
            if not error:
                error = pipeline_exceptions.OperationCancelled(
                    "Cancelling because new ConnectOperation, DisconnectOperation, or ReauthorizeConnectionOperation was issued"
                )
            self._cancel_connection_watchdog(op)
            op.complete(error=error)
            self._pending_connection_op = None

    @pipeline_thread.runs_on_pipeline_thread
    def _start_connection_watchdog(self, connection_op):
        logger.debug("{}({}): Starting watchdog".format(
            self.name, connection_op.name))

        self_weakref = weakref.ref(self)
        op_weakref = weakref.ref(connection_op)

        @pipeline_thread.invoke_on_pipeline_thread
        def watchdog_function():
            this = self_weakref()
            op = op_weakref()
            if this and op and this._pending_connection_op is op:
                logger.info(
                    "{}({}): Connection watchdog expired.  Cancelling op".
                    format(this.name, op.name))
                this.transport.disconnect()
                if this.pipeline_root.connected:
                    logger.info(
                        "{}({}): Pipeline is still connected on watchdog expiration.  Sending DisconnectedEvent"
                        .format(this.name, op.name))
                    this.send_event_up(
                        pipeline_events_base.DisconnectedEvent())
                this._cancel_pending_connection_op(
                    error=pipeline_exceptions.OperationCancelled(
                        "Transport timeout on connection operation"))

        connection_op.watchdog_timer = threading.Timer(WATCHDOG_INTERVAL,
                                                       watchdog_function)
        connection_op.watchdog_timer.daemon = True
        connection_op.watchdog_timer.start()

    @pipeline_thread.runs_on_pipeline_thread
    def _cancel_connection_watchdog(self, op):
        try:
            if op.watchdog_timer:
                logger.debug("{}({}): cancelling watchdog".format(
                    self.name, op.name))
                op.watchdog_timer.cancel()
                op.watchdog_timer = None
        except AttributeError:
            pass

    @pipeline_thread.runs_on_pipeline_thread
    def _run_op(self, op):
        if isinstance(op, pipeline_ops_base.InitializePipelineOperation):

            # If there is a gateway hostname, use that as the hostname for connection,
            # rather than the hostname itself
            if self.pipeline_root.pipeline_configuration.gateway_hostname:
                logger.debug(
                    "Gateway Hostname Present. Setting Hostname to: {}".format(
                        self.pipeline_root.pipeline_configuration.
                        gateway_hostname))
                hostname = self.pipeline_root.pipeline_configuration.gateway_hostname
            else:
                logger.debug(
                    "Gateway Hostname not present. Setting Hostname to: {}".
                    format(self.pipeline_root.pipeline_configuration.hostname))
                hostname = self.pipeline_root.pipeline_configuration.hostname

            # Create the Transport object, set it's handlers
            logger.debug("{}({}): got connection args".format(
                self.name, op.name))
            self.transport = MQTTTransport(
                client_id=op.client_id,
                hostname=hostname,
                username=op.username,
                server_verification_cert=self.pipeline_root.
                pipeline_configuration.server_verification_cert,
                x509_cert=self.pipeline_root.pipeline_configuration.x509,
                websockets=self.pipeline_root.pipeline_configuration.
                websockets,
                cipher=self.pipeline_root.pipeline_configuration.cipher,
                proxy_options=self.pipeline_root.pipeline_configuration.
                proxy_options,
            )
            self.transport.on_mqtt_connected_handler = CallableWeakMethod(
                self, "_on_mqtt_connected")
            self.transport.on_mqtt_connection_failure_handler = CallableWeakMethod(
                self, "_on_mqtt_connection_failure")
            self.transport.on_mqtt_disconnected_handler = CallableWeakMethod(
                self, "_on_mqtt_disconnected")
            self.transport.on_mqtt_message_received_handler = CallableWeakMethod(
                self, "_on_mqtt_message_received")

            # There can only be one pending connection operation (Connect, ReauthorizeConnection, Disconnect)
            # at a time. The existing one must be completed or canceled before a new one is set.

            # Currently, this means that if, say, a connect operation is the pending op and is executed
            # but another connection op is begins by the time the CONNACK is received, the original
            # operation will be cancelled, but the CONNACK for it will still be received, and complete the
            # NEW operation. This is not desirable, but it is how things currently work.

            # We are however, checking the type, so the CONNACK from a cancelled Connect, cannot successfully
            # complete a Disconnect operation.
            self._pending_connection_op = None

            op.complete()

        elif isinstance(op, pipeline_ops_base.ConnectOperation):
            logger.info("{}({}): connecting".format(self.name, op.name))

            self._cancel_pending_connection_op()
            self._pending_connection_op = op
            self._start_connection_watchdog(op)
            # Use SasToken as password if present. If not present (e.g. using X509),
            # then no password is required because auth is handled via other means.
            if self.pipeline_root.pipeline_configuration.sastoken:
                password = str(
                    self.pipeline_root.pipeline_configuration.sastoken)
            else:
                password = None
            try:
                self.transport.connect(password=password)
            except Exception as e:
                logger.error("transport.connect raised error")
                logger.error(traceback.format_exc())
                self._cancel_connection_watchdog(op)
                self._pending_connection_op = None
                op.complete(error=e)

        elif isinstance(
                op, pipeline_ops_base.DisconnectOperation) or isinstance(
                    op, pipeline_ops_base.ReauthorizeConnectionOperation):
            logger.info("{}({}): disconnecting or reauthorizing".format(
                self.name, op.name))

            self._cancel_pending_connection_op()
            self._pending_connection_op = op
            # We don't need a watchdog on disconnect because there's no callback to wait for
            # and we respond to a watchdog timeout by calling disconnect, which is what we're
            # already doing.

            try:
                self.transport.disconnect()
            except Exception as e:
                logger.error("transport.disconnect raised error")
                logger.error(traceback.format_exc())
                self._pending_connection_op = None
                op.complete(error=e)

        elif isinstance(op, pipeline_ops_mqtt.MQTTPublishOperation):
            logger.info("{}({}): publishing on {}".format(
                self.name, op.name, op.topic))

            @pipeline_thread.invoke_on_pipeline_thread_nowait
            def on_published():
                logger.debug("{}({}): PUBACK received. completing op.".format(
                    self.name, op.name))
                op.complete()

            try:
                self.transport.publish(topic=op.topic,
                                       payload=op.payload,
                                       callback=on_published)
            except transport_exceptions.ConnectionDroppedError:
                self.send_event_up(pipeline_events_base.DisconnectedEvent())
                raise

        elif isinstance(op, pipeline_ops_mqtt.MQTTSubscribeOperation):
            logger.info("{}({}): subscribing to {}".format(
                self.name, op.name, op.topic))

            @pipeline_thread.invoke_on_pipeline_thread_nowait
            def on_subscribed():
                logger.debug("{}({}): SUBACK received. completing op.".format(
                    self.name, op.name))
                op.complete()

            try:
                self.transport.subscribe(topic=op.topic,
                                         callback=on_subscribed)
            except transport_exceptions.ConnectionDroppedError:
                self.send_event_up(pipeline_events_base.DisconnectedEvent())
                raise

        elif isinstance(op, pipeline_ops_mqtt.MQTTUnsubscribeOperation):
            logger.info("{}({}): unsubscribing from {}".format(
                self.name, op.name, op.topic))

            @pipeline_thread.invoke_on_pipeline_thread_nowait
            def on_unsubscribed():
                logger.debug(
                    "{}({}): UNSUBACK received.  completing op.".format(
                        self.name, op.name))
                op.complete()

            try:
                self.transport.unsubscribe(topic=op.topic,
                                           callback=on_unsubscribed)
            except transport_exceptions.ConnectionDroppedError:
                self.send_event_up(pipeline_events_base.DisconnectedEvent())
                raise

        else:
            # This code block should not be reached in correct program flow.
            # This will raise an error when executed.
            self.send_op_down(op)

    @pipeline_thread.invoke_on_pipeline_thread_nowait
    def _on_mqtt_message_received(self, topic, payload):
        """
        Handler that gets called by the protocol library when an incoming message arrives.
        Convert that message into a pipeline event and pass it up for someone to handle.
        """
        logger.debug("{}: message received on topic {}".format(
            self.name, topic))
        self.send_event_up(
            pipeline_events_mqtt.IncomingMQTTMessageEvent(topic=topic,
                                                          payload=payload))

    @pipeline_thread.invoke_on_pipeline_thread_nowait
    def _on_mqtt_connected(self):
        """
        Handler that gets called by the transport when it connects.
        """
        logger.info("_on_mqtt_connected called")
        # Send an event to tell other pipeline stages that we're connected. Do this before
        # we do anything else (in case upper stages have any "are we connected" logic.
        self.send_event_up(pipeline_events_base.ConnectedEvent())

        if isinstance(self._pending_connection_op,
                      pipeline_ops_base.ConnectOperation):
            logger.debug("completing connect op")
            op = self._pending_connection_op
            self._cancel_connection_watchdog(op)
            self._pending_connection_op = None
            op.complete()
        else:
            # This should indicate something odd is going on.
            # If this occurs, either a connect was completed while there was no pending op,
            # OR that a connect was completed while a disconnect op was pending
            logger.info("Connection was unexpected")

    @pipeline_thread.invoke_on_pipeline_thread_nowait
    def _on_mqtt_connection_failure(self, cause):
        """
        Handler that gets called by the transport when a connection fails.

        :param Exception cause: The Exception that caused the connection failure.
        """

        logger.info("{}: _on_mqtt_connection_failure called: {}".format(
            self.name, cause))

        if isinstance(self._pending_connection_op,
                      pipeline_ops_base.ConnectOperation):
            logger.debug("{}: failing connect op".format(self.name))
            op = self._pending_connection_op
            self._cancel_connection_watchdog(op)
            self._pending_connection_op = None
            op.complete(error=cause)
        else:
            logger.info("{}: Connection failure was unexpected".format(
                self.name))
            handle_exceptions.swallow_unraised_exception(
                cause,
                log_msg="Unexpected connection failure.  Safe to ignore.",
                log_lvl="info")

    @pipeline_thread.invoke_on_pipeline_thread_nowait
    def _on_mqtt_disconnected(self, cause=None):
        """
        Handler that gets called by the transport when the transport disconnects.

        :param Exception cause: The Exception that caused the disconnection, if any (optional)
        """
        if cause:
            logger.info("{}: _on_mqtt_disconnect called: {}".format(
                self.name, cause))
        else:
            logger.info("{}: _on_mqtt_disconnect called".format(self.name))

        # Send an event to tell other pipeilne stages that we're disconnected. Do this before
        # we do anything else (in case upper stages have any "are we connected" logic.)
        self.send_event_up(pipeline_events_base.DisconnectedEvent())

        if self._pending_connection_op:
            # on_mqtt_disconnected will cause any pending connect op to complete.  This is how Paho
            # behaves when there is a connection error, and it also makes sense that on_mqtt_disconnected
            # would cause a pending connection op to fail.
            logger.debug("{}: completing pending {} op".format(
                self.name, self._pending_connection_op.name))
            op = self._pending_connection_op
            self._cancel_connection_watchdog(op)
            self._pending_connection_op = None

            if isinstance(
                    op, pipeline_ops_base.DisconnectOperation) or isinstance(
                        op, pipeline_ops_base.ReauthorizeConnectionOperation):
                # Swallow any errors if we intended to disconnect - even if something went wrong, we
                # got to the state we wanted to be in!
                if cause:
                    handle_exceptions.swallow_unraised_exception(
                        cause,
                        log_msg=
                        "Unexpected disconnect with error while disconnecting - swallowing error",
                    )
                op.complete()
            else:
                if cause:
                    op.complete(error=cause)
                else:
                    op.complete(
                        error=transport_exceptions.ConnectionDroppedError(
                            "transport disconnected"))
        else:
            logger.info("{}: disconnection was unexpected".format(self.name))
            # Regardless of cause, it is now a ConnectionDroppedError.  log it and swallow it.
            # Higher layers will see that we're disconencted and reconnect as necessary.
            e = transport_exceptions.ConnectionDroppedError(cause=cause)
            handle_exceptions.swallow_unraised_exception(
                e,
                log_msg=
                "Unexpected disconnection.  Safe to ignore since other stages will reconnect.",
                log_lvl="info",
            )
class MQTTTransportStage(PipelineStage):
    """
    PipelineStage object which is responsible for interfacing with the MQTT protocol wrapper object.
    This stage handles all MQTT operations and any other operations (such as ConnectOperation) which
    is not in the MQTT group of operations, but can only be run at the protocol level.
    """
    @pipeline_thread.runs_on_pipeline_thread
    def _cancel_pending_connection_op(self):
        """
        Cancel any running connect, disconnect or reconnect op. Since our ability to "cancel" is fairly limited,
        all this does (for now) is to fail the operation
        """

        op = self._pending_connection_op
        if op:
            # TODO: should this actually run a cancel call on the op?
            op.error = errors.PipelineError(
                "Cancelling because new ConnectOperation, DisconnectOperation, or ReconnectOperation was issued"
            )
            operation_flow.complete_op(stage=self, op=op)
            self._pending_connection_op = None

    @pipeline_thread.runs_on_pipeline_thread
    def _execute_op(self, op):
        if isinstance(op, pipeline_ops_mqtt.SetMQTTConnectionArgsOperation):
            # pipeline_ops_mqtt.SetMQTTConnectionArgsOperation is where we create our MQTTTransport object and set
            # all of its properties.
            logger.info("{}({}): got connection args".format(
                self.name, op.name))
            self.hostname = op.hostname
            self.username = op.username
            self.client_id = op.client_id
            self.ca_cert = op.ca_cert
            self.sas_token = op.sas_token
            self.client_cert = op.client_cert

            self.transport = MQTTTransport(
                client_id=self.client_id,
                hostname=self.hostname,
                username=self.username,
                ca_cert=self.ca_cert,
                x509_cert=self.client_cert,
            )
            self.transport.on_mqtt_connected_handler = self._on_mqtt_connected
            self.transport.on_mqtt_connection_failure_handler = self._on_mqtt_connection_failure
            self.transport.on_mqtt_disconnected_handler = self._on_mqtt_disconnected
            self.transport.on_mqtt_message_received_handler = self._on_mqtt_message_received

            # There can only be one pending connection operation (Connect, Reconnect, Disconnect)
            # at a time. The existing one must be completed or canceled before a new one is set.

            # Currently, this means that if, say, a connect operation is the pending op and is executed
            # but another connection op is begins by the time the CONACK is received, the original
            # operation will be cancelled, but the CONACK for it will still be received, and complete the
            # NEW operation. This is not desirable, but it is how things currently work.

            # We are however, checking the type, so the CONACK from a cancelled Connect, cannot successfully
            # complete a Disconnect operation.
            self._pending_connection_op = None

            self.pipeline_root.transport = self.transport
            operation_flow.complete_op(self, op)

        elif isinstance(op, pipeline_ops_base.UpdateSasTokenOperation):
            logger.info("{}({}): saving sas token and completing".format(
                self.name, op.name))
            self.sas_token = op.sas_token
            operation_flow.complete_op(self, op)

        elif isinstance(op, pipeline_ops_base.ConnectOperation):
            logger.info("{}({}): connecting".format(self.name, op.name))

            self._cancel_pending_connection_op()
            self._pending_connection_op = op
            try:
                self.transport.connect(password=self.sas_token)
            except Exception as e:
                logger.error("transport.connect raised error", exc_info=True)
                self._pending_connection_op = None
                op.error = e
                operation_flow.complete_op(self, op)

        elif isinstance(op, pipeline_ops_base.ReconnectOperation):
            logger.info("{}({}): reconnecting".format(self.name, op.name))

            # We set _active_connect_op here because a reconnect is the same as a connect for "active operation" tracking purposes.
            self._cancel_pending_connection_op()
            self._pending_connection_op = op
            try:
                self.transport.reconnect(password=self.sas_token)
            except Exception as e:
                logger.error("transport.reconnect raised error", exc_info=True)
                self._pending_connection_op = None
                op.error = e
                operation_flow.complete_op(self, op)

        elif isinstance(op, pipeline_ops_base.DisconnectOperation):
            logger.info("{}({}): disconnecting".format(self.name, op.name))

            self._cancel_pending_connection_op()
            self._pending_connection_op = op
            try:
                self.transport.disconnect()
            except Exception as e:
                logger.error("transport.disconnect raised error",
                             exc_info=True)
                self._pending_connection_op = None
                op.error = e
                operation_flow.complete_op(self, op)

        elif isinstance(op, pipeline_ops_mqtt.MQTTPublishOperation):
            logger.info("{}({}): publishing on {}".format(
                self.name, op.name, op.topic))

            @pipeline_thread.invoke_on_pipeline_thread_nowait
            def on_published():
                logger.info("{}({}): PUBACK received. completing op.".format(
                    self.name, op.name))
                operation_flow.complete_op(self, op)

            self.transport.publish(topic=op.topic,
                                   payload=op.payload,
                                   callback=on_published)

        elif isinstance(op, pipeline_ops_mqtt.MQTTSubscribeOperation):
            logger.info("{}({}): subscribing to {}".format(
                self.name, op.name, op.topic))

            @pipeline_thread.invoke_on_pipeline_thread_nowait
            def on_subscribed():
                logger.info("{}({}): SUBACK received. completing op.".format(
                    self.name, op.name))
                operation_flow.complete_op(self, op)

            self.transport.subscribe(topic=op.topic, callback=on_subscribed)

        elif isinstance(op, pipeline_ops_mqtt.MQTTUnsubscribeOperation):
            logger.info("{}({}): unsubscribing from {}".format(
                self.name, op.name, op.topic))

            @pipeline_thread.invoke_on_pipeline_thread_nowait
            def on_unsubscribed():
                logger.info(
                    "{}({}): UNSUBACK received.  completing op.".format(
                        self.name, op.name))
                operation_flow.complete_op(self, op)

            self.transport.unsubscribe(topic=op.topic,
                                       callback=on_unsubscribed)

        else:
            operation_flow.pass_op_to_next_stage(self, op)

    @pipeline_thread.invoke_on_pipeline_thread_nowait
    def _on_mqtt_message_received(self, topic, payload):
        """
        Handler that gets called by the protocol library when an incoming message arrives.
        Convert that message into a pipeline event and pass it up for someone to handle.
        """
        operation_flow.pass_event_to_previous_stage(
            stage=self,
            event=pipeline_events_mqtt.IncomingMQTTMessageEvent(
                topic=topic, payload=payload),
        )

    @pipeline_thread.invoke_on_pipeline_thread_nowait
    def _on_mqtt_connected(self):
        """
        Handler that gets called by the transport when it connects.
        """
        logger.info("_on_mqtt_connected called")
        # self.on_connected() tells other pipeline stages that we're connected.  Do this before
        # we do anything else (in case upper stages have any "are we connected" logic.
        self.on_connected()

        if isinstance(self._pending_connection_op,
                      pipeline_ops_base.ConnectOperation) or isinstance(
                          self._pending_connection_op,
                          pipeline_ops_base.ReconnectOperation):
            logger.info("completing connect op")
            op = self._pending_connection_op
            self._pending_connection_op = None
            operation_flow.complete_op(stage=self, op=op)
        else:
            # This should indicate something odd is going on.
            # If this occurs, either a connect was completed while there was no pending op,
            # OR that a connect was completed while a disconnect op was pending
            logger.warning("Connection was unexpected")

    @pipeline_thread.invoke_on_pipeline_thread_nowait
    def _on_mqtt_connection_failure(self, cause):
        """
        Handler that gets called by the transport when a connection fails.

        :param Exception cause: The Exception that caused the connection failure.
        """

        logger.error("{}: _on_mqtt_connection_failure called: {}".format(
            self.name, cause))

        if isinstance(self._pending_connection_op,
                      pipeline_ops_base.ConnectOperation) or isinstance(
                          self._pending_connection_op,
                          pipeline_ops_base.ReconnectOperation):
            logger.info("{}: failing connect op".format(self.name))
            op = self._pending_connection_op
            self._pending_connection_op = None
            op.error = cause
            operation_flow.complete_op(stage=self, op=op)
        else:
            logger.warning("{}: Connection failure was unexpected".format(
                self.name))
            unhandled_exceptions.exception_caught_in_background_thread(cause)

    @pipeline_thread.invoke_on_pipeline_thread_nowait
    def _on_mqtt_disconnected(self, cause=None):
        """
        Handler that gets called by the transport when the transport disconnects.

        :param Exception cause: The Exception that caused the disconnection, if any (optional)
        """
        logger.error("{}: _on_mqtt_disconnect called: {}".format(
            self.name, cause))

        # self.on_disconnected() tells other pipeilne stages that we're disconnected.  Do this before
        # we do anything else (in case upper stages have any "are we connected" logic.
        self.on_disconnected()

        if isinstance(self._pending_connection_op,
                      pipeline_ops_base.DisconnectOperation):
            logger.info("{}: completing disconnect op".format(self.name))
            op = self._pending_connection_op
            self._pending_connection_op = None

            if cause:
                # Only create a ConnnectionDroppedError if there is a cause,
                # i.e. unexpected disconnect.
                try:
                    six.raise_from(errors.ConnectionDroppedError, cause)
                except errors.ConnectionDroppedError as e:
                    op.error = e
            operation_flow.complete_op(stage=self, op=op)
        else:
            logger.warning("{}: disconnection was unexpected".format(
                self.name))
            # Regardless of cause, it is now a ConnectionDroppedError
            try:
                six.raise_from(errors.ConnectionDroppedError, cause)
            except errors.ConnectionDroppedError as e:
                unhandled_exceptions.exception_caught_in_background_thread(e)
Example #4
0
class MQTTClientStage(PipelineStage):
    """
    PipelineStage object which is responsible for interfacing with the MQTT protocol wrapper object.
    This stage handles all MQTT operations and any other operations (such as ConnectOperation) which
    is not in the MQTT group of operations, but can only be run at the protocol level.
    """
    @pipeline_thread.runs_on_pipeline_thread
    def _run_op(self, op):
        if isinstance(op, pipeline_ops_mqtt.SetMQTTConnectionArgsOperation):
            # pipeline_ops_mqtt.SetMQTTConnectionArgsOperation is where we create our MQTTTransport object and set
            # all of its properties.
            logger.info("{}({}): got connection args".format(
                self.name, op.name))
            self.hostname = op.hostname
            self.username = op.username
            self.client_id = op.client_id
            self.ca_cert = op.ca_cert
            self.sas_token = None
            self.trusted_certificate_chain = None
            self.transport = MQTTTransport(
                client_id=self.client_id,
                hostname=self.hostname,
                username=self.username,
                ca_cert=self.ca_cert,
            )
            self.transport.on_mqtt_connected = self.on_connected
            self.transport.on_mqtt_disconnected = self.on_disconnected
            self.transport.on_mqtt_message_received = self._on_message_received
            self.pipeline_root.transport = self.transport
            operation_flow.complete_op(self, op)

        elif isinstance(op, pipeline_ops_base.SetSasTokenOperation):
            # When we get a sas token from above, we just save it for later
            logger.info("{}({}): got password".format(self.name, op.name))
            self.sas_token = op.sas_token
            operation_flow.complete_op(self, op)

        elif isinstance(
                op,
                pipeline_ops_base.SetClientAuthenticationCertificateOperation):
            # When we get a certificate from above, we just save it for later
            logger.info("{}({}): got certificate".format(self.name, op.name))
            self.trusted_certificate_chain = op.certificate
            operation_flow.complete_op(self, op)

        elif isinstance(op, pipeline_ops_base.ConnectOperation):
            logger.info("{}({}): conneting".format(self.name, op.name))

            @pipeline_thread.invoke_on_pipeline_thread_nowait
            def on_connected():
                logger.info("{}({}): on_connected.  completing op.".format(
                    self.name, op.name))
                self.transport.on_mqtt_connected = self.on_connected
                self.on_connected()
                operation_flow.complete_op(self, op)

            # A note on exceptions handling in Connect, Disconnct, and Reconnet:
            #
            # All calls into self.transport can raise an exception, and this is OK.
            # The exception handler in PipelineStage.run_op() will catch these errors
            # and propagate them to the caller.  This is an intentional design of the
            # pipeline, that stages, etc, don't need to worry about catching exceptions
            # except for special cases.
            #
            # The code right below this comment is This is a special case.  In addition
            # to this "normal" exception handling, we add another exception handler
            # into this class' Connect, Reconnect, and Disconnect code.  We need to
            # do this because transport.on_mqtt_connected and transport.on_mqtt_disconnected
            # are both _handler_ functions instead of _callbacks_.
            #
            # Because they're handlers instead of callbacks, we need to change the
            # handlers while the connection is established.  We do this so we can
            # know when the protocol is connected so we can move on to the next step.
            # Once the connection is established, we change the handler back to its
            # old value before finishing.
            #
            # The exception handling below is to reset the handler back to its original
            # value in the case where transport.connect raises an exception.  Again,
            # this extra exception handling is only necessary in the Connect, Disconnect,
            # and Reconnect case because they're the only cases that use handlers instead
            # of callbacks.
            #
            self.transport.on_mqtt_connected = on_connected
            try:
                self.transport.connect(
                    password=self.sas_token,
                    client_certificate=self.trusted_certificate_chain)
            except Exception as e:
                self.transport.on_mqtt_connected = self.on_connected
                raise e

        elif isinstance(op, pipeline_ops_base.ReconnectOperation):
            logger.info("{}({}): reconnecting".format(self.name, op.name))

            @pipeline_thread.invoke_on_pipeline_thread_nowait
            def on_connected():
                logger.info("{}({}): on_connected.  completing op.".format(
                    self.name, op.name))
                self.transport.on_mqtt_connected = self.on_connected
                self.on_connected()
                operation_flow.complete_op(self, op)

            # See "A note on exception handling" above
            self.transport.on_mqtt_connected = on_connected
            try:
                self.transport.reconnect(self.sas_token)
            except Exception as e:
                self.transport.on_mqtt_connected = self.on_connected
                raise e

        elif isinstance(op, pipeline_ops_base.DisconnectOperation):
            logger.info("{}({}): disconnecting".format(self.name, op.name))

            @pipeline_thread.invoke_on_pipeline_thread_nowait
            def on_disconnected():
                logger.info("{}({}): on_disconnected.  completing op.".format(
                    self.name, op.name))
                self.transport.on_mqtt_disconnected = self.on_disconnected
                self.on_disconnected()
                operation_flow.complete_op(self, op)

            # See "A note on exception handling" above
            self.transport.on_mqtt_disconnected = on_disconnected
            try:
                self.transport.disconnect()
            except Exception as e:
                self.transport.on_mqtt_disconnected = self.on_disconnected
                raise e

        elif isinstance(op, pipeline_ops_mqtt.MQTTPublishOperation):
            logger.info("{}({}): publishing on {}".format(
                self.name, op.name, op.topic))

            @pipeline_thread.invoke_on_pipeline_thread_nowait
            def on_published():
                logger.info("{}({}): PUBACK received. completing op.".format(
                    self.name, op.name))
                operation_flow.complete_op(self, op)

            self.transport.publish(topic=op.topic,
                                   payload=op.payload,
                                   callback=on_published)

        elif isinstance(op, pipeline_ops_mqtt.MQTTSubscribeOperation):
            logger.info("{}({}): subscribing to {}".format(
                self.name, op.name, op.topic))

            @pipeline_thread.invoke_on_pipeline_thread_nowait
            def on_subscribed():
                logger.info("{}({}): SUBACK received. completing op.".format(
                    self.name, op.name))
                operation_flow.complete_op(self, op)

            self.transport.subscribe(topic=op.topic, callback=on_subscribed)

        elif isinstance(op, pipeline_ops_mqtt.MQTTUnsubscribeOperation):
            logger.info("{}({}): unsubscribing from {}".format(
                self.name, op.name, op.topic))

            @pipeline_thread.invoke_on_pipeline_thread_nowait
            def on_unsubscribed():
                logger.info(
                    "{}({}): UNSUBACK received.  completing op.".format(
                        self.name, op.name))
                operation_flow.complete_op(self, op)

            self.transport.unsubscribe(topic=op.topic,
                                       callback=on_unsubscribed)

        else:
            operation_flow.pass_op_to_next_stage(self, op)

    @pipeline_thread.invoke_on_pipeline_thread_nowait
    def _on_message_received(self, topic, payload):
        """
        Handler that gets called by the protocol library when an incoming message arrives.
        Convert that message into a pipeline event and pass it up for someone to handle.
        """
        operation_flow.pass_event_to_previous_stage(
            stage=self,
            event=pipeline_events_mqtt.IncomingMQTTMessageEvent(
                topic=topic, payload=payload),
        )

    @pipeline_thread.invoke_on_pipeline_thread_nowait
    def on_connected(self):
        super(MQTTClientStage, self).on_connected()

    @pipeline_thread.invoke_on_pipeline_thread_nowait
    def on_disconnected(self):
        super(MQTTClientStage, self).on_disconnected()
class MQTTTransportStage(PipelineStage):
    """
    PipelineStage object which is responsible for interfacing with the MQTT protocol wrapper object.
    This stage handles all MQTT operations and any other operations (such as ConnectOperation) which
    is not in the MQTT group of operations, but can only be run at the protocol level.
    """
    def __init__(self):
        super().__init__()

        # The transport will be instantiated upon receiving the InitializePipelineOperation
        self.transport = None
        # The current in-progress op that affects connection state (Connect, Disconnect, Reauthorize)
        self._pending_connection_op = None

    @pipeline_thread.runs_on_pipeline_thread
    def _cancel_pending_connection_op(self, error=None):
        """
        Cancel any running connect, disconnect or reauthorize connection op. Since our ability to "cancel" is fairly limited,
        all this does (for now) is to fail the operation
        """

        op = self._pending_connection_op
        if op:
            # NOTE: This code path should NOT execute in normal flow. There should never already be a pending
            # connection op when another is added, due to the ConnectionLock stage.
            # If this block does execute, there is a bug in the codebase.
            if not error:
                error = pipeline_exceptions.OperationCancelled(
                    "Cancelling because new ConnectOperation or DisconnectOperation was issued"
                )
            self._cancel_connection_watchdog(op)
            self._pending_connection_op = None
            op.complete(error=error)

    @pipeline_thread.runs_on_pipeline_thread
    def _start_connection_watchdog(self, connection_op):
        """
        Start a watchdog on the connection operation. This protects against cases where transport.connect()
        succeeds but the CONNACK never arrives. This is like a timeout, but it is handled at this level
        because specific cleanup needs to take place on timeout (see below), and this cleanup doesn't
        belong anywhere else since it is very specific to this stage.
        """
        logger.debug("{}({}): Starting watchdog".format(
            self.name, connection_op.name))

        self_weakref = weakref.ref(self)
        op_weakref = weakref.ref(connection_op)

        @pipeline_thread.invoke_on_pipeline_thread
        def watchdog_function():
            this = self_weakref()
            op = op_weakref()
            if this and op and this._pending_connection_op is op:
                logger.info(
                    "{}({}): Connection watchdog expired.  Cancelling op".
                    format(this.name, op.name))
                try:
                    this.transport.disconnect()
                except Exception:
                    # If we don't catch this, the pending connection op might not ever be cancelled.
                    # Most likely, the transport isn't actually connected, but other failures are theoretically
                    # possible. Either way, if disconnect fails, we should assume that we're disconnected.
                    logger.info(
                        "transport.disconnect raised error while disconnecting in watchdog.  Safe to ignore."
                    )
                    logger.info(traceback.format_exc())

                if this.pipeline_root.connected:
                    logger.info(
                        "{}({}): Pipeline is still connected on watchdog expiration.  Sending DisconnectedEvent"
                        .format(this.name, op.name))
                    this.send_event_up(
                        pipeline_events_base.DisconnectedEvent())
                this._cancel_pending_connection_op(
                    error=pipeline_exceptions.OperationTimeout(
                        "Transport timeout on connection operation"))
            else:
                logger.debug(
                    "Connection watchdog expired, but pending op is not the same op"
                )

        connection_op.watchdog_timer = threading.Timer(WATCHDOG_INTERVAL,
                                                       watchdog_function)
        connection_op.watchdog_timer.daemon = True
        connection_op.watchdog_timer.start()

    @pipeline_thread.runs_on_pipeline_thread
    def _cancel_connection_watchdog(self, op):
        try:
            if op.watchdog_timer:
                logger.debug("{}({}): cancelling watchdog".format(
                    self.name, op.name))
                op.watchdog_timer.cancel()
                op.watchdog_timer = None
        except AttributeError:
            pass

    @pipeline_thread.runs_on_pipeline_thread
    def _run_op(self, op):
        if isinstance(op, pipeline_ops_base.InitializePipelineOperation):

            # If there is a gateway hostname, use that as the hostname for connection,
            # rather than the hostname itself
            if self.pipeline_root.pipeline_configuration.gateway_hostname:
                logger.debug(
                    "Gateway Hostname Present. Setting Hostname to: {}".format(
                        self.pipeline_root.pipeline_configuration.
                        gateway_hostname))
                hostname = self.pipeline_root.pipeline_configuration.gateway_hostname
            else:
                logger.debug(
                    "Gateway Hostname not present. Setting Hostname to: {}".
                    format(self.pipeline_root.pipeline_configuration.hostname))
                hostname = self.pipeline_root.pipeline_configuration.hostname

            # Create the Transport object, set it's handlers
            logger.debug("{}({}): got connection args".format(
                self.name, op.name))
            self.transport = MQTTTransport(
                client_id=op.client_id,
                hostname=hostname,
                username=op.username,
                server_verification_cert=self.pipeline_root.
                pipeline_configuration.server_verification_cert,
                x509_cert=self.pipeline_root.pipeline_configuration.x509,
                websockets=self.pipeline_root.pipeline_configuration.
                websockets,
                cipher=self.pipeline_root.pipeline_configuration.cipher,
                proxy_options=self.pipeline_root.pipeline_configuration.
                proxy_options,
                keep_alive=self.pipeline_root.pipeline_configuration.
                keep_alive,
            )
            self.transport.on_mqtt_connected_handler = self._on_mqtt_connected
            self.transport.on_mqtt_connection_failure_handler = self._on_mqtt_connection_failure
            self.transport.on_mqtt_disconnected_handler = self._on_mqtt_disconnected
            self.transport.on_mqtt_message_received_handler = self._on_mqtt_message_received

            # There can only be one pending connection operation (Connect, Disconnect)
            # at a time. The existing one must be completed or canceled before a new one is set.

            # Currently, this means that if, say, a connect operation is the pending op and is executed
            # but another connection op is begins by the time the CONNACK is received, the original
            # operation will be cancelled, but the CONNACK for it will still be received, and complete the
            # NEW operation. This is not desirable, but it is how things currently work.

            # We are however, checking the type, so the CONNACK from a cancelled Connect, cannot successfully
            # complete a Disconnect operation.

            # Note that a ReauthorizeConnectionOperation will never be pending because it will
            # instead spawn separate Connect and Disconnect operations.
            self._pending_connection_op = None

            op.complete()

        elif isinstance(op, pipeline_ops_base.ShutdownPipelineOperation):
            try:
                self.transport.shutdown()
            except Exception as e:
                logger.info("transport.shutdown raised error")
                logger.info(traceback.format_exc())
                op.complete(error=e)
            else:
                op.complete()

        elif isinstance(op, pipeline_ops_base.ConnectOperation):
            logger.debug("{}({}): connecting".format(self.name, op.name))

            self._cancel_pending_connection_op()
            self._pending_connection_op = op
            self._start_connection_watchdog(op)
            # Use SasToken as password if present. If not present (e.g. using X509),
            # then no password is required because auth is handled via other means.
            if self.pipeline_root.pipeline_configuration.sastoken:
                password = str(
                    self.pipeline_root.pipeline_configuration.sastoken)
            else:
                password = None
            try:
                self.transport.connect(password=password)
            except Exception as e:
                logger.info("transport.connect raised error")
                logger.info(traceback.format_exc())
                self._cancel_connection_watchdog(op)
                self._pending_connection_op = None
                op.complete(error=e)

        elif isinstance(op, pipeline_ops_base.DisconnectOperation):
            logger.debug("{}({}): disconnecting".format(self.name, op.name))

            self._cancel_pending_connection_op()
            self._pending_connection_op = op
            # We don't need a watchdog on disconnect because there's no callback to wait for
            # and we respond to a watchdog timeout by calling disconnect, which is what we're
            # already doing.

            try:
                # The connect after the disconnect will be triggered upon completion of the
                # disconnect in the on_disconnected handler
                self.transport.disconnect(clear_inflight=op.hard)
            except Exception as e:
                logger.info(
                    "transport.disconnect raised error while disconnecting")
                logger.info(traceback.format_exc())
                self._pending_connection_op = None
                op.complete(error=e)

        elif isinstance(op, pipeline_ops_base.ReauthorizeConnectionOperation):
            logger.debug(
                "{}({}): reauthorizing. Will issue disconnect and then a connect"
                .format(self.name, op.name))
            self_weakref = weakref.ref(self)
            reauth_op = op  # rename for clarity

            def on_disconnect_complete(op, error):
                this = self_weakref()
                if error:
                    # Failing a disconnect should still get us disconnected, so can proceed anyway
                    logger.debug(
                        "Disconnect failed during reauthorization, continuing with connect"
                    )
                connect_op = reauth_op.spawn_worker_op(
                    pipeline_ops_base.ConnectOperation)

                # NOTE: this relies on the fact that before the disconnect is completed it is
                # unset as the pending connection op. Otherwise there would be issues here.
                this.run_op(connect_op)

            disconnect_op = pipeline_ops_base.DisconnectOperation(
                callback=on_disconnect_complete)
            disconnect_op.hard = False

            self.run_op(disconnect_op)

        elif isinstance(op, pipeline_ops_mqtt.MQTTPublishOperation):
            logger.debug("{}({}): publishing on {}".format(
                self.name, op.name, op.topic))

            @pipeline_thread.invoke_on_pipeline_thread_nowait
            def on_complete(cancelled=False):
                if cancelled:
                    op.complete(error=pipeline_exceptions.OperationCancelled(
                        "Operation cancelled before PUBACK received"))
                else:
                    logger.debug(
                        "{}({}): PUBACK received. completing op.".format(
                            self.name, op.name))
                    op.complete()

            try:
                self.transport.publish(topic=op.topic,
                                       payload=op.payload,
                                       callback=on_complete)
            except Exception as e:
                op.complete(error=e)

        elif isinstance(op, pipeline_ops_mqtt.MQTTSubscribeOperation):
            logger.debug("{}({}): subscribing to {}".format(
                self.name, op.name, op.topic))

            @pipeline_thread.invoke_on_pipeline_thread_nowait
            def on_complete(cancelled=False):
                if cancelled:
                    op.complete(error=pipeline_exceptions.OperationCancelled(
                        "Operation cancelled before SUBACK received"))
                else:
                    logger.debug(
                        "{}({}): SUBACK received. completing op.".format(
                            self.name, op.name))
                    op.complete()

            try:
                self.transport.subscribe(topic=op.topic, callback=on_complete)
            except Exception as e:
                op.complete(error=e)

        elif isinstance(op, pipeline_ops_mqtt.MQTTUnsubscribeOperation):
            logger.debug("{}({}): unsubscribing from {}".format(
                self.name, op.name, op.topic))

            @pipeline_thread.invoke_on_pipeline_thread_nowait
            def on_complete(cancelled=False):
                if cancelled:
                    op.complete(error=pipeline_exceptions.OperationCancelled(
                        "Operation cancelled before UNSUBACK received"))
                else:
                    logger.debug(
                        "{}({}): UNSUBACK received.  completing op.".format(
                            self.name, op.name))
                    op.complete()

            try:
                self.transport.unsubscribe(topic=op.topic,
                                           callback=on_complete)
            except Exception as e:
                op.complete(error=e)

        else:
            # This code block should not be reached in correct program flow.
            # This will raise an error when executed.
            self.send_op_down(op)

    @pipeline_thread.invoke_on_pipeline_thread_nowait
    def _on_mqtt_message_received(self, topic, payload):
        """
        Handler that gets called by the protocol library when an incoming message arrives.
        Convert that message into a pipeline event and pass it up for someone to handle.
        """
        logger.debug("{}: message received on topic {}".format(
            self.name, topic))
        self.send_event_up(
            pipeline_events_mqtt.IncomingMQTTMessageEvent(topic=topic,
                                                          payload=payload))

    @pipeline_thread.invoke_on_pipeline_thread_nowait
    def _on_mqtt_connected(self):
        """
        Handler that gets called by the transport when it connects.
        """
        logger.info("_on_mqtt_connected called")
        # Send an event to tell other pipeline stages that we're connected. Do this before
        # we do anything else (in case upper stages have any "are we connected" logic.
        self.send_event_up(pipeline_events_base.ConnectedEvent())

        if isinstance(self._pending_connection_op,
                      pipeline_ops_base.ConnectOperation):
            logger.debug("{}: completing connect op".format(self.name))
            op = self._pending_connection_op
            self._cancel_connection_watchdog(op)
            self._pending_connection_op = None
            op.complete()
        else:
            # This should indicate something odd is going on.
            # If this occurs, either a connect was completed while there was no pending op,
            # OR that a connect was completed while a disconnect op was pending
            logger.info(
                "{}: Connection was unexpected (no connection op pending)".
                format(self.name))

    @pipeline_thread.invoke_on_pipeline_thread_nowait
    def _on_mqtt_connection_failure(self, cause):
        """
        Handler that gets called by the transport when a connection fails.

        :param Exception cause: The Exception that caused the connection failure.
        """

        logger.info("{}: _on_mqtt_connection_failure called: {}".format(
            self.name, cause))

        if isinstance(self._pending_connection_op,
                      pipeline_ops_base.ConnectOperation):
            logger.debug("{}: failing connect op".format(self.name))
            op = self._pending_connection_op
            self._cancel_connection_watchdog(op)
            self._pending_connection_op = None
            op.complete(error=cause)
        else:
            logger.debug("{}: Connection failure was unexpected".format(
                self.name))
            handle_exceptions.swallow_unraised_exception(
                cause,
                log_msg=
                "Unexpected connection failure (no pending operation). Safe to ignore.",
                log_lvl="info",
            )

    @pipeline_thread.invoke_on_pipeline_thread_nowait
    def _on_mqtt_disconnected(self, cause=None):
        """
        Handler that gets called by the transport when the transport disconnects.

        :param Exception cause: The Exception that caused the disconnection, if any (optional)
        """
        if cause:
            logger.info("{}: _on_mqtt_disconnect called: {}".format(
                self.name, cause))
        else:
            logger.info("{}: _on_mqtt_disconnect called".format(self.name))

        # Send an event to tell other pipeline stages that we're disconnected. Do this before
        # we do anything else (in case upper stages have any "are we connected" logic.)
        # NOTE: Other stages rely on the fact that this occurs before any op that may be in
        # progress is completed. Be careful with changing the order things occur here.
        self.send_event_up(pipeline_events_base.DisconnectedEvent())

        if self._pending_connection_op:

            op = self._pending_connection_op

            if isinstance(op, pipeline_ops_base.DisconnectOperation):
                logger.debug(
                    "{}: Expected disconnect - completing pending disconnect op"
                    .format(self.name))
                # Swallow any errors if we intended to disconnect - even if something went wrong, we
                # got to the state we wanted to be in!
                if cause:
                    handle_exceptions.swallow_unraised_exception(
                        cause,
                        log_msg=
                        "Unexpected error while disconnecting - swallowing error",
                    )
                # Disconnect complete, no longer pending
                self._pending_connection_op = None
                op.complete()

            else:
                logger.debug(
                    "{}: Unexpected disconnect - completing pending {} operation"
                    .format(self.name, op.name))
                # Cancel any potential connection watchdog, and clear the pending op
                self._cancel_connection_watchdog(op)
                self._pending_connection_op = None
                # Complete
                if cause:
                    op.complete(error=cause)
                else:
                    op.complete(
                        error=transport_exceptions.ConnectionDroppedError(
                            "transport disconnected"))
        else:
            logger.info(
                "{}: Unexpected disconnect (no pending connection op)".format(
                    self.name))

            # If there is no connection retry, cancel any transport operations waiting on response
            # so that they do not get stuck there.
            if not self.pipeline_root.pipeline_configuration.connection_retry:
                logger.debug(
                    "{}: Connection Retry disabled - cancelling in-flight operations"
                    .format(self.name))
                # TODO: Remove private access to the op manager (this layer shouldn't know about it)
                # This is a stopgap. I didn't want to invest too much infrastructure into a cancel flow
                # given that future development of individual operation cancels might affect the
                # approach to cancelling inflight ops waiting in the transport.
                self.transport._op_manager.cancel_all_operations()

            # Regardless of cause, it is now a ConnectionDroppedError. Log it and swallow it.
            # Higher layers will see that we're disconnected and may reconnect as necessary.
            e = transport_exceptions.ConnectionDroppedError(
                "Unexpected disconnection")
            e.__cause__ = cause
            self.report_background_exception(e)
class MQTTTransportStage(PipelineStage):
    """
    PipelineStage object which is responsible for interfacing with the MQTT protocol wrapper object.
    This stage handles all MQTT operations and any other operations (such as ConnectOperation) which
    is not in the MQTT group of operations, but can only be run at the protocol level.
    """
    def __init__(self):
        super(MQTTTransportStage, self).__init__()

        # The sas_token will be set when Connetion Args are received
        self.sas_token = None

        # The transport will be instantiated when Connection Args are received
        self.transport = None

        self._pending_connection_op = None

    @pipeline_thread.runs_on_pipeline_thread
    def _cancel_pending_connection_op(self):
        """
        Cancel any running connect, disconnect or reauthorize_connection op. Since our ability to "cancel" is fairly limited,
        all this does (for now) is to fail the operation
        """

        op = self._pending_connection_op
        if op:
            # NOTE: This code path should NOT execute in normal flow. There should never already be a pending
            # connection op when another is added, due to the SerializeConnectOps stage.
            # If this block does execute, there is a bug in the codebase.
            error = pipeline_exceptions.OperationCancelled(
                "Cancelling because new ConnectOperation, DisconnectOperation, or ReauthorizeConnectionOperation was issued"
            )  # TODO: should this actually somehow cancel the operation?
            op.complete(error=error)
            self._pending_connection_op = None

    @pipeline_thread.runs_on_pipeline_thread
    def _run_op(self, op):
        if isinstance(op, pipeline_ops_mqtt.SetMQTTConnectionArgsOperation):
            # pipeline_ops_mqtt.SetMQTTConnectionArgsOperation is where we create our MQTTTransport object and set
            # all of its properties.
            logger.debug("{}({}): got connection args".format(
                self.name, op.name))
            self.sas_token = op.sas_token
            self.transport = MQTTTransport(
                client_id=op.client_id,
                hostname=op.hostname,
                username=op.username,
                server_verification_cert=op.server_verification_cert,
                x509_cert=op.client_cert,
                websockets=self.pipeline_root.pipeline_configuration.
                websockets,
            )
            self.transport.on_mqtt_connected_handler = CallableWeakMethod(
                self, "_on_mqtt_connected")
            self.transport.on_mqtt_connection_failure_handler = CallableWeakMethod(
                self, "_on_mqtt_connection_failure")
            self.transport.on_mqtt_disconnected_handler = CallableWeakMethod(
                self, "_on_mqtt_disconnected")
            self.transport.on_mqtt_message_received_handler = CallableWeakMethod(
                self, "_on_mqtt_message_received")

            # There can only be one pending connection operation (Connect, ReauthorizeConnection, Disconnect)
            # at a time. The existing one must be completed or canceled before a new one is set.

            # Currently, this means that if, say, a connect operation is the pending op and is executed
            # but another connection op is begins by the time the CONNACK is received, the original
            # operation will be cancelled, but the CONNACK for it will still be received, and complete the
            # NEW operation. This is not desirable, but it is how things currently work.

            # We are however, checking the type, so the CONNACK from a cancelled Connect, cannot successfully
            # complete a Disconnect operation.
            self._pending_connection_op = None

            op.complete()

        elif isinstance(op, pipeline_ops_base.UpdateSasTokenOperation):
            logger.debug("{}({}): saving sas token and completing".format(
                self.name, op.name))
            self.sas_token = op.sas_token
            op.complete()

        elif isinstance(op, pipeline_ops_base.ConnectOperation):
            logger.info("{}({}): connecting".format(self.name, op.name))

            self._cancel_pending_connection_op()
            self._pending_connection_op = op
            try:
                self.transport.connect(password=self.sas_token)
            except Exception as e:
                logger.error("transport.connect raised error")
                logger.error(traceback.format_exc())
                self._pending_connection_op = None
                op.complete(error=e)

        elif isinstance(op, pipeline_ops_base.ReauthorizeConnectionOperation):
            logger.info("{}({}): reauthorizing".format(self.name, op.name))

            # We set _active_connect_op here because reauthorizing the connection is the same as a connect for "active operation" tracking purposes.
            self._cancel_pending_connection_op()
            self._pending_connection_op = op
            try:
                self.transport.reauthorize_connection(password=self.sas_token)
            except Exception as e:
                logger.error("transport.reauthorize_connection raised error")
                logger.error(traceback.format_exc())
                self._pending_connection_op = None
                op.complete(error=e)

        elif isinstance(op, pipeline_ops_base.DisconnectOperation):
            logger.info("{}({}): disconnecting".format(self.name, op.name))

            self._cancel_pending_connection_op()
            self._pending_connection_op = op
            try:
                self.transport.disconnect()
            except Exception as e:
                logger.error("transport.disconnect raised error")
                logger.error(traceback.format_exc())
                self._pending_connection_op = None
                op.complete(error=e)

        elif isinstance(op, pipeline_ops_mqtt.MQTTPublishOperation):
            logger.info("{}({}): publishing on {}".format(
                self.name, op.name, op.topic))

            @pipeline_thread.invoke_on_pipeline_thread_nowait
            def on_published():
                logger.debug("{}({}): PUBACK received. completing op.".format(
                    self.name, op.name))
                op.complete()

            self.transport.publish(topic=op.topic,
                                   payload=op.payload,
                                   callback=on_published)

        elif isinstance(op, pipeline_ops_mqtt.MQTTSubscribeOperation):
            logger.info("{}({}): subscribing to {}".format(
                self.name, op.name, op.topic))

            @pipeline_thread.invoke_on_pipeline_thread_nowait
            def on_subscribed():
                logger.debug("{}({}): SUBACK received. completing op.".format(
                    self.name, op.name))
                op.complete()

            self.transport.subscribe(topic=op.topic, callback=on_subscribed)

        elif isinstance(op, pipeline_ops_mqtt.MQTTUnsubscribeOperation):
            logger.info("{}({}): unsubscribing from {}".format(
                self.name, op.name, op.topic))

            @pipeline_thread.invoke_on_pipeline_thread_nowait
            def on_unsubscribed():
                logger.debug(
                    "{}({}): UNSUBACK received.  completing op.".format(
                        self.name, op.name))
                op.complete()

            self.transport.unsubscribe(topic=op.topic,
                                       callback=on_unsubscribed)

        else:
            # This code block should not be reached in correct program flow.
            # This will raise an error when executed.
            self.send_op_down(op)

    @pipeline_thread.invoke_on_pipeline_thread_nowait
    def _on_mqtt_message_received(self, topic, payload):
        """
        Handler that gets called by the protocol library when an incoming message arrives.
        Convert that message into a pipeline event and pass it up for someone to handle.
        """
        self.send_event_up(
            pipeline_events_mqtt.IncomingMQTTMessageEvent(topic=topic,
                                                          payload=payload))

    @pipeline_thread.invoke_on_pipeline_thread_nowait
    def _on_mqtt_connected(self):
        """
        Handler that gets called by the transport when it connects.
        """
        logger.info("_on_mqtt_connected called")
        # Send an event to tell other pipeline stages that we're connected. Do this before
        # we do anything else (in case upper stages have any "are we connected" logic.
        self.send_event_up(pipeline_events_base.ConnectedEvent())

        if isinstance(self._pending_connection_op,
                      pipeline_ops_base.ConnectOperation) or isinstance(
                          self._pending_connection_op,
                          pipeline_ops_base.ReauthorizeConnectionOperation):
            logger.debug("completing connect op")
            op = self._pending_connection_op
            self._pending_connection_op = None
            op.complete()
        else:
            # This should indicate something odd is going on.
            # If this occurs, either a connect was completed while there was no pending op,
            # OR that a connect was completed while a disconnect op was pending
            logger.warning("Connection was unexpected")

    @pipeline_thread.invoke_on_pipeline_thread_nowait
    def _on_mqtt_connection_failure(self, cause):
        """
        Handler that gets called by the transport when a connection fails.

        :param Exception cause: The Exception that caused the connection failure.
        """

        logger.info("{}: _on_mqtt_connection_failure called: {}".format(
            self.name, cause))

        if isinstance(self._pending_connection_op,
                      pipeline_ops_base.ConnectOperation) or isinstance(
                          self._pending_connection_op,
                          pipeline_ops_base.ReauthorizeConnectionOperation):
            logger.debug("{}: failing connect op".format(self.name))
            op = self._pending_connection_op
            self._pending_connection_op = None
            op.complete(error=cause)
        else:
            logger.warning("{}: Connection failure was unexpected".format(
                self.name))
            handle_exceptions.swallow_unraised_exception(
                cause,
                log_msg="Unexpected connection failure.  Safe to ignore.",
                log_lvl="info")

    @pipeline_thread.invoke_on_pipeline_thread_nowait
    def _on_mqtt_disconnected(self, cause=None):
        """
        Handler that gets called by the transport when the transport disconnects.

        :param Exception cause: The Exception that caused the disconnection, if any (optional)
        """
        if cause:
            logger.info("{}: _on_mqtt_disconnect called: {}".format(
                self.name, cause))
        else:
            logger.info("{}: _on_mqtt_disconnect called".format(self.name))

        # Send an event to tell other pipeilne stages that we're disconnected. Do this before
        # we do anything else (in case upper stages have any "are we connected" logic.)
        self.send_event_up(pipeline_events_base.DisconnectedEvent())

        if self._pending_connection_op:
            # on_mqtt_disconnected will cause any pending connect op to complete.  This is how Paho
            # behaves when there is a connection error, and it also makes sense that on_mqtt_disconnected
            # would cause a pending connection op to fail.
            logger.debug("{}: completing pending {} op".format(
                self.name, self._pending_connection_op.name))
            op = self._pending_connection_op
            self._pending_connection_op = None

            if isinstance(op, pipeline_ops_base.DisconnectOperation):
                # Swallow any errors if we intended to disconnect - even if something went wrong, we
                # got to the state we wanted to be in!
                if cause:
                    handle_exceptions.swallow_unraised_exception(
                        cause,
                        log_msg=
                        "Unexpected disconnect with error while disconnecting - swallowing error",
                    )
                op.complete()
            else:
                if cause:
                    op.complete(error=cause)
                else:
                    op.complete(
                        error=transport_exceptions.ConnectionDroppedError(
                            "transport disconnected"))
        else:
            logger.warning("{}: disconnection was unexpected".format(
                self.name))
            # Regardless of cause, it is now a ConnectionDroppedError.  log it and swallow it.
            # Higher layers will see that we're disconencted and reconnect as necessary.
            e = transport_exceptions.ConnectionDroppedError(cause=cause)
            handle_exceptions.swallow_unraised_exception(
                e,
                log_msg=
                "Unexpected disconnection.  Safe to ignore since other stages will reconnect.",
                log_lvl="info",
            )