class MQTTClient(APITransport): def __init__(self, config: ConfigHelper) -> None: self.server = config.get_server() self.address: str = config.get('address') self.port: int = config.getint('port', 1883) self.user_name = config.get('username', None) pw_file_path = config.get('password_file', None) self.password: Optional[str] = None if pw_file_path is not None: pw_file = pathlib.Path(pw_file_path).expanduser().absolute() if not pw_file.exists(): raise config.error(f"Password file '{pw_file}' does not exist") self.password = pw_file.read_text().strip() protocol = config.get('mqtt_protocol', "v3.1.1") self.protocol = MQTT_PROTOCOLS.get(protocol, None) if self.protocol is None: raise config.error( f"Invalid value '{protocol}' for option 'mqtt_protocol' " "in section [mqtt]. Must be one of " f"{MQTT_PROTOCOLS.values()}") self.instance_name = config.get('instance_name', socket.gethostname()) if '+' in self.instance_name or '#' in self.instance_name: raise config.error( "Option 'instance_name' in section [mqtt] cannot " "contain a wildcard.") self.qos = config.getint("default_qos", 0) if self.qos > 2 or self.qos < 0: raise config.error( "Option 'default_qos' in section [mqtt] must be " "between 0 and 2") self.client = paho_mqtt.Client(protocol=self.protocol) self.client.on_connect = self._on_connect self.client.on_message = self._on_message self.client.on_disconnect = self._on_disconnect self.client.on_publish = self._on_publish self.client.on_subscribe = self._on_subscribe self.client.on_unsubscribe = self._on_unsubscribe self.connect_evt: asyncio.Event = asyncio.Event() self.disconnect_evt: Optional[asyncio.Event] = None self.reconnect_task: Optional[asyncio.Task] = None self.subscribed_topics: SubscribedDict = {} self.pending_responses: List[asyncio.Future] = [] self.pending_acks: Dict[int, asyncio.Future] = {} self.server.register_endpoint("/server/mqtt/publish", ["POST"], self._handle_publish_request, transports=["http", "websocket"]) self.server.register_endpoint("/server/mqtt/subscribe", ["POST"], self._handle_subscription_request, transports=["http", "websocket"]) # Subscribe to API requests self.json_rpc = JsonRPC(transport="MQTT") self.api_request_topic = f"{self.instance_name}/moonraker/api/request" self.api_resp_topic = f"{self.instance_name}/moonraker/api/response" self.timestamp_deque: Deque = deque(maxlen=20) self.api_qos = config.getint('api_qos', self.qos) if config.getboolean("enable_moonraker_api", True): api_cache = self.server.register_api_transport("mqtt", self) for api_def in api_cache.values(): if "mqtt" in api_def.supported_transports: self.register_api_handler(api_def) self.subscribe_topic(self.api_request_topic, self._process_api_request, self.api_qos) logging.info( f"Moonraker API topics - Request: {self.api_request_topic}, " f"Response: {self.api_resp_topic}") IOLoop.current().spawn_callback(self._initialize) async def _initialize(self) -> None: # We must wait for the IOLoop (asyncio event loop) to start # prior to retreiving it self.helper = AIOHelper(self.client) if self.user_name is not None: self.client.username_pw_set(self.user_name, self.password) retries = 15 while retries: try: self.client.connect(self.address, self.port) except ConnectionRefusedError: retries -= 1 if retries: logging.info("Unable to connect to MQTT broker, " f"retries remaining: {retries}") await asyncio.sleep(2.) continue self.server.set_failed_component("mqtt") self.server.add_warning( f"MQTT Broker Connection at ({self.address}, {self.port}) " "refused. Check your client and broker configuration.") return break self.client.socket().setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 2048) def _on_message(self, client: str, user_data: Any, message: paho_mqtt.MQTTMessage) -> None: topic = message.topic if topic in self.subscribed_topics: cb_hdls = self.subscribed_topics[topic][1] for hdl in cb_hdls: IOLoop.current().spawn_callback( hdl.callback, message.payload) # type: ignore else: logging.debug(f"Unregistered MQTT Topic Received: {topic}, " f"payload: {message.payload.decode()}") def _on_connect(self, client: paho_mqtt.Client, user_data: Any, flags: Dict[str, Any], reason_code: Union[int, paho_mqtt.ReasonCodes], properties: Optional[paho_mqtt.Properties] = None) -> None: logging.info("MQTT Client Connected") if reason_code == 0: subs = [(k, v[0]) for k, v in self.subscribed_topics.items()] if subs: res, msg_id = client.subscribe(subs) if msg_id is not None: sub_fut: asyncio.Future = asyncio.Future() topics = list(self.subscribed_topics.keys()) sub_fut.add_done_callback( BrokerAckLogger(topics, "subscribe")) self.pending_acks[msg_id] = sub_fut self.connect_evt.set() else: if isinstance(reason_code, int): err_str = paho_mqtt.connack_string(reason_code) else: err_str = reason_code.getName() self.server.set_failed_component("mqtt") self.server.add_warning(f"MQTT Connection Failed: {err_str}") def _on_disconnect( self, client: paho_mqtt.Client, user_data: Any, reason_code: int, properties: Optional[paho_mqtt.Properties] = None) -> None: if self.disconnect_evt is not None: self.disconnect_evt.set() elif self.is_connected(): # The server connection was dropped, attempt to reconnect logging.info("MQTT Server Disconnected, reason: " f"{paho_mqtt.error_string(reason_code)}") if self.reconnect_task is None: self.reconnect_task = asyncio.create_task(self._do_reconnect()) self.connect_evt.clear() def _on_publish(self, client: paho_mqtt.Client, user_data: Any, msg_id: int) -> None: pub_fut = self.pending_acks.pop(msg_id, None) if pub_fut is not None and not pub_fut.done(): pub_fut.set_result(None) def _on_subscribe( self, client: paho_mqtt.Client, user_data: Any, msg_id: int, flex: Union[List[int], List[paho_mqtt.ReasonCodes]], properties: Optional[paho_mqtt.Properties] = None) -> None: sub_fut = self.pending_acks.pop(msg_id, None) if sub_fut is not None and not sub_fut.done(): sub_fut.set_result(flex) def _on_unsubscribe( self, client: paho_mqtt.Client, user_data: Any, msg_id: int, properties: Optional[paho_mqtt.Properties] = None, reasoncodes: Optional[paho_mqtt.ReasonCodes] = None) -> None: unsub_fut = self.pending_acks.pop(msg_id, None) if unsub_fut is not None and not unsub_fut.done(): unsub_fut.set_result(None) async def _do_reconnect(self) -> None: logging.info("Attempting MQTT Reconnect") while True: try: await asyncio.sleep(2.) except asyncio.CancelledError: break try: self.client.reconnect() except ConnectionRefusedError: continue self.client.socket().setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 2048) break self.reconnect_task = None async def wait_connection(self, timeout: Optional[float] = None) -> bool: try: await asyncio.wait_for(self.connect_evt.wait(), timeout) except asyncio.TimeoutError: return False return True def is_connected(self) -> bool: return self.connect_evt.is_set() def subscribe_topic(self, topic: str, callback: FlexCallback, qos: Optional[int] = None) -> SubscriptionHandle: if '#' in topic or '+' in topic: raise self.server.error("Wildcards may not be used") qos = qos or self.qos if qos > 2 or qos < 0: raise self.server.error("QOS must be between 0 and 2") hdl = SubscriptionHandle(topic, callback) sub_handles = [hdl] need_sub = True if topic in self.subscribed_topics: prev_qos, sub_handles = self.subscribed_topics[topic] qos = max(qos, prev_qos) sub_handles.append(hdl) need_sub = qos != prev_qos self.subscribed_topics[topic] = (qos, sub_handles) if self.is_connected() and need_sub: res, msg_id = self.client.subscribe(topic, qos) if msg_id is not None: sub_fut: asyncio.Future = asyncio.Future() sub_fut.add_done_callback(BrokerAckLogger([topic], "subscribe")) self.pending_acks[msg_id] = sub_fut return hdl def unsubscribe(self, hdl: SubscriptionHandle) -> None: topic = hdl.topic if topic in self.subscribed_topics: sub_hdls = self.subscribed_topics[topic][1] try: sub_hdls.remove(hdl) except Exception: pass if not sub_hdls: del self.subscribed_topics[topic] res, msg_id = self.client.unsubscribe(topic) if msg_id is not None: unsub_fut: asyncio.Future = asyncio.Future() unsub_fut.add_done_callback( BrokerAckLogger([topic], "unsubscribe")) self.pending_acks[msg_id] = unsub_fut def publish_topic(self, topic: str, payload: Any = None, qos: Optional[int] = None, retain: bool = False) -> Awaitable[None]: qos = qos or self.qos if qos > 2 or qos < 0: raise self.server.error("QOS must be between 0 and 2") pub_fut: asyncio.Future = asyncio.Future() if isinstance(payload, (dict, list)): try: payload = json.dumps(payload) except json.JSONDecodeError: raise self.server.error( "Dict or List is not json encodable") from None elif isinstance(payload, bool): payload = str(payload).lower() try: msg_info = self.client.publish(topic, payload, qos, retain) if msg_info.is_published(): pub_fut.set_result(None) else: if qos == 0: # There is no delivery guarantee for qos == 0, so # it is possible that the on_publish event will # not be called if paho mqtt encounters an error # during publication. Return immediately as # a workaround. if msg_info.rc != paho_mqtt.MQTT_ERR_SUCCESS: err_str = paho_mqtt.error_string(msg_info.rc) pub_fut.set_exception( self.server.error(f"MQTT Publish Error: {err_str}", 503)) else: pub_fut.set_result(None) return pub_fut self.pending_acks[msg_info.mid] = pub_fut except ValueError: pub_fut.set_exception( self.server.error("MQTT Message Queue Full", 529)) except Exception as e: pub_fut.set_exception( self.server.error(f"MQTT Publish Error: {e}", 503)) return pub_fut async def publish_topic_with_response( self, topic: str, response_topic: str, payload: Any = None, qos: Optional[int] = None, retain: bool = False, timeout: Optional[float] = None) -> bytes: qos = qos or self.qos if qos > 2 or qos < 0: raise self.server.error("QOS must be between 0 and 2") resp_fut: asyncio.Future = asyncio.Future() resp_hdl = self.subscribe_topic(response_topic, resp_fut.set_result, qos) self.pending_responses.append(resp_fut) try: await asyncio.wait_for( self.publish_topic(topic, payload, qos, retain), timeout) await asyncio.wait_for(resp_fut, timeout) except asyncio.TimeoutError: logging.info(f"Response to request {topic} timed out") raise self.server.error("MQTT Request Timed Out", 504) finally: try: self.pending_responses.remove(resp_fut) except Exception: pass self.unsubscribe(resp_hdl) return resp_fut.result() async def _handle_publish_request( self, web_request: WebRequest) -> Dict[str, Any]: topic: str = web_request.get_str("topic") payload: Any = web_request.get("payload", None) qos: int = web_request.get_int("qos", self.qos) retain: bool = web_request.get_boolean("retain", False) timeout: Optional[float] = web_request.get_float('timeout', None) try: await asyncio.wait_for( self.publish_topic(topic, payload, qos, retain), timeout) except asyncio.TimeoutError: raise self.server.error("MQTT Publish Timed Out", 504) return {"topic": topic} async def _handle_subscription_request( self, web_request: WebRequest) -> Dict[str, Any]: topic: str = web_request.get_str("topic") qos: int = web_request.get_int("qos", self.qos) timeout: Optional[float] = web_request.get_float('timeout', None) resp: asyncio.Future = asyncio.Future() hdl: Optional[SubscriptionHandle] = None try: hdl = self.subscribe_topic(topic, resp.set_result, qos) self.pending_responses.append(resp) await asyncio.wait_for(resp, timeout) ret: bytes = resp.result() except asyncio.TimeoutError: raise self.server.error("MQTT Subscribe Timed Out", 504) finally: try: self.pending_responses.remove(resp) except Exception: pass if hdl is not None: self.unsubscribe(hdl) try: payload = json.loads(ret) except json.JSONDecodeError: payload = ret.decode() return {'topic': topic, 'payload': payload} async def _process_api_request(self, payload: bytes) -> None: response = await self.json_rpc.dispatch(payload.decode()) if response is not None: await self.publish_topic(self.api_resp_topic, response, self.api_qos) def register_api_handler(self, api_def: APIDefinition) -> None: if api_def.callback is None: # Remote API, uses RPC to reach out to Klippy mqtt_method = api_def.jrpc_methods[0] rpc_cb = self._generate_remote_callback(api_def.endpoint) self.json_rpc.register_method(mqtt_method, rpc_cb) else: # Local API, uses local callback for mqtt_method, req_method in \ zip(api_def.jrpc_methods, api_def.request_methods): rpc_cb = self._generate_local_callback(api_def.endpoint, req_method, api_def.callback) self.json_rpc.register_method(mqtt_method, rpc_cb) logging.info("Registering MQTT JSON-RPC methods: " f"{', '.join(api_def.jrpc_methods)}") def remove_api_handler(self, api_def: APIDefinition) -> None: for jrpc_method in api_def.jrpc_methods: self.json_rpc.remove_method(jrpc_method) def _generate_local_callback( self, endpoint: str, request_method: str, callback: Callable[[WebRequest], Coroutine]) -> RPCCallback: async def func(**kwargs) -> Any: self._check_timestamp(kwargs) result = await callback( WebRequest(endpoint, kwargs, request_method)) return result return func def _generate_remote_callback(self, endpoint: str) -> RPCCallback: async def func(**kwargs) -> Any: self._check_timestamp(kwargs) result = await self.server.make_request( WebRequest(endpoint, kwargs)) return result return func def _check_timestamp(self, args: Dict[str, Any]) -> None: ts = args.pop("mqtt_timestamp", None) if ts is not None: if ts in self.timestamp_deque: logging.debug("Duplicate MQTT API request received") raise self.server.error("Duplicate MQTT Request", DUP_API_REQ_CODE) else: self.timestamp_deque.append(ts) async def close(self) -> None: if self.reconnect_task is not None: self.reconnect_task.cancel() self.reconnect_task = None if not self.is_connected(): return self.disconnect_evt = asyncio.Event() self.client.disconnect() try: await asyncio.wait_for(self.disconnect_evt.wait(), 2.) except asyncio.TimeoutError: logging.info("MQTT Disconnect Timeout") futs = list(self.pending_acks.values()) futs.extend(self.pending_responses) for fut in futs: if fut.done(): continue fut.set_exception(self.server.error("Moonraker Shutdown", 503))
def __init__(self, config: ConfigHelper) -> None: self.server = config.get_server() self.address: str = config.get('address') self.port: int = config.getint('port', 1883) self.user_name = config.get('username', None) pw_file_path = config.get('password_file', None) self.password: Optional[str] = None if pw_file_path is not None: pw_file = pathlib.Path(pw_file_path).expanduser().absolute() if not pw_file.exists(): raise config.error(f"Password file '{pw_file}' does not exist") self.password = pw_file.read_text().strip() protocol = config.get('mqtt_protocol', "v3.1.1") self.protocol = MQTT_PROTOCOLS.get(protocol, None) if self.protocol is None: raise config.error( f"Invalid value '{protocol}' for option 'mqtt_protocol' " "in section [mqtt]. Must be one of " f"{MQTT_PROTOCOLS.values()}") self.instance_name = config.get('instance_name', socket.gethostname()) if '+' in self.instance_name or '#' in self.instance_name: raise config.error( "Option 'instance_name' in section [mqtt] cannot " "contain a wildcard.") self.qos = config.getint("default_qos", 0) if self.qos > 2 or self.qos < 0: raise config.error( "Option 'default_qos' in section [mqtt] must be " "between 0 and 2") self.client = paho_mqtt.Client(protocol=self.protocol) self.client.on_connect = self._on_connect self.client.on_message = self._on_message self.client.on_disconnect = self._on_disconnect self.client.on_publish = self._on_publish self.client.on_subscribe = self._on_subscribe self.client.on_unsubscribe = self._on_unsubscribe self.connect_evt: asyncio.Event = asyncio.Event() self.disconnect_evt: Optional[asyncio.Event] = None self.reconnect_task: Optional[asyncio.Task] = None self.subscribed_topics: SubscribedDict = {} self.pending_responses: List[asyncio.Future] = [] self.pending_acks: Dict[int, asyncio.Future] = {} self.server.register_endpoint("/server/mqtt/publish", ["POST"], self._handle_publish_request, transports=["http", "websocket"]) self.server.register_endpoint("/server/mqtt/subscribe", ["POST"], self._handle_subscription_request, transports=["http", "websocket"]) # Subscribe to API requests self.json_rpc = JsonRPC(transport="MQTT") self.api_request_topic = f"{self.instance_name}/moonraker/api/request" self.api_resp_topic = f"{self.instance_name}/moonraker/api/response" self.timestamp_deque: Deque = deque(maxlen=20) self.api_qos = config.getint('api_qos', self.qos) if config.getboolean("enable_moonraker_api", True): api_cache = self.server.register_api_transport("mqtt", self) for api_def in api_cache.values(): if "mqtt" in api_def.supported_transports: self.register_api_handler(api_def) self.subscribe_topic(self.api_request_topic, self._process_api_request, self.api_qos) logging.info( f"Moonraker API topics - Request: {self.api_request_topic}, " f"Response: {self.api_resp_topic}") IOLoop.current().spawn_callback(self._initialize)
def __init__(self, config: ConfigHelper) -> None: self.server = config.get_server() self.event_loop = self.server.get_event_loop() self.address: str = config.get('address') self.port: int = config.getint('port', 1883) user = config.gettemplate('username', None) self.user_name: Optional[str] = None if user: self.user_name = user.render() pw_file_path = config.get('password_file', None, deprecate=True) pw_template = config.gettemplate('password', None) self.password: Optional[str] = None if pw_file_path is not None: pw_file = pathlib.Path(pw_file_path).expanduser().absolute() if not pw_file.exists(): raise config.error( f"Password file '{pw_file}' does not exist") self.password = pw_file.read_text().strip() if pw_template is not None: self.password = pw_template.render() protocol = config.get('mqtt_protocol', "v3.1.1") self.protocol = MQTT_PROTOCOLS.get(protocol, None) if self.protocol is None: raise config.error( f"Invalid value '{protocol}' for option 'mqtt_protocol' " "in section [mqtt]. Must be one of " f"{MQTT_PROTOCOLS.values()}") self.instance_name = config.get('instance_name', socket.gethostname()) if '+' in self.instance_name or '#' in self.instance_name: raise config.error( "Option 'instance_name' in section [mqtt] cannot " "contain a wildcard.") self.qos = config.getint("default_qos", 0) if self.qos > 2 or self.qos < 0: raise config.error( "Option 'default_qos' in section [mqtt] must be " "between 0 and 2") self.client = paho_mqtt.Client(protocol=self.protocol) self.client.on_connect = self._on_connect self.client.on_message = self._on_message self.client.on_disconnect = self._on_disconnect self.client.on_publish = self._on_publish self.client.on_subscribe = self._on_subscribe self.client.on_unsubscribe = self._on_unsubscribe self.connect_evt: asyncio.Event = asyncio.Event() self.disconnect_evt: Optional[asyncio.Event] = None self.reconnect_task: Optional[asyncio.Task] = None self.subscribed_topics: SubscribedDict = {} self.pending_responses: List[asyncio.Future] = [] self.pending_acks: Dict[int, asyncio.Future] = {} self.server.register_endpoint( "/server/mqtt/publish", ["POST"], self._handle_publish_request, transports=["http", "websocket", "internal"]) self.server.register_endpoint( "/server/mqtt/subscribe", ["POST"], self._handle_subscription_request, transports=["http", "websocket", "internal"]) # Subscribe to API requests self.json_rpc = JsonRPC(transport="MQTT") self.api_request_topic = f"{self.instance_name}/moonraker/api/request" self.api_resp_topic = f"{self.instance_name}/moonraker/api/response" self.klipper_status_topic = f"{self.instance_name}/klipper/status" self.moonraker_status_topic = f"{self.instance_name}/moonraker/status" status_cfg: Dict[str, Any] = config.getdict("status_objects", {}, allow_empty_fields=True) self.status_objs: Dict[str, Any] = {} for key, val in status_cfg.items(): if val is not None: self.status_objs[key] = [v.strip() for v in val.split(',') if v.strip()] else: self.status_objs[key] = None if status_cfg: logging.debug(f"MQTT: Status Objects Set: {self.status_objs}") self.server.register_event_handler("server:klippy_identified", self._handle_klippy_identified) self.timestamp_deque: Deque = deque(maxlen=20) self.api_qos = config.getint('api_qos', self.qos) if config.getboolean("enable_moonraker_api", True): api_cache = self.server.register_api_transport("mqtt", self) for api_def in api_cache.values(): if "mqtt" in api_def.supported_transports: self.register_api_handler(api_def) self.subscribe_topic(self.api_request_topic, self._process_api_request, self.api_qos) self.server.register_remote_method("publish_mqtt_topic", self._publish_from_klipper) logging.info( f"\nReserved MQTT topics:\n" f"API Request: {self.api_request_topic}\n" f"API Response: {self.api_resp_topic}\n" f"Moonraker Status: {self.moonraker_status_topic}\n" f"Klipper Status: {self.klipper_status_topic}")