Ejemplo n.º 1
0
 async def publish_retained_messages_for_subscription(self, subscription, session):
     self.logger.debug(
         "Begin broadcasting messages retained due to subscription on '%s' from %s"
         % (subscription[0], format_client_message(session=session))
     )
     publish_tasks = []
     handler = self._get_handler(session)
     for d_topic in self._retained_messages:
         self.logger.debug("matching : %s %s" % (d_topic, subscription[0]))
         if self.matches(d_topic, subscription[0]):
             self.logger.debug("%s and %s match" % (d_topic, subscription[0]))
             retained = self._retained_messages[d_topic]
             publish_tasks.append(
                 asyncio.Task(
                     handler.mqtt_publish(
                         retained.topic, retained.data, subscription[1], True
                     ),
                     loop=self._loop,
                 )
             )
     if publish_tasks:
         await asyncio.wait(publish_tasks, loop=self._loop)
     self.logger.debug(
         "End broadcasting messages retained due to subscription on '%s' from %s"
         % (subscription[0], format_client_message(session=session))
     )
Ejemplo n.º 2
0
 async def handle_connect(self, connect: ConnectPacket):
     # Broker handler shouldn't received CONNECT message during messages handling
     # as CONNECT messages are managed by the broker on client connection
     self.logger.error(
         "%s [MQTT-3.1.0-2] %s : CONNECT message received during messages handling"
         % (self.session.client_id, format_client_message(self.session))
     )
     if self._disconnect_waiter is not None and not self._disconnect_waiter.done():
         self._disconnect_waiter.set_result(None)
Ejemplo n.º 3
0
 async def publish_session_retained_messages(self, session):
     self.logger.debug(
         "Publishing %d messages retained for session %s"
         % (
             session.retained_messages.qsize(),
             format_client_message(session=session),
         )
     )
     publish_tasks = []
     handler = self._get_handler(session)
     while not session.retained_messages.empty():
         retained = await session.retained_messages.get()
         publish_tasks.append(
             asyncio.ensure_future(
                 handler.mqtt_publish(
                     retained.topic, retained.data, retained.qos, True
                 ),
                 loop=self._loop,
             )
         )
     if publish_tasks:
         await asyncio.wait(publish_tasks, loop=self._loop)
Ejemplo n.º 4
0
 async def add_subscription(self, subscription, session):
     try:
         a_filter = subscription[0]
         if "#" in a_filter and not a_filter.endswith("#"):
             # [MQTT-4.7.1-2] Wildcard character '#' is only allowed as last character in filter
             return 0x80
         if a_filter != "+":
             if "+" in a_filter:
                 if "/+" not in a_filter and "+/" not in a_filter:
                     # [MQTT-4.7.1-3] + wildcard character must occupy entire level
                     return 0x80
         # Check if the client is authorised to connect to the topic
         permitted = await self.topic_filtering(session, topic=a_filter)
         if not permitted:
             return 0x80
         qos = subscription[1]
         if "max-qos" in self.config and qos > self.config["max-qos"]:
             qos = self.config["max-qos"]
         if a_filter not in self._subscriptions:
             self._subscriptions[a_filter] = []
         already_subscribed = next(
             (
                 s
                 for (s, qos) in self._subscriptions[a_filter]
                 if s.client_id == session.client_id
             ),
             None,
         )
         if not already_subscribed:
             self._subscriptions[a_filter].append((session, qos))
         else:
             self.logger.debug(
                 "Client %s has already subscribed to %s"
                 % (format_client_message(session=session), a_filter)
             )
         return qos
     except KeyError:
         return 0x80
Ejemplo n.º 5
0
 def _del_subscription(self, a_filter: str, session: Session) -> int:
     """
     Delete a session subscription on a given topic
     :param a_filter:
     :param session:
     :return:
     """
     deleted = 0
     try:
         subscriptions = self._subscriptions[a_filter]
         for index, (sub_session, qos) in enumerate(subscriptions):
             if sub_session.client_id == session.client_id:
                 self.logger.debug(
                     "Removing subscription on topic '%s' for client %s"
                     % (a_filter, format_client_message(session=session))
                 )
                 subscriptions.pop(index)
                 deleted += 1
                 break
     except KeyError:
         # Unsubscribe topic not found in current subscribed topics
         pass
     finally:
         return deleted
Ejemplo n.º 6
0
    async def init_from_connect(
        cls, reader: ReaderAdapter, writer: WriterAdapter, plugins_manager, loop=None
    ):
        """

        :param reader:
        :param writer:
        :param plugins_manager:
        :param loop:
        :return:
        """
        remote_address, remote_port = writer.get_peer_info()
        connect = await ConnectPacket.from_stream(reader)
        await plugins_manager.fire_event(EVENT_MQTT_PACKET_RECEIVED, packet=connect)
        # this shouldn't be required anymore since broker generates for each client a random client_id if not provided
        # [MQTT-3.1.3-6]
        if connect.payload.client_id is None:
            raise MQTTException("[[MQTT-3.1.3-3]] : Client identifier must be present")

        if connect.variable_header.will_flag:
            if (
                connect.payload.will_topic is None
                or connect.payload.will_message is None
            ):
                raise MQTTException(
                    "will flag set, but will topic/message not present in payload"
                )

        if connect.variable_header.reserved_flag:
            raise MQTTException("[MQTT-3.1.2-3] CONNECT reserved flag must be set to 0")
        if connect.proto_name != "MQTT":
            raise MQTTException(
                '[MQTT-3.1.2-1] Incorrect protocol name: "%s"' % connect.proto_name
            )

        connack = None
        error_msg = None
        if connect.proto_level != 4:
            # only MQTT 3.1.1 supported
            error_msg = "Invalid protocol from %s: %d" % (
                format_client_message(address=remote_address, port=remote_port),
                connect.proto_level,
            )
            connack = ConnackPacket.build(
                0, UNACCEPTABLE_PROTOCOL_VERSION
            )  # [MQTT-3.2.2-4] session_parent=0
        elif not connect.username_flag and connect.password_flag:
            connack = ConnackPacket.build(0, BAD_USERNAME_PASSWORD)  # [MQTT-3.1.2-22]
        elif connect.username_flag and not connect.password_flag:
            connack = ConnackPacket.build(0, BAD_USERNAME_PASSWORD)  # [MQTT-3.1.2-22]
        elif connect.username_flag and connect.username is None:
            error_msg = "Invalid username from %s" % (
                format_client_message(address=remote_address, port=remote_port)
            )
            connack = ConnackPacket.build(
                0, BAD_USERNAME_PASSWORD
            )  # [MQTT-3.2.2-4] session_parent=0
        elif connect.password_flag and connect.password is None:
            error_msg = "Invalid password %s" % (
                format_client_message(address=remote_address, port=remote_port)
            )
            connack = ConnackPacket.build(
                0, BAD_USERNAME_PASSWORD
            )  # [MQTT-3.2.2-4] session_parent=0
        elif connect.clean_session_flag is False and (
            connect.payload.client_id_is_random
        ):
            error_msg = (
                "[MQTT-3.1.3-8] [MQTT-3.1.3-9] %s: No client Id provided (cleansession=0)"
                % (format_client_message(address=remote_address, port=remote_port))
            )
            connack = ConnackPacket.build(0, IDENTIFIER_REJECTED)
        if connack is not None:
            await plugins_manager.fire_event(EVENT_MQTT_PACKET_SENT, packet=connack)
            await connack.to_stream(writer)
            await writer.close()
            raise MQTTException(error_msg)

        incoming_session = Session(loop)
        incoming_session.client_id = connect.client_id
        incoming_session.clean_session = connect.clean_session_flag
        incoming_session.will_flag = connect.will_flag
        incoming_session.will_retain = connect.will_retain_flag
        incoming_session.will_qos = connect.will_qos
        incoming_session.will_topic = connect.will_topic
        incoming_session.will_message = connect.will_message
        incoming_session.username = connect.username
        incoming_session.password = connect.password
        if connect.keep_alive > 0:
            incoming_session.keep_alive = connect.keep_alive
        else:
            incoming_session.keep_alive = 0

        handler = cls(plugins_manager, loop=loop)
        return handler, incoming_session
Ejemplo n.º 7
0
 async def _broadcast_loop(self):
     running_tasks = deque()
     try:
         while True:
             while running_tasks and running_tasks[0].done():
                 task = running_tasks.popleft()
                 try:
                     task.result()  # make asyncio happy and collect results
                 except Exception:
                     pass
             broadcast = await self._broadcast_queue.get()
             if self.logger.isEnabledFor(logging.DEBUG):
                 self.logger.debug("broadcasting %r" % broadcast)
             for k_filter in self._subscriptions:
                 if broadcast["topic"].startswith("$") and (
                     k_filter.startswith("+") or k_filter.startswith("#")
                 ):
                     self.logger.debug(
                         "[MQTT-4.7.2-1] - ignoring brodcasting $ topic to subscriptions starting with + or #"
                     )
                 elif self.matches(broadcast["topic"], k_filter):
                     subscriptions = self._subscriptions[k_filter]
                     for (target_session, qos) in subscriptions:
                         if "qos" in broadcast:
                             qos = broadcast["qos"]
                         if target_session.transitions.state == "connected":
                             self.logger.debug(
                                 "broadcasting application message from %s on topic '%s' to %s"
                                 % (
                                     format_client_message(
                                         session=broadcast["session"]
                                     ),
                                     broadcast["topic"],
                                     format_client_message(session=target_session),
                                 )
                             )
                             handler = self._get_handler(target_session)
                             task = asyncio.ensure_future(
                                 handler.mqtt_publish(
                                     broadcast["topic"],
                                     broadcast["data"],
                                     qos,
                                     retain=False,
                                 ),
                                 loop=self._loop,
                             )
                             running_tasks.append(task)
                         else:
                             self.logger.debug(
                                 "retaining application message from %s on topic '%s' to client '%s'"
                                 % (
                                     format_client_message(
                                         session=broadcast["session"]
                                     ),
                                     broadcast["topic"],
                                     format_client_message(session=target_session),
                                 )
                             )
                             retained_message = RetainedApplicationMessage(
                                 broadcast["session"],
                                 broadcast["topic"],
                                 broadcast["data"],
                                 qos,
                             )
                             await target_session.retained_messages.put(
                                 retained_message
                             )
                             if self.logger.isEnabledFor(logging.DEBUG):
                                 self.logger.debug(
                                     f"target_session.retained_messages={target_session.retained_messages.qsize()}"
                                 )
     except CancelledError:
         # Wait until current broadcasting tasks end
         if running_tasks:
             await asyncio.wait(running_tasks, loop=self._loop)
         raise  # reraise per CancelledError semantics
Ejemplo n.º 8
0
    async def client_connected(
        self, listener_name, reader: ReaderAdapter, writer: WriterAdapter
    ):
        # Wait for connection available on listener
        server = self._servers.get(listener_name, None)
        if not server:
            raise BrokerException("Invalid listener name '%s'" % listener_name)
        await server.acquire_connection()

        remote_address, remote_port = writer.get_peer_info()
        self.logger.info(
            "Connection from %s:%d on listener '%s'"
            % (remote_address, remote_port, listener_name)
        )

        # Wait for first packet and expect a CONNECT
        try:
            handler, client_session = await BrokerProtocolHandler.init_from_connect(
                reader, writer, self.plugins_manager, loop=self._loop
            )
        except HBMQTTException as exc:
            self.logger.warning(
                "[MQTT-3.1.0-1] %s: Can't read first packet an CONNECT: %s"
                % (format_client_message(address=remote_address, port=remote_port), exc)
            )
            # await writer.close()
            self.logger.debug("Connection closed")
            return
        except MQTTException as me:
            self.logger.error(
                "Invalid connection from %s : %s"
                % (format_client_message(address=remote_address, port=remote_port), me)
            )
            await writer.close()
            self.logger.debug("Connection closed")
            return

        if client_session.clean_session:
            # Delete existing session and create a new one
            if client_session.client_id is not None and client_session.client_id != "":
                self.delete_session(client_session.client_id)
            else:
                client_session.client_id = gen_client_id()
            client_session.parent = 0
        else:
            # Get session from cache
            if client_session.client_id in self._sessions:
                self.logger.debug(
                    "Found old session %s"
                    % repr(self._sessions[client_session.client_id])
                )
                (client_session, h) = self._sessions[client_session.client_id]
                client_session.parent = 1
            else:
                client_session.parent = 0
        if client_session.keep_alive > 0:
            client_session.keep_alive += self.config["timeout-disconnect-delay"]
        self.logger.debug("Keep-alive timeout=%d" % client_session.keep_alive)

        handler.attach(client_session, reader, writer)
        self._sessions[client_session.client_id] = (client_session, handler)

        authenticated = await self.authenticate(
            client_session, self.listeners_config[listener_name]
        )
        if not authenticated:
            await writer.close()
            server.release_connection()  # Delete client from connections list
            return

        while True:
            try:
                client_session.transitions.connect()
                break
            except (MachineError, ValueError):
                # Backwards compat: MachineError is raised by transitions < 0.5.0.
                self.logger.warning(
                    "Client %s is reconnecting too quickly, make it wait"
                    % client_session.client_id
                )
                # Wait a bit may be client is reconnecting too fast
                await asyncio.sleep(1, loop=self._loop)
        await handler.mqtt_connack_authorize(authenticated)

        await self.plugins_manager.fire_event(
            EVENT_BROKER_CLIENT_CONNECTED, client_id=client_session.client_id
        )

        self.logger.debug("%s Start messages handling" % client_session.client_id)
        await handler.start()
        self.logger.debug(
            "Retained messages queue size: %d"
            % client_session.retained_messages.qsize()
        )
        await self.publish_session_retained_messages(client_session)

        # Init and start loop for handling client messages (publish, subscribe/unsubscribe, disconnect)
        disconnect_waiter = asyncio.ensure_future(
            handler.wait_disconnect(), loop=self._loop
        )
        subscribe_waiter = asyncio.ensure_future(
            handler.get_next_pending_subscription(), loop=self._loop
        )
        unsubscribe_waiter = asyncio.ensure_future(
            handler.get_next_pending_unsubscription(), loop=self._loop
        )
        wait_deliver = asyncio.ensure_future(
            handler.mqtt_deliver_next_message(), loop=self._loop
        )
        connected = True
        while connected:
            try:
                done, pending = await asyncio.wait(
                    [
                        disconnect_waiter,
                        subscribe_waiter,
                        unsubscribe_waiter,
                        wait_deliver,
                    ],
                    return_when=asyncio.FIRST_COMPLETED,
                    loop=self._loop,
                )
                if disconnect_waiter in done:
                    result = disconnect_waiter.result()
                    self.logger.debug(
                        "%s Result from wait_diconnect: %s"
                        % (client_session.client_id, result)
                    )
                    if result is None:
                        self.logger.debug("Will flag: %s" % client_session.will_flag)
                        # Connection closed anormally, send will message
                        if client_session.will_flag:
                            self.logger.debug(
                                "Client %s disconnected abnormally, sending will message"
                                % format_client_message(client_session)
                            )
                            await self._broadcast_message(
                                client_session,
                                client_session.will_topic,
                                client_session.will_message,
                                client_session.will_qos,
                            )
                            if client_session.will_retain:
                                self.retain_message(
                                    client_session,
                                    client_session.will_topic,
                                    client_session.will_message,
                                    client_session.will_qos,
                                )
                    self.logger.debug(
                        "%s Disconnecting session" % client_session.client_id
                    )
                    await self._stop_handler(handler)
                    client_session.transitions.disconnect()
                    await self.plugins_manager.fire_event(
                        EVENT_BROKER_CLIENT_DISCONNECTED,
                        client_id=client_session.client_id,
                    )
                    connected = False
                if unsubscribe_waiter in done:
                    self.logger.debug(
                        "%s handling unsubscription" % client_session.client_id
                    )
                    unsubscription = unsubscribe_waiter.result()
                    for topic in unsubscription["topics"]:
                        self._del_subscription(topic, client_session)
                        await self.plugins_manager.fire_event(
                            EVENT_BROKER_CLIENT_UNSUBSCRIBED,
                            client_id=client_session.client_id,
                            topic=topic,
                        )
                    await handler.mqtt_acknowledge_unsubscription(
                        unsubscription["packet_id"]
                    )
                    unsubscribe_waiter = asyncio.Task(
                        handler.get_next_pending_unsubscription(), loop=self._loop
                    )
                if subscribe_waiter in done:
                    self.logger.debug(
                        "%s handling subscription" % client_session.client_id
                    )
                    subscriptions = subscribe_waiter.result()
                    return_codes = []
                    for subscription in subscriptions["topics"]:
                        result = await self.add_subscription(
                            subscription, client_session
                        )
                        return_codes.append(result)
                    await handler.mqtt_acknowledge_subscription(
                        subscriptions["packet_id"], return_codes
                    )
                    for index, subscription in enumerate(subscriptions["topics"]):
                        if return_codes[index] != 0x80:
                            await self.plugins_manager.fire_event(
                                EVENT_BROKER_CLIENT_SUBSCRIBED,
                                client_id=client_session.client_id,
                                topic=subscription[0],
                                qos=subscription[1],
                            )
                            await self.publish_retained_messages_for_subscription(
                                subscription, client_session
                            )
                    subscribe_waiter = asyncio.Task(
                        handler.get_next_pending_subscription(), loop=self._loop
                    )
                    self.logger.debug(repr(self._subscriptions))
                if wait_deliver in done:
                    if self.logger.isEnabledFor(logging.DEBUG):
                        self.logger.debug(
                            "%s handling message delivery" % client_session.client_id
                        )
                    app_message = wait_deliver.result()
                    if not app_message.topic:
                        self.logger.warning(
                            "[MQTT-4.7.3-1] - %s invalid TOPIC sent in PUBLISH message, closing connection"
                            % client_session.client_id
                        )
                        break
                    if "#" in app_message.topic or "+" in app_message.topic:
                        self.logger.warning(
                            "[MQTT-3.3.2-2] - %s invalid TOPIC sent in PUBLISH message, closing connection"
                            % client_session.client_id
                        )
                        break
                    await self.plugins_manager.fire_event(
                        EVENT_BROKER_MESSAGE_RECEIVED,
                        client_id=client_session.client_id,
                        message=app_message,
                    )
                    await self._broadcast_message(
                        client_session, app_message.topic, app_message.data
                    )
                    if app_message.publish_packet.retain_flag:
                        self.retain_message(
                            client_session,
                            app_message.topic,
                            app_message.data,
                            app_message.qos,
                        )
                    wait_deliver = asyncio.Task(
                        handler.mqtt_deliver_next_message(), loop=self._loop
                    )
            except asyncio.CancelledError:
                self.logger.debug("Client loop cancelled")
                break
        disconnect_waiter.cancel()
        subscribe_waiter.cancel()
        unsubscribe_waiter.cancel()
        wait_deliver.cancel()

        self.logger.debug("%s Client disconnected" % client_session.client_id)
        server.release_connection()
Ejemplo n.º 9
0
def test_format_client_message(client_id):
    test_session = Session()
    test_session.client_id = client_id
    client_message = utils.format_client_message(session=test_session)

    assert client_message == f"(client id={client_id})"
Ejemplo n.º 10
0
def test_format_client_message_unknown():
    client_message = utils.format_client_message()
    assert client_message == "(unknown client)"
Ejemplo n.º 11
0
def test_format_client_message_valid(url, port):
    client_message = utils.format_client_message(address=url, port=port)
    assert client_message == f"(client @={url}:{port})"