def test_allow_connects_with_same_id_if_disconnected(self): """ If a client connects and there is an existing session which is disconnected, it may connect. """ sessions = {} h = BasicHandler() r = Clock() t = StringTransport() p = MQTTServerTwistedProtocol(h, r, sessions) cp = MQTTClientParser() p.makeConnection(t) data = (Connect(client_id=u"test123", flags=ConnectFlags(clean_session=False)).serialise()) for x in iterbytes(data): p.dataReceived(x) self.assertFalse(t.disconnecting) events = cp.data_received(t.value()) self.assertEqual(len(events), 1) self.assertEqual(attr.asdict(events[0]), { 'return_code': 0, 'session_present': False, }) p.connectionLost(None) # New session r2 = Clock() t2 = StringTransport() p2 = MQTTServerTwistedProtocol(h, r2, sessions) cp2 = MQTTClientParser() p2.makeConnection(t2) # Send the same connect, with the same client ID for x in iterbytes(data): p2.dataReceived(x) # Connection allowed events = cp2.data_received(t2.value()) self.assertEqual(len(events), 1) self.assertEqual(attr.asdict(events[0]), { 'return_code': 0, 'session_present': True, }) # Same session self.assertEqual(p.session, p2.session)
def test_only_unique(self): """ Connected clients must have unique client IDs. Compliance statement MQTT-3.1.3-2 """ sessions = {} h = BasicHandler() r = Clock() t = StringTransport() p = MQTTServerTwistedProtocol(h, r, sessions) cp = MQTTClientParser() p.makeConnection(t) data = ( # CONNECT, client ID of test123 b"\x10\x13\x00\x04MQTT\x04\x02\x00x\x00\x07test123") for x in iterbytes(data): p.dataReceived(x) self.assertFalse(t.disconnecting) events = cp.data_received(t.value()) self.assertEqual(len(events), 1) self.assertEqual(attr.asdict(events[0]), { 'return_code': 0, 'session_present': False, }) # New session r2 = Clock() t2 = StringTransport() p2 = MQTTServerTwistedProtocol(h, r2, sessions) cp2 = MQTTClientParser() p2.makeConnection(t2) # Send the same connect, with the same client ID for x in iterbytes(data): p2.dataReceived(x) events = cp2.data_received(t2.value()) self.assertEqual(len(events), 1) self.assertEqual(attr.asdict(events[0]), { 'return_code': 2, 'session_present': True, })
def test_unknown_connect_code_must_lose_connection(self): """ A non-zero, and non-1-to-5 connect code from the handler must result in a lost connection, and no CONNACK. Compliance statements MQTT-3.2.2-4, MQTT-3.2.2-5 """ sessions = {} d = Deferred() # noqa h = BasicHandler(6) r = Clock() t = StringTransport() p = MQTTServerTwistedProtocol(h, r, sessions) cp = MQTTClientParser() # noqa p.makeConnection(t) data = (Connect(client_id=u"test123", flags=ConnectFlags(clean_session=False)).serialise()) for x in iterbytes(data): p.dataReceived(x) self.assertTrue(t.disconnecting) self.assertEqual(t.value(), b'')
def test_transport_paused_while_processing(self): """ The transport is paused whilst the MQTT protocol is parsing/handling existing items. """ sessions = {} d = Deferred() h = BasicHandler() h.process_connect = lambda x: d r = Clock() t = StringTransport() p = MQTTServerTwistedProtocol(h, r, sessions) t.connected = True p.makeConnection(t) data = (Connect(client_id=u"test123", flags=ConnectFlags(clean_session=False)).serialise()) self.assertEqual(t.producerState, 'producing') for x in iterbytes(data): p.dataReceived(x) self.assertEqual(t.producerState, 'paused') d.callback(0) self.assertEqual(t.producerState, 'producing')
def test_subscribe_always_gets_packet(self): """ Subscriptions always get a ConnACK, even if none of the subscriptions were successful. Compliance statements MQTT-3.8.4-1 """ sessions = {} class SubHandler(BasicHandler): def process_subscribe(self, event): return succeed([128]) h = SubHandler() r = Clock() t = StringTransport() p = MQTTServerTwistedProtocol(h, r, sessions) cp = MQTTClientParser() p.makeConnection(t) data = (Connect(client_id=u"test123", flags=ConnectFlags(clean_session=True)).serialise() + Subscribe(packet_identifier=1234, topic_requests=[SubscriptionTopicRequest(u"a", 0) ]).serialise()) for x in iterbytes(data): p.dataReceived(x) events = cp.data_received(t.value()) self.assertEqual(len(events), 2) self.assertEqual(events[1].return_codes, [128])
def test_subscribe_same_id(self): """ SubACKs have the same packet IDs as the Subscription that it is replying to. Compliance statements MQTT-3.8.4-2 """ sessions = {} class SubHandler(BasicHandler): def process_subscribe(self, event): return succeed([0]) h = SubHandler() r = Clock() t = StringTransport() p = MQTTServerTwistedProtocol(h, r, sessions) cp = MQTTClientParser() p.makeConnection(t) data = (Connect(client_id=u"test123", flags=ConnectFlags(clean_session=True)).serialise() + Subscribe(packet_identifier=1234, topic_requests=[SubscriptionTopicRequest(u"a", 0) ]).serialise()) for x in iterbytes(data): p.dataReceived(x) events = cp.data_received(t.value()) self.assertEqual(len(events), 2) self.assertEqual(events[1].return_codes, [0]) self.assertEqual(events[1].packet_identifier, 1234)
def test_send_packet(self): """ On sending a packet, a trace log message is emitted with details of the sent packet. """ sessions = {} h = BasicHandler() r = Clock() t = StringTransport() p = MQTTServerTwistedProtocol(h, r, sessions) cp = MQTTClientParser() p.makeConnection(t) data = ( # CONNECT b"101300044d51545404020002000774657374313233") with LogCapturer("trace") as logs: for x in iterbytes(unhexlify(data)): p.dataReceived(x) sent_logs = logs.get_category("MQ101") self.assertEqual(len(sent_logs), 1) self.assertEqual(sent_logs[0]["log_level"], LogLevel.debug) self.assertEqual(sent_logs[0]["txaio_trace"], True) self.assertIn("ConnACK", logs.log_text.getvalue()) events = cp.data_received(t.value()) self.assertEqual(len(events), 1) self.assertIsInstance(events[0], ConnACK)
def test_keepalive_canceled_on_lost_connection(self): """ If a client connects with a timeout, and disconnects themselves, we will remove the timeout. """ sessions = {} h = BasicHandler() r = Clock() t = StringTransport() p = MQTTServerTwistedProtocol(h, r, sessions) p.makeConnection(t) data = ( # CONNECT, with keepalive of 2 b"101300044d51545404020002000774657374313233") for x in iterbytes(unhexlify(data)): p.dataReceived(x) self.assertEqual(len(r.calls), 1) self.assertEqual(r.calls[0].getTime(), 3.0) timeout = r.calls[0] # Clean connection lost p.connectionLost(None) self.assertEqual(len(r.calls), 0) self.assertTrue(timeout.cancelled) self.assertFalse(timeout.called)
def test_non_zero_connect_code_must_have_no_present_session(self): """ A non-zero connect code in a CONNACK must be paired with no session present. Compliance statement MQTT-3.2.2-4 """ sessions = {} d = Deferred() # noqa h = BasicHandler(self.connect_code) r = Clock() t = StringTransport() p = MQTTServerTwistedProtocol(h, r, sessions) cp = MQTTClientParser() p.makeConnection(t) data = (Connect(client_id=u"test123", flags=ConnectFlags(clean_session=False)).serialise()) for x in iterbytes(data): p.dataReceived(x) events = cp.data_received(t.value()) self.assertEqual(len(events), 1) self.assertEqual(attr.asdict(events[0]), { 'return_code': self.connect_code, 'session_present': False, })
def test_lose_conn_on_protocol_violation(self): """ When a protocol violation occurs, the connection to the client will be terminated, and an error will be logged. Compliance statement MQTT-4.8.0-1 """ sessions = {} h = BasicHandler() r = Clock() t = StringTransport() p = MQTTServerTwistedProtocol(h, r, sessions) p.makeConnection(t) data = ( # Invalid CONNECT b"111300044d51545404020002000774657374313233") with LogCapturer("trace") as logs: for x in iterbytes(unhexlify(data)): p.dataReceived(x) sent_logs = logs.get_category("MQ401") self.assertEqual(len(sent_logs), 1) self.assertEqual(sent_logs[0]["log_level"], LogLevel.error) self.assertIn("Connect", logs.log_text.getvalue()) self.assertEqual(t.value(), b'') self.assertTrue(t.disconnecting)
def make_test_items(handler): r = Clock() t = StringTransport() p = MQTTServerTwistedProtocol(handler, r) cp = MQTTClientParser() p.makeConnection(t) return r, t, p, cp
def test_keepalive(self): """ If a client connects with a timeout, and sends no data in keep_alive * 1.5, they will be disconnected. Compliance statement MQTT-3.1.2-24 """ sessions = {} h = BasicHandler() r = Clock() t = StringTransport() p = MQTTServerTwistedProtocol(h, r, sessions) p.makeConnection(t) data = ( # CONNECT, with keepalive of 2 b"101300044d51545404020002000774657374313233") for x in iterbytes(unhexlify(data)): p.dataReceived(x) self.assertEqual(len(r.calls), 1) self.assertEqual(r.calls[0].func, p._lose_connection) self.assertEqual(r.calls[0].getTime(), 3.0) self.assertFalse(t.disconnecting) r.advance(2.9) self.assertFalse(t.disconnecting) r.advance(0.1) self.assertTrue(t.disconnecting)
def test_exception_in_subscribe_drops_connection(self): """ Transient failures (like an exception from handler.process_subscribe) will cause the connection it happened on to be dropped. Compliance statement MQTT-4.8.0-2 """ sessions = {} class SubHandler(BasicHandler): @inlineCallbacks def process_subscribe(self, event): raise Exception("boom!") h = SubHandler() r = Clock() t = StringTransport() p = MQTTServerTwistedProtocol(h, r, sessions) cp = MQTTClientParser() p.makeConnection(t) data = (Connect(client_id=u"test123", flags=ConnectFlags(clean_session=True)).serialise() + Subscribe(packet_identifier=1234, topic_requests=[SubscriptionTopicRequest(u"a", 0) ]).serialise()) with LogCapturer("trace") as logs: for x in iterbytes(data): p.dataReceived(x) sent_logs = logs.get_category("MQ500") self.assertEqual(len(sent_logs), 1) self.assertEqual(sent_logs[0]["log_level"], LogLevel.critical) self.assertEqual(sent_logs[0]["log_failure"].value.args[0], "boom!") events = cp.data_received(t.value()) self.assertEqual(len(events), 1) self.assertTrue(t.disconnecting) # We got the error, we need to flush it so it doesn't make the test # error self.flushLoggedErrors()
def test_keepalive_requires_full_packet(self): """ If a client connects with a keepalive, and sends no FULL packets in keep_alive * 1.5, they will be disconnected. Compliance statement MQTT-3.1.2-24 """ sessions = {} h = BasicHandler() r = Clock() t = StringTransport() p = MQTTServerTwistedProtocol(h, r, sessions) p.makeConnection(t) data = ( # CONNECT, with keepalive of 2 b"101300044d51545404020002000774657374313233") for x in iterbytes(unhexlify(data)): p.dataReceived(x) self.assertEqual(len(r.calls), 1) self.assertEqual(r.calls[0].func, p._lose_connection) self.assertEqual(r.calls[0].getTime(), 3.0) self.assertFalse(t.disconnecting) r.advance(2.9) self.assertFalse(t.disconnecting) data = ( # PINGREQ header, no body (incomplete packet) b"c0") for x in iterbytes(unhexlify(data)): p.dataReceived(x) # Timeout has not changed. If it reset the timeout on data recieved, # the delayed call's trigger time would instead be 2.9 + 3. self.assertEqual(len(r.calls), 1) self.assertEqual(r.calls[0].func, p._lose_connection) self.assertEqual(r.calls[0].getTime(), 3.0) r.advance(0.1) self.assertTrue(t.disconnecting)
def test_keepalive_full_packet_resets_timeout(self): """ If a client connects with a keepalive, and sends packets in under keep_alive * 1.5, the connection will remain, and the timeout will be reset. """ sessions = {} h = BasicHandler() r = Clock() t = StringTransport() p = MQTTServerTwistedProtocol(h, r, sessions) p.makeConnection(t) data = ( # CONNECT, with keepalive of 2 b"101300044d51545404020002000774657374313233") for x in iterbytes(unhexlify(data)): p.dataReceived(x) self.assertEqual(len(r.calls), 1) self.assertEqual(r.calls[0].func, p._lose_connection) self.assertEqual(r.calls[0].getTime(), 3.0) self.assertFalse(t.disconnecting) r.advance(2.9) self.assertFalse(t.disconnecting) data = ( # Full PINGREQ packet b"c000") for x in iterbytes(unhexlify(data)): p.dataReceived(x) # Timeout has changed, to be 2.9 (the time the packet was recieved) + 3 self.assertEqual(len(r.calls), 1) self.assertEqual(r.calls[0].func, p._lose_connection) self.assertEqual(r.calls[0].getTime(), 2.9 + 3.0) r.advance(0.1) self.assertFalse(t.disconnecting)
def test_lose_conn_on_unimplemented_packet(self): """ If we get a valid, but unimplemented for that role packet (e.g. SubACK, which we will only ever send, and getting it is a protocol violation), we will drop the connection. Compliance statement: MQTT-4.8.0-1 """ sessions = {} # This shouldn't normally happen, but just in case. from crossbar.adapter.mqtt import protocol protocol.server_packet_handlers[protocol.P_SUBACK] = SubACK self.addCleanup( lambda: protocol.server_packet_handlers.pop(protocol.P_SUBACK)) h = BasicHandler() r = Clock() t = StringTransport() p = MQTTServerTwistedProtocol(h, r, sessions) p.makeConnection(t) data = (Connect(client_id=u"test123", flags=ConnectFlags(clean_session=False)).serialise() + SubACK(1, [1]).serialise()) with LogCapturer("trace") as logs: for x in iterbytes(data): p.dataReceived(x) sent_logs = logs.get_category("MQ402") self.assertEqual(len(sent_logs), 1) self.assertEqual(sent_logs[0]["log_level"], LogLevel.error) self.assertEqual(sent_logs[0]["packet_id"], "SubACK") self.assertTrue(t.disconnecting)
def test_publish_traced_events_batched(self): """ with two subscribers and message tracing the last event should have a magic flag """ # we want to trigger a deeply-nested condition in # processPublish in class Broker -- lets try w/o refactoring # anything first... class TestSession(ApplicationSession): pass session0 = TestSession() session1 = TestSession() session2 = TestSession() session3 = TestSession() session4 = TestSession() # NOTE! We ensure that "session0" (the publishing session) is # *last* in the observation-list to trigger a (now fixed) # edge-case) sessions = [session1, session2, session3, session4, session0] router = mock.MagicMock() router.send = mock.Mock() router.new_correlation_id = lambda: u'fake correlation id' router.is_traced = True clock = Clock() with replace_loop(clock): broker = Broker(router, clock) broker._options.event_dispatching_chunk_size = 2 # to ensure we get "session0" last, we turn on ordering in # the observations broker._subscription_map._ordered = 1 # let's just "cheat" our way a little to the right state by # injecting our subscription "directly" (e.g. instead of # faking out an entire Subscribe etc. flow # ...so we need _subscriptions_map to have at least one # subscription (our test one) for the topic we'll publish to for session in sessions: broker._subscription_map.add_observer(session, u'test.topic') for i, sess in enumerate(sessions): sess._session_id = 1000 + i sess._transport = mock.MagicMock() sess._transport.get_channel_id = mock.MagicMock(return_value=b'deadbeef') # here's the main "cheat"; we're faking out the # router.authorize because we need it to callback immediately router.authorize = mock.MagicMock(return_value=txaio.create_future_success(dict(allow=True, cache=False, disclose=True))) # now we scan call "processPublish" such that we get to the # condition we're interested in; should go to all sessions # except session0 pubmsg = message.Publish(123, u'test.topic') broker.processPublish(session0, pubmsg) clock.advance(1) clock.advance(1) # extract all the event calls events = [ call[1][1] for call in router.send.mock_calls if call[1][0] in [session0, session1, session2, session3, session4] ] # all except session0 should have gotten an event, and # session4's should have the "last" flag set self.assertEqual(4, len(events)) self.assertFalse(events[0].correlation_is_last) self.assertFalse(events[1].correlation_is_last) self.assertFalse(events[2].correlation_is_last) self.assertTrue(events[3].correlation_is_last)
def test_publish_traced_events_batched(self): """ with two subscribers and message tracing the last event should have a magic flag """ # we want to trigger a deeply-nested condition in # processPublish in class Broker -- lets try w/o refactoring # anything first... class TestSession(ApplicationSession): pass session0 = TestSession() session1 = TestSession() session2 = TestSession() session3 = TestSession() session4 = TestSession() # NOTE! We ensure that "session0" (the publishing session) is # *last* in the observation-list to trigger a (now fixed) # edge-case) sessions = [session1, session2, session3, session4, session0] router = mock.MagicMock() router.send = mock.Mock() router.new_correlation_id = lambda: 'fake correlation id' router.is_traced = True clock = Clock() with replace_loop(clock): broker = Broker(router, clock) broker._options.event_dispatching_chunk_size = 2 # to ensure we get "session0" last, we turn on ordering in # the observations broker._subscription_map._ordered = 1 # let's just "cheat" our way a little to the right state by # injecting our subscription "directly" (e.g. instead of # faking out an entire Subscribe etc. flow # ...so we need _subscriptions_map to have at least one # subscription (our test one) for the topic we'll publish to for session in sessions: broker._subscription_map.add_observer(session, 'test.topic') for i, sess in enumerate(sessions): sess._session_id = 1000 + i sess._transport = mock.MagicMock() sess._transport.get_channel_id = mock.MagicMock( return_value=b'deadbeef') # here's the main "cheat"; we're faking out the # router.authorize because we need it to callback immediately router.authorize = mock.MagicMock( return_value=txaio.create_future_success( dict(allow=True, cache=False, disclose=True))) # now we scan call "processPublish" such that we get to the # condition we're interested in; should go to all sessions # except session0 pubmsg = message.Publish(123, 'test.topic') broker.processPublish(session0, pubmsg) clock.advance(1) clock.advance(1) # extract all the event calls events = [ call[1][1] for call in router.send.mock_calls if call[1][0] in [session0, session1, session2, session3, session4] ] # all except session0 should have gotten an event, and # session4's should have the "last" flag set self.assertEqual(4, len(events)) self.assertFalse(events[0].correlation_is_last) self.assertFalse(events[1].correlation_is_last) self.assertFalse(events[2].correlation_is_last) self.assertTrue(events[3].correlation_is_last)
def build_mqtt_server(): reactor = Clock() router_factory, server_factory, session_factory = make_router() add_realm_to_router(router_factory, session_factory) router = add_realm_to_router(router_factory, session_factory, realm_name=u'mqtt', realm_options={}) # allow everything default_permissions = { u'uri': u'', u'match': u'prefix', u'allow': { u'call': True, u'register': True, u'publish': True, u'subscribe': True } } router.add_role( RouterRoleStaticAuth(router, u'mqttrole', default_permissions=default_permissions)) class AuthenticatorSession(ApplicationSession): @inlineCallbacks def onJoin(self, details): def authenticate(realm, authid, details): if authid == u"test123": if details["ticket"] != u'password': raise ApplicationError(u'com.example.invalid_ticket', u'nope') res = { u'realm': u'mqtt', u'role': u'mqttrole', u'extra': {} } return res else: raise ApplicationError(u'com.example.no_such_user', u'nah') yield self.register(authenticate, u'com.example.auth') def tls(realm, authid, details): ACCEPTED_CERTS = set([ u'95:1C:A9:6B:CD:8D:D2:BD:F4:73:82:01:55:89:41:12:9C:F8:AF:8E' ]) if 'client_cert' not in details['transport'] or not details[ 'transport']['client_cert']: raise ApplicationError(u"com.example.no_cert", u"no client certificate presented") client_cert = details['transport']['client_cert'] sha1 = client_cert['sha1'] subject_cn = client_cert['subject']['cn'] if sha1 not in ACCEPTED_CERTS: raise ApplicationError( u"com.example.invalid_cert", u"certificate with SHA1 {} denied".format(sha1)) else: return { u'authid': subject_cn, u'role': u'mqttrole', u'realm': u'mqtt' } yield self.register(tls, u'com.example.tls') config = ComponentConfig(u"default", {}) authsession = AuthenticatorSession(config) session_factory.add(authsession, router, authrole=u"trusted") options = { u"options": { u"realm": u"mqtt", u"role": u"mqttrole", u"payload_mapping": { u"": { u"type": u"native", u"serializer": u"json" } }, u"auth": { u"ticket": { u"type": u"dynamic", u"authenticator": u"com.example.auth", u"authenticator-realm": u"default", }, u"tls": { u"type": u"dynamic", u"authenticator": u"com.example.tls", u"authenticator-realm": u"default", } } } } mqtt_factory = WampMQTTServerFactory(session_factory, options, reactor) server_factory._mqtt_factory = mqtt_factory return reactor, router, server_factory, session_factory
def test_clean_session_destroys_session(self): """ Setting the clean_session flag to True when connecting means that any existing session for that user ID will be destroyed. Compliance statement MQTT-3.2.2-1 """ sessions = {} h = BasicHandler() r = Clock() t = StringTransport() p = MQTTServerTwistedProtocol(h, r, sessions) p.makeConnection(t) data = (Connect(client_id=u"test123", flags=ConnectFlags(clean_session=False)).serialise()) for x in iterbytes(data): p.dataReceived(x) self.assertFalse(t.disconnecting) self.assertEqual(list(sessions.keys()), [u"test123"]) old_session = sessions[u"test123"] # Close the connection p.connectionLost(None) # New session, clean_session=True data = (Connect(client_id=u"test123", flags=ConnectFlags(clean_session=True)).serialise()) r2 = Clock() t2 = StringTransport() p2 = MQTTServerTwistedProtocol(h, r2, sessions) cp2 = MQTTClientParser() p2.makeConnection(t2) # Send the same connect, with the same client ID for x in iterbytes(data): p2.dataReceived(x) # Connection allowed events = cp2.data_received(t2.value()) self.assertEqual(len(events), 1) self.assertEqual(attr.asdict(events[0]), { 'return_code': 0, 'session_present': False, }) self.assertEqual(list(sessions.keys()), [u"test123"]) new_session = sessions[u"test123"] # Brand new session, that won't survive self.assertIsNot(old_session, new_session) self.assertFalse(new_session.survives) # We close the connection, the session is destroyed p2.connectionLost(None) self.assertEqual(list(sessions.keys()), [])