Exemplo n.º 1
0
    async def _broadcast_loop(self):
        async with anyio.create_task_group() as tg:
            while True:
                broadcast = await self._broadcast_queue_r.receive()
                # self.logger.debug("broadcasting %r", broadcast)
                topic = broadcast["topic"]
                if isinstance(topic, str):
                    topic = topic.split("/")

                targets = {}
                for k_filter, subscriptions in self._subscriptions.items():
                    if match_topic(topic, k_filter):
                        for (target_session, qos) in subscriptions:
                            qos = max(
                                qos,
                                broadcast.get("qos", QOS_0),
                                targets.get(target_session, QOS_0),
                            )
                            targets[target_session] = qos

                for target_session, qos in targets.items():
                    if target_session.transitions.state == "connected":
                        if False and self.logger.isEnabledFor(logging.DEBUG):
                            self.logger.debug(
                                "broadcasting application message from %s on topic '%s' to %s",
                                format_client_message(
                                    session=broadcast["session"]),
                                broadcast["topic"],
                                format_client_message(session=target_session),
                            )
                        handler = self._get_handler(target_session)
                        await tg.spawn(
                            partial(
                                handler.mqtt_publish,
                                broadcast["topic"],
                                broadcast["data"],
                                qos,
                                retain=False,
                            ))
                    else:
                        if self.logger.isEnabledFor(logging.DEBUG):
                            self.logger.debug(
                                "retaining application message from %s on topic '%s' to client '%s'",
                                format_client_message(
                                    session=broadcast["session"]),
                                broadcast["topic"],
                                format_client_message(session=target_session),
                            )
                        retained_message = RetainedApplicationMessage(
                            broadcast["session"],
                            broadcast["topic"],
                            broadcast["data"],
                            qos,
                        )
                        await target_session.retained_messages.put(
                            retained_message)
Exemplo n.º 2
0
 def _del_subscription(self, a_filter, session):
     """
     Delete a session subscription on a given topic
     :param a_filter:
     :param session:
     :return:
     """
     deleted = 0
     if isinstance(a_filter, str):
         a_filter = tuple(a_filter.split("/"))
     try:
         subscriptions = self._subscriptions[a_filter]
         for index, (sub_session, _) in enumerate(
             subscriptions
         ):  # pylint: disable=unused-variable
             if sub_session.client_id == session.client_id:
                 if self.logger.isEnabledFor(logging.DEBUG):
                     self.logger.debug(
                         "Removing subscription on topic '%s' for client %s",
                         a_filter,
                         format_client_message(session=session),
                     )
                 subscriptions.pop(index)
                 deleted += 1
                 break
     except KeyError:
         # Unsubscribe topic not found in current subscribed topics
         pass
     return deleted
Exemplo n.º 3
0
 async def publish_retained_messages_for_subscription(
         self, subscription, session):
     #       if self.logger.isEnabledFor(logging.DEBUG):
     #           self.logger.debug("Begin broadcasting messages retained due to subscription on '%s' from %s",
     #                             subscription[0], format_client_message(session=session))
     sub = subscription[0].split("/")
     handler = self._get_handler(session)
     async with anyio.create_task_group() as tg:
         for d_topic in self._retained_messages:
             topic = d_topic.split("/")
             self.logger.debug("matching : %s %s", d_topic, subscription[0])
             if match_topic(topic, sub):
                 self.logger.debug("%s and %s match", d_topic,
                                   subscription[0])
                 retained = self._retained_messages[d_topic]
                 await tg.spawn(
                     handler.mqtt_publish,
                     retained.topic,
                     retained.data,
                     subscription[1],
                     True,
                 )
     if self.logger.isEnabledFor(logging.DEBUG):
         self.logger.debug(
             "End broadcasting messages retained due to subscription on '%s' from %s",
             subscription[0],
             format_client_message(session=session),
         )
Exemplo n.º 4
0
    async def add_subscription(self, subscription, session):
        a_filter = subscription[0]
        if '#' in a_filter and not a_filter.endswith('#'):
            # [MQTT-4.7.1-2] Wildcard character '#' is only allowed as last character in filter
            return 0x80
        if a_filter != "+":
            if '+' in a_filter:
                if "/+" not in a_filter and "+/" not in a_filter:
                    # [MQTT-4.7.1-3] + wildcard character must occupy entire level
                    return 0x80

        # Check if the client is authorised to connect to the topic
        permitted = await self.topic_filtering(session, topic=a_filter)
        if not permitted:
            return 0x80

        a_filter = tuple(a_filter.split('/'))
        qos = subscription[1]
        if 'max-qos' in self.config and qos > self.config['max-qos']:
            qos = self.config['max-qos']
        if a_filter not in self._subscriptions:
            self._subscriptions[a_filter] = []
        already_subscribed = next(
            (s for (s, qos) in self._subscriptions[a_filter]
             if s.client_id == session.client_id), None)
        if not already_subscribed:
            self._subscriptions[a_filter].append((session, qos))
        else:
            if self.logger.isEnabledFor(logging.DEBUG):
                self.logger.debug("Client %s has already subscribed to %s",
                                  format_client_message(session=session),
                                  a_filter)
        return qos
Exemplo n.º 5
0
 async def handle_connect(self, connect: ConnectPacket):
     # Broker handler shouldn't received CONNECT message during messages handling
     # as CONNECT messages are managed by the broker on client connection
     self.logger.error(
         '%s [MQTT-3.1.0-2] %s : CONNECT message received during messages handling'
         % (self.session.client_id, format_client_message(self.session)))
     await self.stop()
Exemplo n.º 6
0
    async def _broadcast_loop(self):
        async with anyio.create_task_group() as tg:
            while True:
                broadcast = await self._broadcast_queue.get()
                self.logger.debug("broadcasting %r", broadcast)
                topic = broadcast['topic'].split('/')

                targets = {}
                for k_filter, subscriptions in self._subscriptions.items():
                    if match_topic(topic, k_filter):
                        for (target_session, qos) in subscriptions:
                            qos = max(qos, broadcast.get('qos', QOS_0),
                                      targets.get(target_session, QOS_0))
                            targets[target_session] = qos

                for target_session, qos in targets.items():
                    if target_session.transitions.state == 'connected':
                        if self.logger.isEnabledFor(logging.DEBUG):
                            self.logger.debug(
                                "broadcasting application message from %s on topic '%s' to %s",
                                format_client_message(
                                    session=broadcast['session']),
                                broadcast['topic'],
                                format_client_message(session=target_session))
                        handler = self._get_handler(target_session)
                        await tg.spawn(
                            partial(handler.mqtt_publish,
                                    broadcast['topic'],
                                    broadcast['data'],
                                    qos,
                                    retain=False))
                    else:
                        if self.logger.isEnabledFor(logging.DEBUG):
                            self.logger.debug(
                                "retaining application message from %s on topic '%s' to client '%s'",
                                format_client_message(
                                    session=broadcast['session']),
                                broadcast['topic'],
                                format_client_message(session=target_session))
                        retained_message = RetainedApplicationMessage(
                            broadcast['session'], broadcast['topic'],
                            broadcast['data'], qos)
                        await target_session.retained_messages.put(
                            retained_message)
Exemplo n.º 7
0
 async def publish_session_retained_messages(self, session):
     if self.logger.isEnabledFor(logging.DEBUG):
         self.logger.debug("Publishing %d messages retained for session %s",
                           session.retained_messages.qsize(),
                           format_client_message(session=session))
     handler = self._get_handler(session)
     async with anyio.create_task_group() as tg:
         while not session.retained_messages.empty():
             retained = await session.retained_messages.get()
             await tg.spawn(handler.mqtt_publish, retained.topic,
                            retained.data, retained.qos, True)
Exemplo n.º 8
0
    async def client_connected_(self, listener_name, adapter: BaseAdapter):
        # Wait for connection available on listener

        remote_address, remote_port = adapter.get_peer_info()
        self.logger.info(
            "Connection from %s:%d on listener '%s'", remote_address, remote_port, listener_name,
        )

        # Wait for first packet and expect a CONNECT
        try:
            handler, client_session = await BrokerProtocolHandler.init_from_connect(
                adapter, self.plugins_manager
            )
        except DistMQTTException as exc:
            self.logger.warning(
                "[MQTT-3.1.0-1] %s: Can't read first packet an CONNECT: %s",
                format_client_message(address=remote_address, port=remote_port),
                exc,
            )
            # await writer.close()
            self.logger.debug("Connection closed")
            return
        except MQTTException as me:
            self.logger.error(
                "Invalid connection from %s : %s",
                format_client_message(address=remote_address, port=remote_port),
                me,
            )
            await adapter.close()
            self.logger.debug("Connection closed")
            return

        if client_session.clean_session:
            # Delete existing session and create a new one
            if client_session.client_id is not None and client_session.client_id != "":
                await self.delete_session(client_session.client_id)
            else:
                client_session.client_id = gen_client_id()
            client_session.parent = 0
        else:
            # Get session from cache
            if client_session.client_id in self._sessions:
                self.logger.debug("Found old session %r", self._sessions[client_session.client_id])
                client_session = self._sessions[client_session.client_id][0]
                client_session.parent = 1
            else:
                client_session.parent = 0
        if not client_session.parent:
            await client_session.start(self)
        if client_session.keep_alive > 0 and not client_session.parent:
            # MQTT 3.1.2.10: one and a half keepalive times, plus configurable grace
            client_session.keep_alive += (
                client_session.keep_alive / 2 + self.config["timeout-disconnect-delay"]
            )
        self.logger.debug("Keep-alive timeout=%d", client_session.keep_alive)

        await handler.attach(client_session, adapter)
        self._sessions[client_session.client_id] = (client_session, handler)

        authenticated = await self.authenticate(client_session)
        if not authenticated:
            await adapter.close()
            return

        while True:
            try:
                client_session.transitions.connect()
                break
            except (MachineError, ValueError) as exc:
                # Backwards compat: MachineError is raised by transitions < 0.5.0.
                self.logger.warning(
                    "Client %s is reconnecting too quickly, make it wait",
                    client_session.client_id,
                    exc_info=exc,
                )
                # Wait a bit may be client is reconnecting too fast
                await anyio.sleep(1)
        await handler.mqtt_connack_authorize(authenticated)

        await self.plugins_manager.fire_event(
            EVENT_BROKER_CLIENT_CONNECTED, client_id=client_session.client_id
        )

        self.logger.debug("%s Start messages handling", client_session.client_id)
        await handler.start()
        if self._do_retain:
            self.logger.debug(
                "Retained messages queue size: %d", client_session.retained_messages.qsize(),
            )
            await self.publish_session_retained_messages(client_session)

        # Init and start loop for handling client messages (publish, subscribe/unsubscribe, disconnect)
        async with anyio.create_task_group() as tg:

            async def handle_unsubscribe():
                while True:
                    unsubscription = await handler.get_next_pending_unsubscription()
                    self.logger.debug("%s handling unsubscription", client_session.client_id)
                    for topic in unsubscription["topics"]:
                        self._del_subscription(topic, client_session)
                        await self.plugins_manager.fire_event(
                            EVENT_BROKER_CLIENT_UNSUBSCRIBED,
                            client_id=client_session.client_id,
                            topic=topic,
                        )
                    await handler.mqtt_acknowledge_unsubscription(unsubscription["packet_id"])

            async def handle_subscribe():
                while True:
                    subscriptions = await handler.get_next_pending_subscription()
                    self.logger.debug("%s handling subscription", client_session.client_id)
                    return_codes = []
                    for subscription in subscriptions["topics"]:
                        result = await self.add_subscription(subscription, client_session)
                        return_codes.append(result)
                    await handler.mqtt_acknowledge_subscription(
                        subscriptions["packet_id"], return_codes
                    )
                    for index, subscription in enumerate(subscriptions["topics"]):
                        if return_codes[index] != 0x80:
                            await self.plugins_manager.fire_event(
                                EVENT_BROKER_CLIENT_SUBSCRIBED,
                                client_id=client_session.client_id,
                                topic=subscription[0],
                                qos=subscription[1],
                            )
                            if self._do_retain:
                                await self.publish_retained_messages_for_subscription(
                                    subscription, client_session
                                )
                    self.logger.debug(repr(self._subscriptions))

            await tg.spawn(handle_unsubscribe)
            await tg.spawn(handle_subscribe)

            try:
                await handler.wait_disconnect()
                self.logger.debug(
                    "%s wait_diconnect: %sclean",
                    client_session.client_id,
                    "" if handler.clean_disconnect else "un",
                )

                if not handler.clean_disconnect:
                    # Connection closed anormally, send will message
                    self.logger.debug("Will flag: %s", client_session.will_flag)
                    if client_session.will_flag:
                        if self.logger.isEnabledFor(logging.DEBUG):
                            self.logger.debug(
                                "Client %s disconnected abnormally, sending will message",
                                format_client_message(session=client_session),
                            )
                        await self.broadcast_message(
                            client_session,
                            client_session.will_topic,
                            client_session.will_message,
                            client_session.will_qos,
                            retain=client_session.will_retain,
                        )
                self.logger.debug("%s Disconnecting session", client_session.client_id)
                await self._stop_handler(handler)
                client_session.transitions.disconnect()
                await self.plugins_manager.fire_event(
                    EVENT_BROKER_CLIENT_DISCONNECTED, client_id=client_session.client_id
                )
            finally:
                async with anyio.fail_after(2, shield=True):
                    await tg.cancel_scope.cancel()
            pass  # end taskgroup

        self.logger.debug("%s Client disconnected", client_session.client_id)
Exemplo n.º 9
0
    async def init_from_connect(cls, stream: StreamAdapter, plugins_manager):
        """

        :param stream:
        :param plugins_manager:
        :return:
        """
        remote_address, remote_port = stream.get_peer_info()
        try:
            connect = await ConnectPacket.from_stream(stream)
        except NoDataException:
            raise MQTTException("Client closed the connection")
        logger.debug("< B %r", connect)
        await plugins_manager.fire_event(EVENT_MQTT_PACKET_RECEIVED,
                                         packet=connect)
        # this shouldn't be required anymore since broker generates for each client a random client_id if not provided
        # [MQTT-3.1.3-6]
        if connect.payload.client_id is None:
            raise MQTTException(
                "[[MQTT-3.1.3-3]] : Client identifier must be present")

        if connect.variable_header.will_flag:
            if connect.payload.will_topic is None or connect.payload.will_message is None:
                raise MQTTException(
                    "will flag set, but will topic/message not present in payload"
                )

        if connect.variable_header.reserved_flag:
            raise MQTTException(
                "[MQTT-3.1.2-3] CONNECT reserved flag must be set to 0")
        if connect.proto_name != "MQTT":
            raise MQTTException(
                '[MQTT-3.1.2-1] Incorrect protocol name: "%s"' %
                connect.proto_name)

        connack = None
        error_msg = None
        if connect.proto_level != 4:
            # only MQTT 3.1.1 supported
            error_msg = "Invalid protocol from %s: %d" % (
                format_client_message(address=remote_address,
                                      port=remote_port),
                connect.proto_level,
            )
            connack = ConnackPacket.build(0, UNACCEPTABLE_PROTOCOL_VERSION
                                          )  # [MQTT-3.2.2-4] session_parent=0
        elif not connect.username_flag and connect.password_flag:
            connack = ConnackPacket.build(
                0, BAD_USERNAME_PASSWORD)  # [MQTT-3.1.2-22]
        elif connect.username_flag and not connect.password_flag:
            connack = ConnackPacket.build(
                0, BAD_USERNAME_PASSWORD)  # [MQTT-3.1.2-22]
        elif connect.username_flag and connect.username is None:
            error_msg = "Invalid username from %s" % (format_client_message(
                address=remote_address, port=remote_port))
            connack = ConnackPacket.build(
                0, BAD_USERNAME_PASSWORD)  # [MQTT-3.2.2-4] session_parent=0
        elif connect.password_flag and connect.password is None:
            error_msg = "Invalid password %s" % (format_client_message(
                address=remote_address, port=remote_port))
            connack = ConnackPacket.build(
                0, BAD_USERNAME_PASSWORD)  # [MQTT-3.2.2-4] session_parent=0
        elif connect.clean_session_flag is False and (
                connect.payload.client_id_is_random):
            error_msg = (
                "[MQTT-3.1.3-8] [MQTT-3.1.3-9] %s: No client Id provided (cleansession=0)"
                % (format_client_message(address=remote_address,
                                         port=remote_port)))
            connack = ConnackPacket.build(0, IDENTIFIER_REJECTED)
        if connack is not None:
            logger.debug("B > %r", connack)
            await plugins_manager.fire_event(EVENT_MQTT_PACKET_SENT,
                                             packet=connack)
            await connack.to_stream(stream)

            await stream.close()
            raise MQTTException(error_msg)

        incoming_session = Session(plugins_manager)
        incoming_session.client_id = connect.client_id
        incoming_session.clean_session = connect.clean_session_flag
        incoming_session.will_flag = connect.will_flag
        incoming_session.will_retain = connect.will_retain_flag
        incoming_session.will_qos = connect.will_qos
        incoming_session.will_topic = connect.will_topic
        incoming_session.will_message = connect.will_message
        incoming_session.username = connect.username
        incoming_session.password = connect.password
        if connect.keep_alive > 0:
            incoming_session.keep_alive = connect.keep_alive
        else:
            incoming_session.keep_alive = 0

        handler = cls(plugins_manager)
        return handler, incoming_session