def __init__(self, reactor): self._mqtt = MQTTServerTwistedProtocol(self, reactor) self._request_to_packetid = {} self._waiting_for_connect = None self._inflight_subscriptions = {} self._subrequest_to_mqtt_subrequest = {} self._subrequest_callbacks = {} self._topic_lookup = {} self._wamp_session = None
def make_test_items(handler): r = Clock() t = StringTransport() p = MQTTServerTwistedProtocol(handler, r) cp = MQTTClientParser() p.makeConnection(t) return r, t, p, cp
def test_keepalive_canceled_on_lost_connection(self): """ If a client connects with a timeout, and disconnects themselves, we will remove the timeout. """ sessions = {} h = BasicHandler() r = Clock() t = StringTransport() p = MQTTServerTwistedProtocol(h, r, sessions) p.makeConnection(t) data = ( # CONNECT, with keepalive of 2 b"101300044d51545404020002000774657374313233") for x in iterbytes(unhexlify(data)): p.dataReceived(x) self.assertEqual(len(r.calls), 1) self.assertEqual(r.calls[0].getTime(), 3.0) timeout = r.calls[0] # Clean connection lost p.connectionLost(None) self.assertEqual(len(r.calls), 0) self.assertTrue(timeout.cancelled) self.assertFalse(timeout.called)
def test_subscribe_always_gets_packet(self): """ Subscriptions always get a ConnACK, even if none of the subscriptions were successful. Compliance statements MQTT-3.8.4-1 """ sessions = {} class SubHandler(BasicHandler): def process_subscribe(self, event): return succeed([128]) h = SubHandler() r = Clock() t = StringTransport() p = MQTTServerTwistedProtocol(h, r, sessions) cp = MQTTClientParser() p.makeConnection(t) data = (Connect(client_id=u"test123", flags=ConnectFlags(clean_session=True)).serialise() + Subscribe(packet_identifier=1234, topic_requests=[SubscriptionTopicRequest(u"a", 0) ]).serialise()) for x in iterbytes(data): p.dataReceived(x) events = cp.data_received(t.value()) self.assertEqual(len(events), 2) self.assertEqual(events[1].return_codes, [128])
def test_non_zero_connect_code_must_have_no_present_session(self): """ A non-zero connect code in a CONNACK must be paired with no session present. Compliance statement MQTT-3.2.2-4 """ sessions = {} d = Deferred() # noqa h = BasicHandler(self.connect_code) r = Clock() t = StringTransport() p = MQTTServerTwistedProtocol(h, r, sessions) cp = MQTTClientParser() p.makeConnection(t) data = (Connect(client_id=u"test123", flags=ConnectFlags(clean_session=False)).serialise()) for x in iterbytes(data): p.dataReceived(x) events = cp.data_received(t.value()) self.assertEqual(len(events), 1) self.assertEqual(attr.asdict(events[0]), { 'return_code': self.connect_code, 'session_present': False, })
def test_lose_conn_on_protocol_violation(self): """ When a protocol violation occurs, the connection to the client will be terminated, and an error will be logged. Compliance statement MQTT-4.8.0-1 """ sessions = {} h = BasicHandler() r = Clock() t = StringTransport() p = MQTTServerTwistedProtocol(h, r, sessions) p.makeConnection(t) data = ( # Invalid CONNECT b"111300044d51545404020002000774657374313233") with LogCapturer("trace") as logs: for x in iterbytes(unhexlify(data)): p.dataReceived(x) sent_logs = logs.get_category("MQ401") self.assertEqual(len(sent_logs), 1) self.assertEqual(sent_logs[0]["log_level"], LogLevel.error) self.assertIn("Connect", logs.log_text.getvalue()) self.assertEqual(t.value(), b'') self.assertTrue(t.disconnecting)
def test_unknown_connect_code_must_lose_connection(self): """ A non-zero, and non-1-to-5 connect code from the handler must result in a lost connection, and no CONNACK. Compliance statements MQTT-3.2.2-4, MQTT-3.2.2-5 """ sessions = {} d = Deferred() # noqa h = BasicHandler(6) r = Clock() t = StringTransport() p = MQTTServerTwistedProtocol(h, r, sessions) cp = MQTTClientParser() # noqa p.makeConnection(t) data = (Connect(client_id=u"test123", flags=ConnectFlags(clean_session=False)).serialise()) for x in iterbytes(data): p.dataReceived(x) self.assertTrue(t.disconnecting) self.assertEqual(t.value(), b'')
def test_transport_paused_while_processing(self): """ The transport is paused whilst the MQTT protocol is parsing/handling existing items. """ sessions = {} d = Deferred() h = BasicHandler() h.process_connect = lambda x: d r = Clock() t = StringTransport() p = MQTTServerTwistedProtocol(h, r, sessions) t.connected = True p.makeConnection(t) data = (Connect(client_id=u"test123", flags=ConnectFlags(clean_session=False)).serialise()) self.assertEqual(t.producerState, 'producing') for x in iterbytes(data): p.dataReceived(x) self.assertEqual(t.producerState, 'paused') d.callback(0) self.assertEqual(t.producerState, 'producing')
def test_send_packet(self): """ On sending a packet, a trace log message is emitted with details of the sent packet. """ sessions = {} h = BasicHandler() r = Clock() t = StringTransport() p = MQTTServerTwistedProtocol(h, r, sessions) cp = MQTTClientParser() p.makeConnection(t) data = ( # CONNECT b"101300044d51545404020002000774657374313233") with LogCapturer("trace") as logs: for x in iterbytes(unhexlify(data)): p.dataReceived(x) sent_logs = logs.get_category("MQ101") self.assertEqual(len(sent_logs), 1) self.assertEqual(sent_logs[0]["log_level"], LogLevel.debug) self.assertEqual(sent_logs[0]["txaio_trace"], True) self.assertIn("ConnACK", logs.log_text.getvalue()) events = cp.data_received(t.value()) self.assertEqual(len(events), 1) self.assertIsInstance(events[0], ConnACK)
def test_keepalive(self): """ If a client connects with a timeout, and sends no data in keep_alive * 1.5, they will be disconnected. Compliance statement MQTT-3.1.2-24 """ sessions = {} h = BasicHandler() r = Clock() t = StringTransport() p = MQTTServerTwistedProtocol(h, r, sessions) p.makeConnection(t) data = ( # CONNECT, with keepalive of 2 b"101300044d51545404020002000774657374313233") for x in iterbytes(unhexlify(data)): p.dataReceived(x) self.assertEqual(len(r.calls), 1) self.assertEqual(r.calls[0].func, p._lose_connection) self.assertEqual(r.calls[0].getTime(), 3.0) self.assertFalse(t.disconnecting) r.advance(2.9) self.assertFalse(t.disconnecting) r.advance(0.1) self.assertTrue(t.disconnecting)
def test_subscribe_same_id(self): """ SubACKs have the same packet IDs as the Subscription that it is replying to. Compliance statements MQTT-3.8.4-2 """ sessions = {} class SubHandler(BasicHandler): def process_subscribe(self, event): return succeed([0]) h = SubHandler() r = Clock() t = StringTransport() p = MQTTServerTwistedProtocol(h, r, sessions) cp = MQTTClientParser() p.makeConnection(t) data = (Connect(client_id=u"test123", flags=ConnectFlags(clean_session=True)).serialise() + Subscribe(packet_identifier=1234, topic_requests=[SubscriptionTopicRequest(u"a", 0) ]).serialise()) for x in iterbytes(data): p.dataReceived(x) events = cp.data_received(t.value()) self.assertEqual(len(events), 2) self.assertEqual(events[1].return_codes, [0]) self.assertEqual(events[1].packet_identifier, 1234)
def test_keepalive_requires_full_packet(self): """ If a client connects with a keepalive, and sends no FULL packets in keep_alive * 1.5, they will be disconnected. Compliance statement MQTT-3.1.2-24 """ sessions = {} h = BasicHandler() r = Clock() t = StringTransport() p = MQTTServerTwistedProtocol(h, r, sessions) p.makeConnection(t) data = ( # CONNECT, with keepalive of 2 b"101300044d51545404020002000774657374313233") for x in iterbytes(unhexlify(data)): p.dataReceived(x) self.assertEqual(len(r.calls), 1) self.assertEqual(r.calls[0].func, p._lose_connection) self.assertEqual(r.calls[0].getTime(), 3.0) self.assertFalse(t.disconnecting) r.advance(2.9) self.assertFalse(t.disconnecting) data = ( # PINGREQ header, no body (incomplete packet) b"c0") for x in iterbytes(unhexlify(data)): p.dataReceived(x) # Timeout has not changed. If it reset the timeout on data recieved, # the delayed call's trigger time would instead be 2.9 + 3. self.assertEqual(len(r.calls), 1) self.assertEqual(r.calls[0].func, p._lose_connection) self.assertEqual(r.calls[0].getTime(), 3.0) r.advance(0.1) self.assertTrue(t.disconnecting)
def test_keepalive_full_packet_resets_timeout(self): """ If a client connects with a keepalive, and sends packets in under keep_alive * 1.5, the connection will remain, and the timeout will be reset. """ sessions = {} h = BasicHandler() r = Clock() t = StringTransport() p = MQTTServerTwistedProtocol(h, r, sessions) p.makeConnection(t) data = ( # CONNECT, with keepalive of 2 b"101300044d51545404020002000774657374313233") for x in iterbytes(unhexlify(data)): p.dataReceived(x) self.assertEqual(len(r.calls), 1) self.assertEqual(r.calls[0].func, p._lose_connection) self.assertEqual(r.calls[0].getTime(), 3.0) self.assertFalse(t.disconnecting) r.advance(2.9) self.assertFalse(t.disconnecting) data = ( # Full PINGREQ packet b"c000") for x in iterbytes(unhexlify(data)): p.dataReceived(x) # Timeout has changed, to be 2.9 (the time the packet was recieved) + 3 self.assertEqual(len(r.calls), 1) self.assertEqual(r.calls[0].func, p._lose_connection) self.assertEqual(r.calls[0].getTime(), 2.9 + 3.0) r.advance(0.1) self.assertFalse(t.disconnecting)
def test_exception_in_subscribe_drops_connection(self): """ Transient failures (like an exception from handler.process_subscribe) will cause the connection it happened on to be dropped. Compliance statement MQTT-4.8.0-2 """ sessions = {} class SubHandler(BasicHandler): @inlineCallbacks def process_subscribe(self, event): raise Exception("boom!") h = SubHandler() r = Clock() t = StringTransport() p = MQTTServerTwistedProtocol(h, r, sessions) cp = MQTTClientParser() p.makeConnection(t) data = (Connect(client_id=u"test123", flags=ConnectFlags(clean_session=True)).serialise() + Subscribe(packet_identifier=1234, topic_requests=[SubscriptionTopicRequest(u"a", 0) ]).serialise()) with LogCapturer("trace") as logs: for x in iterbytes(data): p.dataReceived(x) sent_logs = logs.get_category("MQ500") self.assertEqual(len(sent_logs), 1) self.assertEqual(sent_logs[0]["log_level"], LogLevel.critical) self.assertEqual(sent_logs[0]["log_failure"].value.args[0], "boom!") events = cp.data_received(t.value()) self.assertEqual(len(events), 1) self.assertTrue(t.disconnecting) # We got the error, we need to flush it so it doesn't make the test # error self.flushLoggedErrors()
def test_lose_conn_on_unimplemented_packet(self): """ If we get a valid, but unimplemented for that role packet (e.g. SubACK, which we will only ever send, and getting it is a protocol violation), we will drop the connection. Compliance statement: MQTT-4.8.0-1 """ sessions = {} # This shouldn't normally happen, but just in case. from crossbar.adapter.mqtt import protocol protocol.server_packet_handlers[protocol.P_SUBACK] = SubACK self.addCleanup( lambda: protocol.server_packet_handlers.pop(protocol.P_SUBACK)) h = BasicHandler() r = Clock() t = StringTransport() p = MQTTServerTwistedProtocol(h, r, sessions) p.makeConnection(t) data = (Connect(client_id=u"test123", flags=ConnectFlags(clean_session=False)).serialise() + SubACK(1, [1]).serialise()) with LogCapturer("trace") as logs: for x in iterbytes(data): p.dataReceived(x) sent_logs = logs.get_category("MQ402") self.assertEqual(len(sent_logs), 1) self.assertEqual(sent_logs[0]["log_level"], LogLevel.error) self.assertEqual(sent_logs[0]["packet_id"], "SubACK") self.assertTrue(t.disconnecting)
def __init__(self, reactor, mqtt_sessions): self._mqtt = MQTTServerTwistedProtocol(self, reactor, mqtt_sessions)
class WampMQTTServerProtocol(Protocol): def __init__(self, reactor, mqtt_sessions): self._mqtt = MQTTServerTwistedProtocol(self, reactor, mqtt_sessions) def connectionMade(self): self._mqtt.transport = self.transport def new_wamp_session(self, event): session_config = ComponentConfig(realm=self.factory._config['realm'], extra=None) session = ApplicationSession(session_config) self.factory._session_factory.add( session, authrole=self.factory._config.get('role', u'anonymous')) self._wamp_session = session return session def existing_wamp_session(self, session): self._full_session = session self._wamp_session = session.wamp_session def process_connect(self, packet): # Should add some authorisation here? return succeed(0) def _publish(self, event, options): payload = {'mqtt_message': event.payload.decode('utf8'), 'mqtt_qos': event.qos_level} return self._wamp_session.publish(event.topic_name, options=options, **payload) def process_publish_qos_0(self, event): return self._publish(event, options=PublishOptions(exclude_me=False)) def process_publish_qos_1(self, event): return self._publish(event, options=PublishOptions(acknowledge=True, exclude_me=False)) def process_puback(self, event): return def process_pubrec(self, event): return def process_pubrel(self, event): return def process_pubcomp(self, event): return @inlineCallbacks def process_subscribe(self, packet): def handle_publish(topic, qos, *args, **kwargs): # If there's a single kwarg which is mqtt_message, then just send # that, so that CB can be 'drop in' if not args and set(kwargs.keys()) == set(["mqtt_message", "mqtt_qos"]): body = kwargs["mqtt_message"].encode('utf8') if kwargs["mqtt_qos"] < qos: # If the QoS of the message is lower than our max QoS, use # the lower QoS. Otherwise, bracket it at our QoS. qos = kwargs["mqtt_qos"] else: body = json.dumps({"args": args, "kwargs": kwargs}).encode('utf8') self._mqtt.send_publish(topic, qos, body) responses = [] for x in packet.topic_requests: if "$" in x.topic_filter or "#" in x.topic_filter or "+" in x.topic_filter or "*" in x.topic_filter: responses.append(128) continue else: try: if x.topic_filter in self._subscriptions: yield self._subscriptions[x.topic_filter].unsubscribe() sub = yield self._wamp_session.subscribe( partial(handle_publish, x.topic_filter, x.max_qos), x.topic_filter) self._full_session.subscriptions[x.topic_filter] = sub # We don't allow QoS 2 subscriptions if x.max_qos > 1: responses.append(1) else: responses.append(x.max_qos) except Exception: print("Failed subscribing to topic %s" % (x.topic_filter,)) responses.append(128) returnValue(responses) @inlineCallbacks def process_unsubscribe(self, packet): for topic in packet.topics: if topic in self._subscriptions: yield self._subscriptions.pop(topic).unsubscribe() return def dataReceived(self, data): self._mqtt.dataReceived(data)
def test_only_unique(self): """ Connected clients must have unique client IDs. Compliance statement MQTT-3.1.3-2 """ sessions = {} h = BasicHandler() r = Clock() t = StringTransport() p = MQTTServerTwistedProtocol(h, r, sessions) cp = MQTTClientParser() p.makeConnection(t) data = ( # CONNECT, client ID of test123 b"\x10\x13\x00\x04MQTT\x04\x02\x00x\x00\x07test123") for x in iterbytes(data): p.dataReceived(x) self.assertFalse(t.disconnecting) events = cp.data_received(t.value()) self.assertEqual(len(events), 1) self.assertEqual(attr.asdict(events[0]), { 'return_code': 0, 'session_present': False, }) # New session r2 = Clock() t2 = StringTransport() p2 = MQTTServerTwistedProtocol(h, r2, sessions) cp2 = MQTTClientParser() p2.makeConnection(t2) # Send the same connect, with the same client ID for x in iterbytes(data): p2.dataReceived(x) events = cp2.data_received(t2.value()) self.assertEqual(len(events), 1) self.assertEqual(attr.asdict(events[0]), { 'return_code': 2, 'session_present': True, })
class WampMQTTServerProtocol(Protocol): def __init__(self, reactor, mqtt_sessions): self._mqtt = MQTTServerTwistedProtocol(self, reactor, mqtt_sessions) def connectionMade(self): self._mqtt.transport = self.transport def new_wamp_session(self, event): session_config = ComponentConfig(realm=self.factory._config['realm'], extra=None) session = ApplicationSession(session_config) self.factory._session_factory.add(session, authrole=self.factory._config.get( 'role', u'anonymous')) self._wamp_session = session return session def existing_wamp_session(self, session): self._full_session = session self._wamp_session = session.wamp_session def process_connect(self, packet): # Should add some authorisation here? return succeed(0) def _publish(self, event, options): payload = { 'mqtt_message': event.payload.decode('utf8'), 'mqtt_qos': event.qos_level } return self._wamp_session.publish(event.topic_name, options=options, **payload) def publish_qos_0(self, event): return self._publish(event, options=PublishOptions(exclude_me=False)) def publish_qos_1(self, event): return self._publish(event, options=PublishOptions(acknowledge=True, exclude_me=False)) @inlineCallbacks def process_subscribe(self, packet): def handle_publish(topic, qos, *args, **kwargs): # If there's a single kwarg which is mqtt_message, then just send # that, so that CB can be 'drop in' if not args and set(kwargs.keys()) == set( ["mqtt_message", "mqtt_qos"]): body = kwargs["mqtt_message"].encode('utf8') if kwargs["mqtt_qos"] < qos: # If the QoS of the message is lower than our max QoS, use # the lower QoS. Otherwise, bracket it at our QoS. qos = kwargs["mqtt_qos"] else: body = json.dumps({ "args": args, "kwargs": kwargs }).encode('utf8') self._mqtt.send_publish(topic, qos, body) responses = [] for x in packet.topic_requests: if "$" in x.topic_filter or "#" in x.topic_filter or "+" in x.topic_filter or "*" in x.topic_filter: responses.append(128) continue else: try: if x.topic_filter in self._subscriptions: yield self._subscriptions[x.topic_filter].unsubscribe() sub = yield self._wamp_session.subscribe( partial(handle_publish, x.topic_filter, x.max_qos), x.topic_filter) self._full_session.subscriptions[x.topic_filter] = sub # We don't allow QoS 2 subscriptions if x.max_qos > 1: responses.append(1) else: responses.append(x.max_qos) except Exception: print("Failed subscribing to topic %s" % (x.topic_filter, )) responses.append(128) returnValue(responses) @inlineCallbacks def process_unsubscribe(self, packet): for topic in packet.topics: if topic in self._subscriptions: yield self._subscriptions.pop(topic).unsubscribe() return def dataReceived(self, data): self._mqtt.dataReceived(data)
class WampMQTTServerProtocol(Protocol): log = make_logger() def __init__(self, reactor): self._mqtt = MQTTServerTwistedProtocol(self, reactor) self._request_to_packetid = {} self._waiting_for_connect = None self._inflight_subscriptions = {} self._subrequest_to_mqtt_subrequest = {} self._subrequest_callbacks = {} self._topic_lookup = {} self._wamp_session = None def on_message(self, inc_msg): try: self._on_message(inc_msg) except: self.log.failure() def _on_message(self, inc_msg): if isinstance(inc_msg, message.Challenge): assert inc_msg.method == u"ticket" msg = message.Authenticate(signature=self._pw_challenge) del self._pw_challenge self._wamp_session.onMessage(msg) elif isinstance(inc_msg, message.Welcome): self._waiting_for_connect.callback((0, False)) elif isinstance(inc_msg, message.Abort): self._waiting_for_connect.callback((1, False)) elif isinstance(inc_msg, message.Subscribed): # Successful subscription! mqtt_id = self._subrequest_to_mqtt_subrequest[inc_msg.request] self._inflight_subscriptions[mqtt_id][ inc_msg.request]["response"] = 0 self._topic_lookup[ inc_msg.subscription] = self._inflight_subscriptions[mqtt_id][ inc_msg.request]["topic"] if -1 not in [ x["response"] for x in self._inflight_subscriptions[mqtt_id].values() ]: self._subrequest_callbacks[mqtt_id].callback(None) elif (isinstance(inc_msg, message.Error) and inc_msg.request_type == message.Subscribe.MESSAGE_TYPE): # Failed subscription :( mqtt_id = self._subrequest_to_mqtt_subrequest[inc_msg.request] self._inflight_subscriptions[mqtt_id][ inc_msg.request]["response"] = 128 if -1 not in [ x["response"] for x in self._inflight_subscriptions[mqtt_id].values() ]: self._subrequest_callbacks[mqtt_id].callback(None) elif isinstance(inc_msg, message.Event): topic = inc_msg.topic or self._topic_lookup[inc_msg.subscription] body = wamp_payload_transform( self._wamp_session._router._mqtt_payload_format, inc_msg) self._mqtt.send_publish(u"/".join(tokenise_wamp_topic(topic)), 0, body, retained=inc_msg.retained or False) elif isinstance(inc_msg, message.Goodbye): if self._mqtt.transport: self._mqtt.transport.loseConnection() self._mqtt.transport = None else: print("Got something we don't understand yet:") print(inc_msg) def connectionMade(self, ignore_handshake=False): if ignore_handshake or not ISSLTransport.providedBy(self.transport): self._when_ready() def connectionLost(self, reason): if self._wamp_session: msg = message.Goodbye() self._wamp_session.onMessage(msg) del self._wamp_session def handshakeCompleted(self): self._when_ready() def _when_ready(self): if self._wamp_session: return self._mqtt.transport = self.transport self._wamp_session = RouterSession( self.factory._wamp_session_factory._routerFactory) self._wamp_session._is_mqtt = True self._wamp_transport = WampTransport(self.on_message, self.transport) self._wamp_transport.factory = self.factory self._wamp_session.onOpen(self._wamp_transport) def process_connect(self, packet): self._waiting_for_connect = Deferred() roles = { u"subscriber": role.RoleSubscriberFeatures(payload_transparency=True), u"publisher": role.RolePublisherFeatures(payload_transparency=True, x_acknowledged_event_delivery=True) } # Will be autoassigned realm = None methods = [] if ISSLTransport.providedBy(self.transport): methods.append(u"tls") if packet.username and packet.password: methods.append(u"ticket") msg = message.Hello(realm=realm, roles=roles, authmethods=methods, authid=packet.username) self._pw_challenge = packet.password else: methods.append(u"anonymous") msg = message.Hello(realm=realm, roles=roles, authmethods=methods, authid=packet.client_id) self._wamp_session.onMessage(msg) if packet.flags.will: @self._waiting_for_connect.addCallback def process_will(res): akw = mqtt_payload_transform( self._wamp_session._router._mqtt_payload_format, packet.will_message) if not akw: # Drop it I guess :( return res args, kwargs = akw msg = message.Call( request=util.id(), procedure=u"wamp.session.add_testament", args=[ u".".join(tokenise_mqtt_topic(packet.will_topic)), args, kwargs, { "retain": bool(packet.flags.will_retain) } ]) self._wamp_session.onMessage(msg) return res return self._waiting_for_connect def _publish(self, event, options): request = util.id() msg = message.Publish(request=request, topic=u".".join( tokenise_mqtt_topic(event.topic_name)), payload=event.payload, **options.message_attr()) msg._mqtt_publish = True self._wamp_session.onMessage(msg) if event.qos_level > 0: self._request_to_packetid[request] = event.packet_identifier return succeed(0) def process_publish_qos_0(self, event): return self._publish(event, options=PublishOptions(exclude_me=False, retain=event.retain)) def process_publish_qos_1(self, event): return self._publish(event, options=PublishOptions(acknowledge=True, exclude_me=False, retain=event.retain)) def process_puback(self, event): return def process_pubrec(self, event): return def process_pubrel(self, event): return def process_pubcomp(self, event): return def process_subscribe(self, packet): packet_watch = OrderedDict() d = Deferred() @d.addCallback def _(ign): self._mqtt.send_suback( packet.packet_identifier, [x["response"] for x in packet_watch.values()]) del self._inflight_subscriptions[packet.packet_identifier] del self._subrequest_callbacks[packet.packet_identifier] self._subrequest_callbacks[packet.packet_identifier] = d self._inflight_subscriptions[packet.packet_identifier] = packet_watch for n, x in enumerate(packet.topic_requests): # fixme match_type = u"exact" request_id = util.id() msg = message.Subscribe( request=request_id, topic=u".".join(tokenise_mqtt_topic(x.topic_filter)), match=match_type, get_retained=True, ) try: packet_watch[request_id] = { "response": -1, "topic": x.topic_filter } self._subrequest_to_mqtt_subrequest[ request_id] = packet.packet_identifier self._wamp_session.onMessage(msg) except Exception: self.log.failure() packet_watch[request_id] = {"response": 128} @inlineCallbacks def process_unsubscribe(self, packet): for topic in packet.topics: if topic in self._subscriptions: yield self._subscriptions.pop(topic).unsubscribe() return def dataReceived(self, data): self._mqtt.dataReceived(data)
def test_allow_connects_with_same_id_if_disconnected(self): """ If a client connects and there is an existing session which is disconnected, it may connect. """ sessions = {} h = BasicHandler() r = Clock() t = StringTransport() p = MQTTServerTwistedProtocol(h, r, sessions) cp = MQTTClientParser() p.makeConnection(t) data = (Connect(client_id=u"test123", flags=ConnectFlags(clean_session=False)).serialise()) for x in iterbytes(data): p.dataReceived(x) self.assertFalse(t.disconnecting) events = cp.data_received(t.value()) self.assertEqual(len(events), 1) self.assertEqual(attr.asdict(events[0]), { 'return_code': 0, 'session_present': False, }) p.connectionLost(None) # New session r2 = Clock() t2 = StringTransport() p2 = MQTTServerTwistedProtocol(h, r2, sessions) cp2 = MQTTClientParser() p2.makeConnection(t2) # Send the same connect, with the same client ID for x in iterbytes(data): p2.dataReceived(x) # Connection allowed events = cp2.data_received(t2.value()) self.assertEqual(len(events), 1) self.assertEqual(attr.asdict(events[0]), { 'return_code': 0, 'session_present': True, }) # Same session self.assertEqual(p.session, p2.session)
class WampMQTTServerProtocol(Protocol): log = make_logger() def __init__(self, reactor): self._mqtt = MQTTServerTwistedProtocol(self, reactor) self._request_to_packetid = {} self._waiting_for_connect = None self._inflight_subscriptions = {} self._subrequest_to_mqtt_subrequest = {} self._subrequest_callbacks = {} self._topic_lookup = {} self._wamp_session = None def on_message(self, inc_msg): try: self._on_message(inc_msg) except: self.log.failure() @inlineCallbacks def _on_message(self, inc_msg): self.log.debug('WampMQTTServerProtocol._on_message(inc_msg={inc_msg})', inc_msg=inc_msg) if isinstance(inc_msg, message.Challenge): assert inc_msg.method == u"ticket" msg = message.Authenticate(signature=self._pw_challenge) del self._pw_challenge self._wamp_session.onMessage(msg) elif isinstance(inc_msg, message.Welcome): self._waiting_for_connect.callback((0, False)) elif isinstance(inc_msg, message.Abort): self._waiting_for_connect.callback((1, False)) elif isinstance(inc_msg, message.Subscribed): # Successful subscription! mqtt_id = self._subrequest_to_mqtt_subrequest[inc_msg.request] self._inflight_subscriptions[mqtt_id][inc_msg.request]["response"] = 0 self._topic_lookup[inc_msg.subscription] = self._inflight_subscriptions[mqtt_id][inc_msg.request]["topic"] if -1 not in [x["response"] for x in self._inflight_subscriptions[mqtt_id].values()]: self._subrequest_callbacks[mqtt_id].callback(None) elif (isinstance(inc_msg, message.Error) and inc_msg.request_type == message.Subscribe.MESSAGE_TYPE): # Failed subscription :( mqtt_id = self._subrequest_to_mqtt_subrequest[inc_msg.request] self._inflight_subscriptions[mqtt_id][inc_msg.request]["response"] = 128 if -1 not in [x["response"] for x in self._inflight_subscriptions[mqtt_id].values()]: self._subrequest_callbacks[mqtt_id].callback(None) elif isinstance(inc_msg, message.Event): topic = inc_msg.topic or self._topic_lookup[inc_msg.subscription] try: payload_format, mapped_topic, payload = yield self.factory.transform_wamp(topic, inc_msg) except: self.log.failure() else: self._mqtt.send_publish(mapped_topic, 0, payload, retained=inc_msg.retained or False) elif isinstance(inc_msg, message.Goodbye): if self._mqtt.transport: self._mqtt.transport.loseConnection() self._mqtt.transport = None else: self.log.warn('cannot process unimplemented message: {inc_msg}', inc_msg=inc_msg) def connectionMade(self, ignore_handshake=False): if ignore_handshake or not ISSLTransport.providedBy(self.transport): self._when_ready() def connectionLost(self, reason): if self._wamp_session: msg = message.Goodbye() self._wamp_session.onMessage(msg) del self._wamp_session def handshakeCompleted(self): self._when_ready() def _when_ready(self): if self._wamp_session: return self._mqtt.transport = self.transport self._wamp_session = RouterSession(self.factory._router_session_factory._routerFactory) self._wamp_transport = WampTransport(self.factory, self.on_message, self.transport) self._wamp_session.onOpen(self._wamp_transport) self._wamp_session._transport_config = self.factory._options def process_connect(self, packet): """ Process the initial Connect message from the MQTT client. This should return a pair `(accept_conn, session_present)`, where `accept_conn` is a return code: 0: connection accepted 1-5: connection refused (see MQTT spec 3.2.2.3) """ # Connect(client_id='paho/4E23D8C09DD9C6CF2C', # flags=ConnectFlags(username=False, # password=False, # will=False, # will_retain=False, # will_qos=0, # clean_session=True, # reserved=False), # keep_alive=60, # will_topic=None, # will_message=None, # username=None, # password=None) self.log.info('WampMQTTServerProtocol.process_connect(packet={packet})', packet=packet) # we don't support session resumption: https://github.com/crossbario/crossbar/issues/892 if not packet.flags.clean_session: self.log.warn('denying MQTT connect from {peer}, as the clients wants to resume a session (which we do not support)', peer=peer2str(self.transport.getPeer())) return succeed((1, False)) # we won't support QoS 2: https://github.com/crossbario/crossbar/issues/1046 if packet.flags.will and packet.flags.will_qos not in [0, 1]: self.log.warn('denying MQTT connect from {peer}, as the clients wants to provide a "last will" event with QoS {will_qos} (and we only support QoS 0/1 here)', peer=peer2str(self.transport.getPeer()), will_qos=packet.flags.will_qos) return succeed((1, False)) # this will be resolved when the MQTT connect handshake is completed self._waiting_for_connect = Deferred() roles = { u"subscriber": role.RoleSubscriberFeatures( payload_transparency=True, pattern_based_subscription=True), u"publisher": role.RolePublisherFeatures( payload_transparency=True, x_acknowledged_event_delivery=True) } realm = self.factory._options.get(u'realm', None) authmethods = [] authextra = { u'mqtt': { u'client_id': packet.client_id, u'will': bool(packet.flags.will), u'will_topic': packet.will_topic } } if ISSLTransport.providedBy(self.transport): authmethods.append(u"tls") if packet.username and packet.password: authmethods.append(u"ticket") msg = message.Hello( realm=realm, roles=roles, authmethods=authmethods, authid=packet.username, authextra=authextra) self._pw_challenge = packet.password else: authmethods.append(u"anonymous") msg = message.Hello( realm=realm, roles=roles, authmethods=authmethods, authid=packet.client_id, authextra=authextra) self._wamp_session.onMessage(msg) if packet.flags.will: # it's unclear from the MQTT spec whether a) the publication of the last will # is to happen in-band during "connect", and if it fails, deny the connection, # or b) the last will publication happens _after_ "connect", and the connection # succeeds regardless whether the last will publication succeeds or not. # # we opt for b) here! # @inlineCallbacks @self._waiting_for_connect.addCallback def process_will(res): self.log.info() payload_format, mapped_topic, options = yield self.factory.transform_mqtt(packet.will_topic, packet.will_message) request = util.id() msg = message.Call( request=request, procedure=u"wamp.session.add_testament", args=[ mapped_topic, options.get('args', None), options.get('kwargs', None), { # specifiy "retain" for when the testament (last will) # will be auto-published by the broker later u'retain': bool(packet.flags.will_retain) } ]) self._wamp_session.onMessage(msg) returnValue(res) return self._waiting_for_connect @inlineCallbacks def _publish(self, event, acknowledge=None): """ Given a MQTT event, create a WAMP Publish message and forward that on the forwarding WAMP session. """ try: payload_format, mapped_topic, options = yield self.factory.transform_mqtt(event.topic_name, event.payload) except: self.log.failure() return request = util.id() msg = message.Publish( request=request, topic=mapped_topic, exclude_me=False, acknowledge=acknowledge, retain=event.retain, **options) self._wamp_session.onMessage(msg) if event.qos_level > 0: self._request_to_packetid[request] = event.packet_identifier returnValue(0) def process_publish_qos_0(self, event): try: return self._publish(event) except: self.log.failure() def process_publish_qos_1(self, event): try: return self._publish(event, acknowledge=True) except: self.log.failure() def process_puback(self, event): return def process_pubrec(self, event): return def process_pubrel(self, event): return def process_pubcomp(self, event): return def process_subscribe(self, packet): packet_watch = OrderedDict() d = Deferred() @d.addCallback def _(ign): self._mqtt.send_suback(packet.packet_identifier, [x["response"] for x in packet_watch.values()]) del self._inflight_subscriptions[packet.packet_identifier] del self._subrequest_callbacks[packet.packet_identifier] self._subrequest_callbacks[packet.packet_identifier] = d self._inflight_subscriptions[packet.packet_identifier] = packet_watch for n, x in enumerate(packet.topic_requests): topic, match = _mqtt_topicfilter_to_wamp(x.topic_filter) self.log.info('process_subscribe -> topic={topic}, match={match}', topic=topic, match=match) request_id = util.id() msg = message.Subscribe( request=request_id, topic=topic, match=match, get_retained=True, ) try: packet_watch[request_id] = {"response": -1, "topic": x.topic_filter} self._subrequest_to_mqtt_subrequest[request_id] = packet.packet_identifier self._wamp_session.onMessage(msg) except: self.log.failure() packet_watch[request_id] = {"response": 128} @inlineCallbacks def process_unsubscribe(self, packet): for topic in packet.topics: if topic in self._subscriptions: yield self._subscriptions.pop(topic).unsubscribe() return def dataReceived(self, data): self._mqtt.dataReceived(data)
def test_clean_session_destroys_session(self): """ Setting the clean_session flag to True when connecting means that any existing session for that user ID will be destroyed. Compliance statement MQTT-3.2.2-1 """ sessions = {} h = BasicHandler() r = Clock() t = StringTransport() p = MQTTServerTwistedProtocol(h, r, sessions) p.makeConnection(t) data = (Connect(client_id=u"test123", flags=ConnectFlags(clean_session=False)).serialise()) for x in iterbytes(data): p.dataReceived(x) self.assertFalse(t.disconnecting) self.assertEqual(list(sessions.keys()), [u"test123"]) old_session = sessions[u"test123"] # Close the connection p.connectionLost(None) # New session, clean_session=True data = (Connect(client_id=u"test123", flags=ConnectFlags(clean_session=True)).serialise()) r2 = Clock() t2 = StringTransport() p2 = MQTTServerTwistedProtocol(h, r2, sessions) cp2 = MQTTClientParser() p2.makeConnection(t2) # Send the same connect, with the same client ID for x in iterbytes(data): p2.dataReceived(x) # Connection allowed events = cp2.data_received(t2.value()) self.assertEqual(len(events), 1) self.assertEqual(attr.asdict(events[0]), { 'return_code': 0, 'session_present': False, }) self.assertEqual(list(sessions.keys()), [u"test123"]) new_session = sessions[u"test123"] # Brand new session, that won't survive self.assertIsNot(old_session, new_session) self.assertFalse(new_session.survives) # We close the connection, the session is destroyed p2.connectionLost(None) self.assertEqual(list(sessions.keys()), [])
class WampMQTTServerProtocol(Protocol): log = make_logger() def __init__(self, reactor): self._mqtt = MQTTServerTwistedProtocol(self, reactor) self._request_to_packetid = {} self._waiting_for_connect = None self._inflight_subscriptions = {} self._subrequest_to_mqtt_subrequest = {} self._subrequest_callbacks = {} self._topic_lookup = {} self._wamp_session = None def on_message(self, inc_msg): try: self._on_message(inc_msg) except: self.log.failure() @inlineCallbacks def _on_message(self, inc_msg): self.log.debug('WampMQTTServerProtocol._on_message(inc_msg={inc_msg})', inc_msg=inc_msg) if isinstance(inc_msg, message.Challenge): assert inc_msg.method == u"ticket" msg = message.Authenticate(signature=self._pw_challenge) del self._pw_challenge self._wamp_session.onMessage(msg) elif isinstance(inc_msg, message.Welcome): self._waiting_for_connect.callback((0, False)) elif isinstance(inc_msg, message.Abort): self._waiting_for_connect.callback((1, False)) elif isinstance(inc_msg, message.Subscribed): # Successful subscription! mqtt_id = self._subrequest_to_mqtt_subrequest[inc_msg.request] self._inflight_subscriptions[mqtt_id][ inc_msg.request]["response"] = 0 self._topic_lookup[ inc_msg.subscription] = self._inflight_subscriptions[mqtt_id][ inc_msg.request]["topic"] if -1 not in [ x["response"] for x in self._inflight_subscriptions[mqtt_id].values() ]: self._subrequest_callbacks[mqtt_id].callback(None) elif (isinstance(inc_msg, message.Error) and inc_msg.request_type == message.Subscribe.MESSAGE_TYPE): # Failed subscription :( mqtt_id = self._subrequest_to_mqtt_subrequest[inc_msg.request] self._inflight_subscriptions[mqtt_id][ inc_msg.request]["response"] = 128 if -1 not in [ x["response"] for x in self._inflight_subscriptions[mqtt_id].values() ]: self._subrequest_callbacks[mqtt_id].callback(None) elif isinstance(inc_msg, message.Event): topic = inc_msg.topic or self._topic_lookup[inc_msg.subscription] try: payload_format, mapped_topic, payload = yield self.factory.transform_wamp( topic, inc_msg) except: self.log.failure() else: self._mqtt.send_publish(mapped_topic, 0, payload, retained=inc_msg.retained or False) elif isinstance(inc_msg, message.Goodbye): if self._mqtt.transport: self._mqtt.transport.loseConnection() self._mqtt.transport = None else: self.log.warn('cannot process unimplemented message: {inc_msg}', inc_msg=inc_msg) def connectionMade(self, ignore_handshake=False): if ignore_handshake or not ISSLTransport.providedBy(self.transport): self._when_ready() def connectionLost(self, reason): if self._wamp_session: msg = message.Goodbye() self._wamp_session.onMessage(msg) del self._wamp_session def handshakeCompleted(self): self._when_ready() def _when_ready(self): if self._wamp_session: return self._mqtt.transport = self.transport self._wamp_session = RouterSession( self.factory._router_session_factory._routerFactory) self._wamp_transport = WampTransport(self.factory, self.on_message, self.transport) self._wamp_session.onOpen(self._wamp_transport) self._wamp_session._transport_config = self.factory._options def process_connect(self, packet): try: self.log.debug( 'WampMQTTServerProtocol.process_connect(packet={packet})', packet=packet) self._waiting_for_connect = Deferred() roles = { u"subscriber": role.RoleSubscriberFeatures(payload_transparency=True), u"publisher": role.RolePublisherFeatures(payload_transparency=True, x_acknowledged_event_delivery=True) } realm = self.factory._options.get('realm', None) methods = [] if ISSLTransport.providedBy(self.transport): methods.append(u"tls") if packet.username and packet.password: methods.append(u"ticket") msg = message.Hello(realm=realm, roles=roles, authmethods=methods, authid=packet.username) self._pw_challenge = packet.password else: methods.append(u"anonymous") msg = message.Hello(realm=realm, roles=roles, authmethods=methods, authid=packet.client_id) self._wamp_session.onMessage(msg) if packet.flags.will: @inlineCallbacks @self._waiting_for_connect.addCallback def process_will(res): payload_format, mapped_topic, options = yield self.factory.transform_mqtt( packet.will_topic, packet.will_message) request = util.id() msg = message.Call(request=request, procedure=u"wamp.session.add_testament", args=[ mapped_topic, options.get('args', None), options.get('kwargs', None), { 'retain': bool(packet.flags.will_retain) } ]) self._wamp_session.onMessage(msg) returnValue(res) return self._waiting_for_connect except: self.log.failure() @inlineCallbacks def _publish(self, event, acknowledge=None): """ Given a MQTT event, create a WAMP Publish message and forward that on the forwarding WAMP session. """ try: payload_format, mapped_topic, options = yield self.factory.transform_mqtt( event.topic_name, event.payload) except: self.log.failure() return request = util.id() msg = message.Publish(request=request, topic=mapped_topic, exclude_me=False, acknowledge=acknowledge, retain=event.retain, **options) self._wamp_session.onMessage(msg) if event.qos_level > 0: self._request_to_packetid[request] = event.packet_identifier returnValue(0) def process_publish_qos_0(self, event): try: return self._publish(event) except: self.log.failure() def process_publish_qos_1(self, event): try: return self._publish(event, acknowledge=True) except: self.log.failure() def process_puback(self, event): return def process_pubrec(self, event): return def process_pubrel(self, event): return def process_pubcomp(self, event): return def process_subscribe(self, packet): packet_watch = OrderedDict() d = Deferred() @d.addCallback def _(ign): self._mqtt.send_suback( packet.packet_identifier, [x["response"] for x in packet_watch.values()]) del self._inflight_subscriptions[packet.packet_identifier] del self._subrequest_callbacks[packet.packet_identifier] self._subrequest_callbacks[packet.packet_identifier] = d self._inflight_subscriptions[packet.packet_identifier] = packet_watch for n, x in enumerate(packet.topic_requests): # fixme match_type = u"exact" request_id = util.id() msg = message.Subscribe( request=request_id, topic=u".".join(tokenise_mqtt_topic(x.topic_filter)), match=match_type, get_retained=True, ) try: packet_watch[request_id] = { "response": -1, "topic": x.topic_filter } self._subrequest_to_mqtt_subrequest[ request_id] = packet.packet_identifier self._wamp_session.onMessage(msg) except: self.log.failure() packet_watch[request_id] = {"response": 128} @inlineCallbacks def process_unsubscribe(self, packet): for topic in packet.topics: if topic in self._subscriptions: yield self._subscriptions.pop(topic).unsubscribe() return def dataReceived(self, data): self._mqtt.dataReceived(data)