def setUp(self):
     '''
     Set up a conencted state
     '''
     self.transport = proto_helpers.StringTransportWithDisconnection()
     self.clock     = task.Clock()
     MQTTBaseProtocol.callLater = self.clock.callLater
     self.factory   = MQTTFactory(MQTTFactory.SUBSCRIBER)
     self._rebuild()
 def setUp(self):
     '''
     Set up a conencted state
     '''
     self.transport = proto_helpers.StringTransportWithDisconnection()
     self.clock     = task.Clock()
     MQTTBaseProtocol.callLater = self.clock.callLater
     self.factory   = MQTTFactory(MQTTFactory.PUBLISHER)
     self.addr = IPv4Address('TCP','localhost',1880)
     self._rebuild()
Example #3
0
class TestMQTTBaseProtocol2(unittest.TestCase):

    def setUp(self):
        '''
        Set up a conencted state
        '''
        self.transport = proto_helpers.StringTransportWithDisconnection()
        self.clock     = task.Clock()
        MQTTBaseProtocol.callLater = self.clock.callLater
        self.factory   = MQTTFactory(MQTTFactory.SUBSCRIBER)
        self._rebuild()
      
    def tearDown(self):
        '''
        Needed because ping sets up a LoopingCall, outside Clock simulated callLater()
        '''
        self.transport.loseConnection()

    def _connect(self, keepalive=0, cleanStart=True):
        '''
        Go to connected state
        '''
        ack = CONNACK()
        ack.session = False
        ack.resultCode = 0
        ack.encode()
        self.protocol.connect("TwistedMQTT-sub", keepalive=keepalive, cleanStart=cleanStart, version=v31)
        self.transport.clear()
        self.protocol.dataReceived(ack.encoded)

    def _rebuild(self):
        self.protocol  = self.factory.buildProtocol(0)
        self.transport.protocol = self.protocol
        self.protocol.makeConnection(self.transport)


    def test_disconnect(self):
        self._connect()
        self.assertEqual(self.protocol.state, MQTTBaseProtocol.CONNECTED)
        self.protocol.disconnect()
        self.transport.clear()

    def test_ping(self):
        self._connect(keepalive=5)
        self.protocol.dataReceived(PINGRES().encode())
        self.transport.clear()
        self.assertEqual(self.protocol.state, MQTTBaseProtocol.CONNECTED)
      
    def test_ping_timeout(self):
        self._connect(keepalive=5)
        self.protocol.ping()
        self.transport.clear()
        self.clock.advance(6)
        self.assertEqual(self.protocol.state, MQTTBaseProtocol.IDLE)
class TestMQTTPublisherForbiddenOps(unittest.TestCase):
    '''
    Testing various cases of disconnect callback
    '''

    def setUp(self):
        '''
        Set up a connencted state
        '''
        self.transport = proto_helpers.StringTransportWithDisconnection()
        self.clock     = task.Clock()
        MQTTBaseProtocol.callLater = self.clock.callLater
        self.factory   = MQTTFactory(MQTTFactory.PUBLISHER)
        self._rebuild()
        self.disconnected = False
        self._rebuild()
        self._connect()

    def _connect(self, cleanStart=True):
        '''
        Go to connected state
        '''
        ack = CONNACK()
        ack.session = False
        ack.resultCode = 0
        ack.encode()
        self.protocol.connect("TwistedMQTT-pub", keepalive=0, cleanStart=cleanStart, version=v31)
        self.transport.clear()
        self.protocol.dataReceived(ack.encoded)


    def _rebuild(self):
        self.protocol  = self.factory.buildProtocol(0)
        self.transport.protocol = self.protocol
        MQTTBaseProtocol.callLater = self.clock.callLater
        self.protocol.makeConnection(self.transport)

    def test_forbidden_subscribe(self):
        '''Just connect and lose the transport'''
        d = self.protocol.subscribe("foo/bar/baz1", 2 )
        self.failureResultOf(d).trap(MQTTStateError)
      
    def test_forbidden_unsubscribe(self):
        '''Just connect and lose the transport'''
        d = self.protocol.unsubscribe("foo/bar/baz1")
        self.failureResultOf(d).trap(MQTTStateError)

    def test_forbidden_publish_callback(self):
        '''Just connect and lose the transport'''
        def onPublish(topic, payload, qos, dup, retain, msgId):
            pass
        self.assertRaises(MQTTStateError, self.protocol.onPublish, onPublish)
Example #5
0
	def __init__(self, uploader):
		MQTTFactory.__init__(self, profile=MQTTFactory.PUBLISHER)
		self.uploader = uploader
class TestMQTTSubscriber1(unittest.TestCase):


    def setUp(self):
        '''
        Set up a conencted state
        '''
        self.transport = proto_helpers.StringTransportWithDisconnection()
        self.clock     = task.Clock()
        MQTTBaseProtocol.callLater = self.clock.callLater
        self.factory   = MQTTFactory(MQTTFactory.SUBSCRIBER)
        self.addr = IPv4Address('TCP','localhost',1880)
        self._rebuild()
        self._connect()

    def _connect(self, cleanStart=True):
        '''
        Go to connected state
        '''
        ack = CONNACK()
        ack.session = False
        ack.resultCode = 0
        ack.encode()
        self.protocol.connect("TwistedMQTT-sub", keepalive=0, cleanStart=cleanStart, version=v31)
        self.transport.clear()
        self.protocol.dataReceived(ack.encoded)


    def _serverDown(self):
        self.transport.loseConnection()
        self.transport.clear()
        del self.protocol

    def _rebuild(self):
        self.protocol  = self.factory.buildProtocol(self.addr)
        self.transport.protocol = self.protocol
        MQTTBaseProtocol.callLater = self.clock.callLater
        self.protocol.makeConnection(self.transport)


    def _subscribe(self, n, qos, topic):
        self.protocol.setWindowSize(n)
        dl = []
        for i in range(0,n):
            t = "{0}{1}".format(topic, i)
            dl.append(self.protocol.subscribe(t, qos))
        self.transport.clear()
        for d in dl:
            self.assertNoResult(d)
        return dl

    def _unsubscribe(self, n, topic):
        self.protocol.setWindowSize(n)
        dl = []
        for i in range(0,n):
            t = "{0}{1}".format(topic, i)
            dl.append(self.protocol.unsubscribe(t))
        self.transport.clear()
        for d in dl:
            self.assertNoResult(d)
        return dl

    def test_subscribe_single(self):
        d = self.protocol.subscribe("foo/bar/baz1", 2 )
        self.transport.clear()
        ack = SUBACK()
        ack.msgId = d.msgId
        ack.granted = [(2, False)]
        self.protocol.dataReceived(ack.encode())
        self.assertEqual([(2, False)], self.successResultOf(d))

    def test_subscribe_single_large_qos(self):
        d = self.protocol.subscribe("foo/bar/baz1", 3)
        self.transport.clear()
        self.failureResultOf(d).trap(ValueError)

    def test_subscribe_single_negative_qos(self):
        d = self.protocol.subscribe("foo/bar/baz1", -1)
        self.transport.clear()
        self.failureResultOf(d).trap(ValueError)

    def test_subscribe_tuple(self):
        d = self.protocol.subscribe( ("foo/bar/baz1", 2) )
        self.transport.clear()
        ack = SUBACK()
        ack.msgId = d.msgId
        ack.granted = [(2, False)]
        self.protocol.dataReceived(ack.encode())
        self.assertEqual([(2, False)], self.successResultOf(d))

    def test_subscribe_list(self):
        d = self.protocol.subscribe( [ ("foo/bar/baz1", 2), ("foo/bar/baz2", 1), ("foo/bar/baz3", 0) ] )
        d.addCallback(self.assertEqual, [(2, False), (1, False), (0, False)] )
        self.transport.clear()
        ack = SUBACK()
        ack.msgId = d.msgId
        ack.granted = [(2, False), (1, False), (0, False)]
        self.protocol.dataReceived(ack.encode())
        self.assertEqual( [(2, False), (1, False), (0, False)], self.successResultOf(d))

    def test_subscribe_several_fail(self):
        dl = self._subscribe(n=3, qos=2, topic="foo/bar/baz")
        self.assertEqual(len(self.protocol.factory.windowSubscribe[self.addr]), 3)
        self._serverDown()
        for d in dl:
            self.failureResultOf(d).trap(error.ConnectionDone)
        

    def test_subscribe_several_window_fail(self):
        self.protocol.setWindowSize(3)
        dl = self._subscribe(n=3, qos=2, topic="foo/bar/baz")
        self.assertEqual(len(self.protocol.factory.windowSubscribe[self.addr]), 3)
        d4 = self.protocol.subscribe("foo/bar/baz3", 2 )
        self.assertEqual(len(self.protocol.factory.windowSubscribe[self.addr]), 3)
        self.failureResultOf(d4).trap(MQTTWindowError)
        self._serverDown()
        for d in dl:
            self.failureResultOf(d).trap(error.ConnectionDone)
        

    def test_unsubscribe_single(self):
        d = self.protocol.unsubscribe("foo/bar/baz1")
        self.transport.clear()
        ack = UNSUBACK()
        ack.msgId = d.msgId
        self.protocol.dataReceived(ack.encode())
        self.assertEqual(ack.msgId, self.successResultOf(d))


    def test_unsubscribe_list(self):
        d = self.protocol.unsubscribe( [ "foo/bar/baz1", "foo/bar/baz2", "foo/bar/baz3"] )
        self.transport.clear()
        ack = UNSUBACK()
        ack.msgId = d.msgId
        self.protocol.dataReceived(ack.encode())
        self.assertEqual(ack.msgId, self.successResultOf(d))

    def test_unsubscribe_several_fail(self):
        dl = self._unsubscribe(n=3, topic="foo/bar/baz")
        self.assertEqual(len(self.protocol.factory.windowUnsubscribe[self.addr]), 3)
        self._serverDown()
        for d in dl:
            self.failureResultOf(d).trap(error.ConnectionDone)

    def test_unsubscribe_several_window_fail(self):
        dl = self._unsubscribe(n=3, topic="foo/bar/baz")
        self.assertEqual(len(self.protocol.factory.windowUnsubscribe[self.addr]), 3)
        d4 = self.protocol.unsubscribe("foo/bar/baz4")
        self.assertEqual(len(self.protocol.factory.windowUnsubscribe[self.addr]), 3)
        self.failureResultOf(d4).trap(MQTTWindowError)
        self._serverDown()
        for d in dl:
            self.failureResultOf(d).trap(error.ConnectionDone)

    def test_publish_recv_qos0(self):
        def onPublish(topic, payload, qos, dup, retain, msgId):
            self.topic   = topic
            self.payload = payload.decode('utf-8')
            self.qos     = qos
            self.retain  = retain
            self.msgId   = msgId 
        self.protocol.onPublish = onPublish
        pub =PUBLISH()
        pub.qos     = 0
        pub.dup     = False
        pub.retain  = False
        pub.topic   = "foo/bar/baz0"
        pub.msgId   = None
        pub.payload = "Hello world 0"
        self.protocol.dataReceived(pub.encode())
        self.assertEqual(self.topic,   pub.topic)
        self.assertEqual(self.payload, pub.payload)
        self.assertEqual(self.qos,     pub.qos)
        self.assertEqual(self.retain,  pub.retain)
        self.assertEqual(self.msgId,   pub.msgId )

    def test_publish_recv_qos1(self):
        def onPublish(topic, payload, qos, dup, retain, msgId):
            self.topic   = topic
            self.payload = payload.decode('utf-8')
            self.qos     = qos
            self.retain  = retain
            self.msgId   = msgId 
            self.dup     = dup
        self.protocol.onPublish = onPublish
        pub =PUBLISH()
        pub.qos     = 1
        pub.dup     = False
        pub.retain  = False
        pub.topic   = "foo/bar/baz1"
        pub.msgId   = 1
        pub.payload = "Hello world 1"
        self.protocol.dataReceived(pub.encode())
        self.transport.clear()
        self.assertEqual(self.topic,   pub.topic)
        self.assertEqual(self.payload, pub.payload)
        self.assertEqual(self.qos,     pub.qos)
        self.assertEqual(self.retain,  pub.retain)
        self.assertEqual(self.msgId,   pub.msgId )
        self.assertEqual(self.dup,     pub.dup )

    def test_publish_recv_qos2(self):
        self.called = False
        def onPublish(topic, payload, qos, dup, retain, msgId):
            self.called = True
            self.topic   = topic
            self.payload = payload.decode('utf-8')
            self.qos     = qos
            self.retain  = retain
            self.msgId   = msgId 
            self.dup     = dup
        self.protocol.onPublish = onPublish
        pub =PUBLISH()
        pub.qos     = 2
        pub.dup     = False
        pub.retain  = False
        pub.topic   = "foo/bar/baz2"
        pub.msgId   = 1
        pub.payload = "Hello world 2"
        self.protocol.dataReceived(pub.encode())
        self.transport.clear()
        self.assertEqual(self.called, False)
        rel = PUBREL()
        rel.msgId = pub.msgId
        self.protocol.dataReceived(rel.encode())
        self.assertEqual(self.topic,   pub.topic)
        self.assertEqual(self.payload, pub.payload)
        self.assertEqual(self.qos,     pub.qos)
        self.assertEqual(self.retain,  pub.retain)
        self.assertEqual(self.msgId,   pub.msgId )
        self.assertEqual(self.dup,     pub.dup )
class TestMQTTSubscriber2(unittest.TestCase):


    def setUp(self):
        '''
        Set up a conencted state
        '''
        self.transport = proto_helpers.StringTransportWithDisconnection()
        self.clock     = task.Clock()
        MQTTBaseProtocol.callLater = self.clock.callLater
        self.factory   = MQTTFactory(MQTTFactory.SUBSCRIBER)
        self._rebuild()
       
    def _connect(self, cleanStart=True):
        '''
        Go to connected state
        '''
        ack = CONNACK()
        ack.session = False
        ack.resultCode = 0
        ack.encode()
        self.protocol.connect("TwistedMQTT-sub", keepalive=0, cleanStart=cleanStart, version=v31)
        self.transport.clear()
        self.protocol.dataReceived(ack.encoded)


    def _serverDown(self):
        self.transport.loseConnection()
        self.transport.clear()
        del self.protocol

    def _rebuild(self):
        self.protocol  = self.factory.buildProtocol(0)
        self.transport.protocol = self.protocol
        MQTTBaseProtocol.callLater = self.clock.callLater
        self.protocol.makeConnection(self.transport)


    def _subscribe(self, n, qos, topic):
        self.protocol.setWindowSize(n)
        dl = []
        for i in range(0,n):
            t = "{0}{1}".format(topic, i)
            dl.append(self.protocol.subscribe(t, qos))
        self.transport.clear()
        for d in dl:
            self.assertNoResult(d)
        return dl

    def _unsubscribe(self, n, topic):
        self.protocol.setWindowSize(n)
        dl = []
        for i in range(0,n):
            t = "{0}{1}".format(topic, i)
            dl.append(self.protocol.unsubscribe(t))
        self.transport.clear()
        for d in dl:
            self.assertNoResult(d)
        return dl

    def test_subscribe_setPubishHandler1(self):
        def onPublish(topic, payload, qos, dup, retain, msgId):
            self.called  = True
        self.protocol.onPublish = onPublish
        self._connect()
        d = self.protocol.subscribe("foo/bar/baz1", 2 )
        self.transport.clear()
        ack = SUBACK()
        ack.msgId = d.msgId
        ack.granted = [(2, False)]
        self.protocol.dataReceived(ack.encode())
        self.assertEqual([(2, False)], self.successResultOf(d))

    def test_subscribe_setPubishHandler2(self):
        def onPublish(topic, payload, qos, dup, retain, msgId):
            self.called  = True
        self._connect()
        self.protocol.onPublish = onPublish
        d = self.protocol.subscribe("foo/bar/baz1", 2 )
        self.transport.clear()
        ack = SUBACK()
        ack.msgId = d.msgId
        ack.granted = [(2, False)]
        self.protocol.dataReceived(ack.encode())
        self.assertEqual([(2, False)], self.successResultOf(d))
class TestMQTTSubscriberDisconnect(unittest.TestCase):
    '''
    Testing various cases of disconnect callback
    '''

    def setUp(self):
        '''
        Set up a connencted state
        '''
        self.transport = proto_helpers.StringTransportWithDisconnection()
        self.clock     = task.Clock()
        MQTTBaseProtocol.callLater = self.clock.callLater
        self.factory   = MQTTFactory(MQTTFactory.SUBSCRIBER)
        self._rebuild()
        self.disconnected = False

    def _connect(self, cleanStart=True):
        '''
        Go to connected state
        '''
        ack = CONNACK()
        ack.session = False
        ack.resultCode = 0
        ack.encode()
        self.protocol.connect("TwistedMQTT-sub", keepalive=0, cleanStart=cleanStart, version=v31)
        self.transport.clear()
        self.protocol.dataReceived(ack.encoded)

    def _disconnected(self, reason):
        self.disconnected = True

    def _rebuild(self):
        self.protocol  = self.factory.buildProtocol(0)
        self.transport.protocol = self.protocol
        MQTTBaseProtocol.callLater = self.clock.callLater
        self.protocol.makeConnection(self.transport)

    def _serverDown(self):
        self.transport.loseConnection()
        self.transport.clear()
        del self.protocol

    def test_disconnect_1(self):
        '''Just connect and lose the transport'''
        self._connect()
        self.protocol.onDisconnection = self._disconnected
        self.transport.loseConnection()
        self.assertEqual(self.disconnected, True)

    def test_disconnect_2(self):
        '''connect and disconnect'''
        self._connect()
        self.protocol.onDisconnection = self._disconnected
        self.protocol.disconnect()
        self.assertEqual(self.disconnected, True)

    def test_disconnect_3(self):
        '''connect, generate a deferred and lose the transport'''
        self._connect()
        self.protocol.onDisconnection = self._disconnected
        d = self.protocol.subscribe("foo/bar/baz1", 2 )
        self.transport.clear()
        self.transport.loseConnection()
        self.assertEqual(self.disconnected, True)
        self.failureResultOf(d).trap(error.ConnectionDone)

    def test_disconnect_4(self):
        '''connect, generate a deferred and disconnect'''
        self._connect()
        self.protocol.onDisconnection = self._disconnected
        d = self.protocol.subscribe("foo/bar/baz1", 2 )
        self.transport.clear()
        self.protocol.disconnect()
        self.assertEqual(self.disconnected, True)
        self.failureResultOf(d).trap(error.ConnectionDone)

    def test_disconnect_5(self):
        '''connect with persistent session, 
        enerate a deferred that will not errback 
        and then disconnect'''
        self._connect(cleanStart=False)
        self.protocol.onDisconnection = self._disconnected
        d = self.protocol.subscribe("foo/bar/baz1", 2 )
        self.transport.clear()
        self.protocol.disconnect()
        self.assertEqual(self.disconnected, True)
        self.assertNoResult(d)

    def test_disconnect_6(self):
        '''connect with persistent session, 
        generate a deferred that will not errback yet, 
        then rebuilds protocol'''
        self._connect(cleanStart=False)
        self.protocol.onDisconnection = self._disconnected
        d = self.protocol.subscribe("foo/bar/baz1", 2 )
        self._serverDown()
        self._rebuild()
        self.assertEqual(self.disconnected, True)
        self.assertNoResult(d)
Example #9
0
class TestMQTTPublisher1(unittest.TestCase):


    def setUp(self):
        '''
        Set up a conencted state
        '''
        self.transport = proto_helpers.StringTransportWithDisconnection()
        self.clock     = task.Clock()
        MQTTBaseProtocol.callLater = self.clock.callLater
        self.factory   = MQTTFactory(MQTTFactory.PUBLISHER)
        self._rebuild()

    def _connect(self, cleanStart=True):
        '''
        Go to connected state
        '''
        ack = CONNACK()
        ack.session = False
        ack.resultCode = 0
        ack.encode()
        self.protocol.connect("TwistedMQTT-pub", keepalive=0, cleanStart=cleanStart, version=v31)
        self.transport.clear()
        self.protocol.dataReceived(ack.encoded)


    def _serverDown(self):
        self.transport.loseConnection()
        self.transport.clear()
        del self.protocol

    def _rebuild(self):
        self.protocol  = self.factory.buildProtocol(0)
        self.transport.protocol = self.protocol
        MQTTBaseProtocol.callLater = self.clock.callLater
        self.protocol.makeConnection(self.transport)


    def _publish(self, n, qos, topic, msg):
        dl = []
        for i in range(0,n):
            dl.append(self.protocol.publish(topic=topic, qos=qos, message=msg))
        self.transport.clear()
        for d in dl:
            if qos == 0:
                self.assertEqual(None, self.successResultOf(d))
            else:
                self.assertNoResult(d)
        return dl
    
    def _puback(self, dl):
        ackl = []
        for i in range(0, len(dl)):
            ack= PUBACK()
            ack.msgId = dl[i].msgId
            ackl.append(ack)
        encoded = bytearray()
        for ack in ackl:
            encoded.extend(ack.encode())
        self.protocol.dataReceived(encoded)
        self.transport.clear()
        for i in range(0, len(dl)):
            self.assertEqual(dl[i].msgId, self.successResultOf(dl[i]))

    def _pubrec(self, dl):
        recl = []
        for i in range(0, len(dl)):
            rec= PUBREC()
            rec.msgId = dl[i].msgId
            recl.append(rec)
        encoded = bytearray()
        for rec in recl:
            encoded.extend(rec.encode())
        self.protocol.dataReceived(encoded)
        self.transport.clear()
        for i in range(0, len(dl)):
            self.assertNoResult(dl[i])


    def _pubcomp(self, dl):
        compl = []
        for i in range(0, len(dl)):
            comp= PUBCOMP()
            comp.msgId = dl[i].msgId
            compl.append(comp)
        encoded = bytearray()
        for rec in compl:
            encoded.extend(comp.encode())
        self.protocol.dataReceived(encoded)
        self.transport.clear()
        for i in range(0, len(dl)):
            self.assertEqual(dl[i].msgId, self.successResultOf(dl[i]))



    def test_publish_single_qos0(self):
        self._connect()
        d = self.protocol.publish(topic="foo/bar/baz1", qos=0, message="hello world 0")
        self.assertEqual(len(self.protocol.factory.queuePublishTx),  0)
        self.assertEqual(None, self.successResultOf(d))

    def test_publish_single_qos1(self):
        self._connect()
        d = self.protocol.publish(topic="foo/bar/baz1", qos=1, message="hello world 1")
        self.assertEqual(len(self.protocol.factory.queuePublishTx),  1)
        self.transport.clear()
        ack = PUBACK()
        ack.msgId = d.msgId
        self.protocol.dataReceived(ack.encode())
        self.assertEqual(len(self.protocol.factory.queuePublishTx),  0)
        self.assertEqual(ack.msgId, self.successResultOf(d))

    def test_publish_single_qos2(self):
        self._connect()
        d = self.protocol.publish(topic="foo/bar/baz1", qos=2, message="hello world 2")
        self.assertEqual(len(self.protocol.factory.queuePublishTx),  1)
        self.assertEqual(len(self.protocol.factory.queuePubRelease), 0)
        self.transport.clear()
        ack = PUBREC()
        ack.msgId = d.msgId
        self.protocol.dataReceived(ack.encode())
        self.transport.clear()
        self.assertEqual(len(self.protocol.factory.queuePublishTx),  0)
        self.assertEqual(len(self.protocol.factory.queuePubRelease), 1)
        ack = PUBCOMP()
        ack.msgId = d.msgId
        self.protocol.dataReceived(ack.encode())
        self.assertEqual(len(self.protocol.factory.queuePublishTx),  0)
        self.assertEqual(len(self.protocol.factory.queuePubRelease), 0)
        self.assertEqual(ack.msgId, self.successResultOf(d))

    def test_publish_several_qos0(self):
        self._connect()
        dl = self._publish(n=3, qos=0, topic="foo/bar/baz", msg="Hello World")
        self.assertEqual(len(self.protocol.factory.queuePublishTx),  0)
        

    def test_publish_several_qos1(self):
        self._connect()
        dl = self._publish(n=3, qos=1, topic="foo/bar/baz", msg="Hello World")
        self.assertEqual(len(self.protocol.factory.queuePublishTx),  len(dl))
        self._puback(dl)
        self.assertEqual(len(self.protocol.factory.queuePublishTx),  0)
        

    def test_publish_several_qos2(self):
        self._connect()
        dl = self._publish(n=3, qos=2, topic="foo/bar/baz", msg="Hello World")
        self.assertEqual(len(self.protocol.factory.queuePublishTx),  len(dl))
        self.assertEqual(len(self.protocol.factory.queuePubRelease), 0)
        self._pubrec(dl)
        self.assertEqual(len(self.protocol.factory.queuePublishTx),  0)
        self.assertEqual(len(self.protocol.factory.queuePubRelease), len(dl))
        self._pubcomp(dl)
        self.assertEqual(len(self.protocol.factory.queuePublishTx),  0)
        self.assertEqual(len(self.protocol.factory.queuePubRelease), 0)


    def test_lost_session(self):
        self._connect()
        dl = self._publish(n=3, qos=2, topic="foo/bar/baz", msg="Hello World")
        self.assertEqual(len(self.protocol.factory.queuePublishTx),  len(dl))
        self.assertEqual(len(self.protocol.factory.queuePubRelease), 0)
        self._serverDown()
        self.assertEqual(len(self.factory.queuePublishTx),  0)
        self.assertEqual(len(self.factory.queuePubRelease), 0)
        for d in dl:
            self.failureResultOf(d).trap(error.ConnectionDone)
       

    def test_persistent_session_qos1(self):
        self._connect(cleanStart=False)
        dl = self._publish(n=3, qos=1, topic="foo/bar/baz", msg="Hello World")
        self.assertEqual(len(self.protocol.factory.queuePublishTx),  len(dl))
        self._serverDown()
        self.assertEqual(len(self.factory.queuePublishTx),  len(dl))
        for d in dl:
            self.assertNoResult(d)
        self._rebuild()
        self._connect(cleanStart=False)
        self.assertEqual(len(self.protocol.factory.queuePublishTx),  len(dl))
        self._puback(dl)
        self.assertEqual(len(self.protocol.factory.queuePublishTx),  0)


    def test_persistent_session_qos2(self):
        self._connect(cleanStart=False)
        dl = self._publish(n=3, qos=2, topic="foo/bar/baz", msg="Hello World")
        self.assertEqual(len(self.protocol.factory.queuePublishTx),  len(dl))
        self.assertEqual(len(self.protocol.factory.queuePubRelease), 0)
        self._serverDown()
        for d in dl:
            self.assertNoResult(d)
        self._rebuild()
        self._connect(cleanStart=False)
        self.assertEqual(len(self.factory.queuePublishTx),  len(dl))
        self.assertEqual(len(self.protocol.factory.queuePubRelease), 0)
        self._pubrec(dl)
        self.assertEqual(len(self.factory.queuePublishTx), 0 )
        self.assertEqual(len(self.protocol.factory.queuePubRelease), len(dl))
        self._pubcomp(dl)
        self.assertEqual(len(self.factory.queuePublishTx),  0)
        self.assertEqual(len(self.protocol.factory.queuePubRelease), 0)



    def test_persistent_release_qos2(self):
        self._connect(cleanStart=False)
        dl = self._publish(n=3, qos=2, topic="foo/bar/baz", msg="Hello World")
        #- generate two ACK and simulate a server disconnect with a new client protocol
        # being built on the client 
        self._pubrec(dl[:-1])   # Only the first two

        self._serverDown()
        self.assertNoResult(dl[0])
        self.assertNoResult(dl[1])
        self.assertNoResult(dl[2])
        self._rebuild()
        self.assertEqual(len(self.protocol.factory.queuePublishTx),  1)
        self.assertEqual(len(self.protocol.factory.queuePubRelease),  2)
        # Reconnect with the new client protcol object
        self._connect(cleanStart=False)
        self._pubrec(dl[-1:])   # send the last one
        self.assertEqual(len(self.protocol.factory.queuePublishTx), 0)
        self.assertEqual(len(self.protocol.factory.queuePubRelease), 3)
        self._pubcomp(dl[0:1])   # send the first comp
        self.assertEqual(len(self.factory.queuePublishTx),  0)
        self.assertEqual(len(self.protocol.factory.queuePubRelease), 2)
        self._pubcomp(dl[1:2])   # send the second comp
        self.assertEqual(len(self.factory.queuePublishTx),  0)
        self.assertEqual(len(self.protocol.factory.queuePubRelease), 1)
        self._pubcomp(dl[-1:])   # send the last comp
        self.assertEqual(len(self.factory.queuePublishTx),  0)
        self.assertEqual(len(self.protocol.factory.queuePubRelease), 0)
class TestMQTTPublisherIntervals(unittest.TestCase):


    def setUp(self):
        '''
        Set up a conencted state
        '''
        self.transport = proto_helpers.StringTransportWithDisconnection()
        self.clock     = task.Clock()
        MQTTBaseProtocol.callLater = self.clock.callLater
        self.factory   = MQTTFactory(MQTTFactory.PUBLISHER)
        self.addr = IPv4Address('TCP','localhost',1880)
        self._rebuild()
        # Just to generate connection contexts
        

    def _connect(self, cleanStart=True):
        '''
        Go to connected state
        '''
        ack = CONNACK()
        ack.session = False
        ack.resultCode = 0
        ack.encode()
        self.protocol.connect("TwistedMQTT-pub", keepalive=0, cleanStart=cleanStart, version=v31)
        self.transport.clear()
        self.protocol.dataReceived(ack.encoded)

    def _rebuild(self):
        self.protocol  = self.factory.buildProtocol(self.addr)
        self.transport.protocol = self.protocol
        MQTTBaseProtocol.callLater = self.clock.callLater
        self.protocol.makeConnection(self.transport)


    def _publish(self, n, qos, topic, msg, window=None):
        if window is None:
            self.protocol.setWindowSize(n)
        else:
            self.protocol.setWindowSize(window)
        dl = []
        for i in range(0,n):
            dl.append(self.protocol.publish(topic=topic, qos=qos, message=msg))
        self.transport.clear()
        for d in dl:
            if qos == 0:
                self.assertEqual(None, self.successResultOf(d))
            else:
                self.assertNoResult(d)
        return dl
    
    def _puback(self, dl):
        ackl = []
        for i in range(0, len(dl)):
            ack= PUBACK()
            ack.msgId = dl[i].msgId
            ackl.append(ack)
        encoded = bytearray()
        for ack in ackl:
            encoded.extend(ack.encode())
        self.protocol.dataReceived(encoded)
        self.transport.clear()
        for i in range(0, len(dl)):
            self.assertEqual(dl[i].msgId, self.successResultOf(dl[i]))

    def test_publish_very_large_qos1(self):
        message = '0123456789ABCDEF'*1000000 # Large PDU
        self._connect()

        # Test at 1MByte/sec
        self.protocol.setBandwith(1000000.0)
        d = self.protocol.publish(topic="foo/bar/baz1", qos=1, message=message)
        self.assertEqual(self.protocol.factory.windowPublish[self.addr][1].dup, False)
        self.transport.clear()
        self.clock.advance(10)
        self.assertEqual(self.protocol.factory.windowPublish[self.addr][1].dup, False)
        ack = PUBACK()
        ack.msgId = d.msgId
        self.protocol.dataReceived(ack.encode())
        self.transport.clear()
        self.assertEqual(ack.msgId, self.successResultOf(d))

        # A large PDU with a large bandwith estimation may retransmit
        self.protocol.setBandwith(10000000.0)
        d = self.protocol.publish(topic="foo/bar/baz1", qos=1, message=message)
        self.assertEqual(self.protocol.factory.windowPublish[self.addr][2].dup, False)
        self.transport.clear()
        self.clock.advance(10)
        self.assertEqual(self.protocol.factory.windowPublish[self.addr][2].dup, True)
        ack = PUBACK()
        ack.msgId = d.msgId
        self.protocol.dataReceived(ack.encode())
        self.transport.clear()
        self.assertEqual(ack.msgId, self.successResultOf(d))
Example #11
0
                if payload['slots'][0]['value']['value'] == 'white':
                    self.datastore.rgbw_brightness[3]=float(payload['slots'][1]['value']['value']/100)
                if payload['slots'][0]['value']['value'] == 'master':
                    self.datastore.master_brightness=float(payload['slots'][0]['value']['value']/100)
                print(self.datastore.rgbw_brightness)
                        


    def onDisconnection(self, reason):
        '''
        get notfied of disconnections
        and get a deferred for a new protocol object (next retry)
        '''
        self.log.debug("<Connection was lost !> <reason={r}>", r=reason)
        self.whenConnected().addCallback(self.connectToBroker)


if __name__ == '__main__':
    import sys
    log = Logger()
    startLogging()
    setLogLevel(namespace='mqtt',     levelStr='debug')
    setLogLevel(namespace='__main__', levelStr='debug')

    factory    = MQTTFactory(profile=MQTTFactory.SUBSCRIBER)
    myEndpoint = clientFromString(reactor, BROKER)
    serv       = MQTTService(myEndpoint, factory)
    serv.startService()
    reactor.run()
    
Example #12
0
    def __init__(self, core_reactor, options, config):
        log_timer = Log(logging.DEBUG).start()
        Log(logging.INFO).log("Service", "state", lambda: "[anode] initialising")
        self.core_reactor = core_reactor
        self.options = options
        self.config = config
        self.plugins = {}
        self.certificate = pem.twisted.certificateOptionsFromFiles(self.options.certificate) \
            if os.path.isfile(self.options.certificate) else None
        self.web_ws = WebWsFactory(u"ws" + ("" if self.certificate is None else "s") + "://"
                                   + self.config["host"] + ":" + str(self.config["port"]), self, self.certificate)
        self.web_ws.protocol = WebWs
        self.web_rest = WebRest(self, "http" + ("" if self.certificate is None else "s") + "://"
                                + self.config["host"] + ":" + str(self.config["port"]))
        self.web_pool = HTTPConnectionPool(reactor, persistent=True)
        self.publish_service = None
        self.publish = "publish_host" in self.config and len(self.config["publish_host"]) > 0 and \
                       "publish_port" in self.config and self.config["publish_port"] > 0
        if self.publish:
            access_key = config["profile"]["VERNEMQ_ACCESS_KEY"] if "VERNEMQ_ACCESS_KEY" in config["profile"] else None
            secret_key = config["profile"]["VERNEMQ_SECRET_KEY"] if "VERNEMQ_SECRET_KEY" in config["profile"] else None
            mqtt_client_string = clientFromString(reactor, "tcp:" + self.config["publish_host"] + ":" + str(self.config["publish_port"]))
            self.publish_service = MqttPublishService(mqtt_client_string, MQTTFactory(profile=MQTTFactory.PUBLISHER),
                                                      KEEPALIVE_DEFAULT_SECONDS, access_key, secret_key)

        def looping_call(loop_function, loop_seconds):
            loop_call = LoopingCall(loop_function)
            loop_call.clock = self.core_reactor
            loop_call.start(loop_seconds)

        if "model_pull_seconds" in self.config and self.config["model_pull_seconds"] > 0:
            model_pull = ModelPull(self, "pullmodel", {
                "pool": self.web_pool, "db_dir": self.options.db_dir,
                "profile": self.config["profile"],
                "model_pull_region": self.config["model_pull_region"] if "model_pull_region" in self.config else S3_REGION,
                "model_pull_bucket": (self.config["model_pull_bucket"] if "model_pull_bucket" in self.config else S3_BUCKET) + (
                    self.config["model_pull_bucket_snapshot"] if ("model_pull_bucket_snapshot" in self.config and
                                                                  APP_VERSION.endswith("-SNAPSHOT")) else "")}, self.core_reactor)
            looping_call(model_pull.poll, self.config["model_pull_seconds"])
        if "plugin" in self.config and self.config["plugin"] is not None:
            for plugin_name in self.config["plugin"]:
                self.config["plugin"][plugin_name]["pool"] = self.web_pool
                self.config["plugin"][plugin_name]["db_dir"] = self.options.db_dir
                self.config["plugin"][plugin_name]["profile"] = self.config["profile"]
                if self.publish_service is not None:
                    self.config["plugin"][plugin_name]["publish_service"] = self.publish_service
                if "publish_batch_seconds" in self.config:
                    self.config["plugin"][plugin_name]["publish_batch_seconds"] = self.config["publish_batch_seconds"]
                if "publish_status_topic" in self.config:
                    self.config["plugin"][plugin_name]["publish_status_topic"] = self.config["publish_status_topic"]
                if "publish_push_data_topic" in self.config:
                    self.config["plugin"][plugin_name]["publish_push_data_topic"] = self.config["publish_push_data_topic"]
                if "publish_push_metadata_topic" in self.config:
                    self.config["plugin"][plugin_name]["publish_push_metadata_topic"] = self.config["publish_push_metadata_topic"]
                if "publish_batch_datum_topic" in self.config:
                    self.config["plugin"][plugin_name]["publish_batch_datum_topic"] = self.config["publish_batch_datum_topic"]
                self.plugins[plugin_name] = Plugin.get(self, plugin_name, self.config["plugin"][plugin_name], self.core_reactor)
                if "poll_seconds" in self.config["plugin"][plugin_name] and self.config["plugin"][plugin_name]["poll_seconds"] > 0:
                    looping_call(self.plugins[plugin_name].poll, self.config["plugin"][plugin_name]["poll_seconds"])
                if "repeat_seconds" in self.config["plugin"][plugin_name] and self.config["plugin"][plugin_name]["repeat_seconds"] > 0:
                    looping_call(self.plugins[plugin_name].repeat, self.config["plugin"][plugin_name]["repeat_seconds"])
        for plugin in self.plugins.itervalues():
            if "history_partition_seconds" in self.config["plugin"][plugin.name] and \
                    self.config["plugin"][plugin.name]["history_partition_seconds"] > 0 and \
                    "repeat_seconds" in self.config["plugin"][plugin_name] and \
                    self.config["plugin"][plugin_name]["repeat_seconds"] >= 0:
                time_current = plugin.get_time()
                time_partition = self.config["plugin"][plugin.name]["history_partition_seconds"]
                time_partition_next = time_partition - (time_current - plugin.get_time_period(time_current, time_partition))
                plugin_partition_call = LoopingCall(self.plugins[plugin.name].repeat, force=True)
                plugin_partition_call.clock = self.core_reactor
                self.core_reactor.callLater(time_partition_next,
                                            lambda _plugin_partition_call, _time_partition:
                                            _plugin_partition_call.start(_time_partition), plugin_partition_call, time_partition)
        if self.publish and "publish_batch_seconds" in self.config and self.config["publish_batch_seconds"] > 0:
            looping_call(self.publish_datums, self.config["publish_batch_seconds"])
        if "save_seconds" in self.config and self.config["save_seconds"] > 0:
            looping_call(self.store_state, self.config["save_seconds"])
        log_timer.log("Service", "timer", lambda: "[anode] initialised", context=self.__init__)
Example #13
0
class TestMQTTPublisherIntervals(unittest.TestCase):
    def setUp(self):
        '''
        Set up a conencted state
        '''
        self.transport = proto_helpers.StringTransportWithDisconnection()
        self.clock = task.Clock()
        MQTTBaseProtocol.callLater = self.clock.callLater
        self.factory = MQTTFactory(MQTTFactory.PUBLISHER)
        self.addr = IPv4Address('TCP', 'localhost', 1880)
        self._rebuild()
        # Just to generate connection contexts

    def _connect(self, cleanStart=True):
        '''
        Go to connected state
        '''
        ack = CONNACK()
        ack.session = False
        ack.resultCode = 0
        ack.encode()
        self.protocol.connect("TwistedMQTT-pub",
                              keepalive=0,
                              cleanStart=cleanStart,
                              version=v31)
        self.transport.clear()
        self.protocol.dataReceived(ack.encoded)

    def _rebuild(self):
        self.protocol = self.factory.buildProtocol(self.addr)
        self.transport.protocol = self.protocol
        MQTTBaseProtocol.callLater = self.clock.callLater
        self.protocol.makeConnection(self.transport)

    def _publish(self, n, qos, topic, msg, window=None):
        if window is None:
            self.protocol.setWindowSize(n)
        else:
            self.protocol.setWindowSize(window)
        dl = []
        for i in range(0, n):
            dl.append(self.protocol.publish(topic=topic, qos=qos, message=msg))
        self.transport.clear()
        for d in dl:
            if qos == 0:
                self.assertEqual(None, self.successResultOf(d))
            else:
                self.assertNoResult(d)
        return dl

    def _puback(self, dl):
        ackl = []
        for i in range(0, len(dl)):
            ack = PUBACK()
            ack.msgId = dl[i].msgId
            ackl.append(ack)
        encoded = bytearray()
        for ack in ackl:
            encoded.extend(ack.encode())
        self.protocol.dataReceived(encoded)
        self.transport.clear()
        for i in range(0, len(dl)):
            self.assertEqual(dl[i].msgId, self.successResultOf(dl[i]))

    def test_publish_very_large_qos1(self):
        message = '0123456789ABCDEF' * 1000000  # Large PDU
        self._connect()

        # Test at 1MByte/sec
        self.protocol.setBandwith(1000000.0)
        d = self.protocol.publish(topic="foo/bar/baz1", qos=1, message=message)
        self.assertEqual(self.protocol.factory.windowPublish[self.addr][1].dup,
                         False)
        self.transport.clear()
        self.clock.advance(10)
        self.assertEqual(self.protocol.factory.windowPublish[self.addr][1].dup,
                         False)
        ack = PUBACK()
        ack.msgId = d.msgId
        self.protocol.dataReceived(ack.encode())
        self.transport.clear()
        self.assertEqual(ack.msgId, self.successResultOf(d))

        # A large PDU with a large bandwith estimation may retransmit
        self.protocol.setBandwith(10000000.0)
        d = self.protocol.publish(topic="foo/bar/baz1", qos=1, message=message)
        self.assertEqual(self.protocol.factory.windowPublish[self.addr][2].dup,
                         False)
        self.transport.clear()
        self.clock.advance(10)
        self.assertEqual(self.protocol.factory.windowPublish[self.addr][2].dup,
                         True)
        ack = PUBACK()
        ack.msgId = d.msgId
        self.protocol.dataReceived(ack.encode())
        self.transport.clear()
        self.assertEqual(ack.msgId, self.successResultOf(d))
Example #14
0
class TestMQTTPublisher1(unittest.TestCase):
    def setUp(self):
        '''
        Set up a conencted state
        '''
        self.transport = proto_helpers.StringTransportWithDisconnection()
        self.clock = task.Clock()
        MQTTBaseProtocol.callLater = self.clock.callLater
        self.factory = MQTTFactory(MQTTFactory.PUBLISHER)
        self._rebuild()
        # Just to generate connection contexts

    def _connect(self, cleanStart=True):
        '''
        Go to connected state
        '''
        ack = CONNACK()
        ack.session = False
        ack.resultCode = 0
        ack.encode()
        self.protocol.connect("TwistedMQTT-pub",
                              keepalive=0,
                              cleanStart=cleanStart,
                              version=v31)
        self.transport.clear()
        self.protocol.dataReceived(ack.encoded)

    def _serverDown(self):
        self.transport.loseConnection()
        self.transport.clear()
        del self.protocol

    def _rebuild(self):
        self.addr = IPv4Address('TCP', 'localhost', 1880)
        self.protocol = self.factory.buildProtocol(self.addr)
        self.transport.protocol = self.protocol
        MQTTBaseProtocol.callLater = self.clock.callLater
        self.protocol.makeConnection(self.transport)

    def _publish(self, n, qos, topic, msg, window=None):
        if window is None:
            self.protocol.setWindowSize(n)
        else:
            self.protocol.setWindowSize(window)
        dl = []
        for i in range(0, n):
            dl.append(self.protocol.publish(topic=topic, qos=qos, message=msg))
        self.transport.clear()
        for d in dl:
            if qos == 0:
                self.assertEqual(None, self.successResultOf(d))
            else:
                self.assertNoResult(d)
        return dl

    def _puback(self, dl):
        ackl = []
        for i in range(0, len(dl)):
            ack = PUBACK()
            ack.msgId = dl[i].msgId
            ackl.append(ack)
        encoded = bytearray()
        for ack in ackl:
            encoded.extend(ack.encode())
        self.protocol.dataReceived(encoded)
        self.transport.clear()
        for i in range(0, len(dl)):
            self.assertEqual(dl[i].msgId, self.successResultOf(dl[i]))

    def _pubrec(self, dl):
        recl = []
        for i in range(0, len(dl)):
            rec = PUBREC()
            rec.msgId = dl[i].msgId
            recl.append(rec)
        encoded = bytearray()
        for rec in recl:
            encoded.extend(rec.encode())
        self.protocol.dataReceived(encoded)
        self.transport.clear()
        for i in range(0, len(dl)):
            self.assertNoResult(dl[i])

    def _pubcomp(self, dl):
        compl = []
        for i in range(0, len(dl)):
            comp = PUBCOMP()
            comp.msgId = dl[i].msgId
            compl.append(comp)
        encoded = bytearray()
        for comp in compl:
            encoded.extend(comp.encode())
        self.protocol.dataReceived(encoded)
        self.transport.clear()
        for i in range(0, len(dl)):
            self.assertEqual(dl[i].msgId, self.successResultOf(dl[i]))

    def test_publish_single_qos0(self):
        self._connect()
        d = self.protocol.publish(topic="foo/bar/baz1",
                                  qos=0,
                                  message="hello world 0")
        self.assertEqual(len(self.protocol.factory.windowPublish[self.addr]),
                         0)
        self.assertEqual(None, self.successResultOf(d))

    def test_publish_single_qos1(self):
        self._connect()
        d = self.protocol.publish(topic="foo/bar/baz1",
                                  qos=1,
                                  message="hello world 1")
        self.assertEqual(len(self.protocol.factory.windowPublish[self.addr]),
                         1)
        self.transport.clear()
        ack = PUBACK()
        ack.msgId = d.msgId
        self.protocol.dataReceived(ack.encode())
        self.assertEqual(len(self.protocol.factory.windowPublish[self.addr]),
                         0)
        self.assertEqual(ack.msgId, self.successResultOf(d))

    def test_publish_single_qos2(self):
        self._connect()
        d = self.protocol.publish(topic="foo/bar/baz1",
                                  qos=2,
                                  message="hello world 2")
        self.assertEqual(len(self.protocol.factory.windowPublish[self.addr]),
                         1)
        self.assertEqual(
            len(self.protocol.factory.windowPubRelease[self.addr]), 0)
        self.transport.clear()
        ack = PUBREC()
        ack.msgId = d.msgId
        self.protocol.dataReceived(ack.encode())
        self.transport.clear()
        self.assertEqual(len(self.protocol.factory.windowPublish[self.addr]),
                         0)
        self.assertEqual(
            len(self.protocol.factory.windowPubRelease[self.addr]), 1)
        ack = PUBCOMP()
        ack.msgId = d.msgId
        self.protocol.dataReceived(ack.encode())
        self.assertEqual(len(self.protocol.factory.windowPublish[self.addr]),
                         0)
        self.assertEqual(
            len(self.protocol.factory.windowPubRelease[self.addr]), 0)
        self.assertEqual(ack.msgId, self.successResultOf(d))

    def test_publish_several_qos0(self):
        self._connect()
        dl = self._publish(n=3, qos=0, topic="foo/bar/baz", msg="Hello World")
        self.assertEqual(len(self.protocol.factory.windowPublish[self.addr]),
                         0)

    def test_publish_several_qos1(self):
        self._connect()
        dl = self._publish(n=3, qos=1, topic="foo/bar/baz", msg="Hello World")
        self.assertEqual(len(self.protocol.factory.windowPublish[self.addr]),
                         len(dl))
        self._puback(dl)
        self.assertEqual(len(self.protocol.factory.windowPublish[self.addr]),
                         0)

    def test_publish_several_qos2(self):
        self._connect()
        dl = self._publish(n=3, qos=2, topic="foo/bar/baz", msg="Hello World")
        self.assertEqual(len(self.protocol.factory.windowPublish[self.addr]),
                         len(dl))
        self.assertEqual(
            len(self.protocol.factory.windowPubRelease[self.addr]), 0)
        self._pubrec(dl)
        self.assertEqual(len(self.protocol.factory.windowPublish[self.addr]),
                         0)
        self.assertEqual(
            len(self.protocol.factory.windowPubRelease[self.addr]), len(dl))
        self._pubcomp(dl)
        self.assertEqual(len(self.protocol.factory.windowPublish[self.addr]),
                         0)
        self.assertEqual(
            len(self.protocol.factory.windowPubRelease[self.addr]), 0)

    def test_publish_many_qos1(self):
        '''
        Test enqueuing when not all ACKs arrives
        '''
        self._connect()
        window = 3
        n = 7
        dl = self._publish(n=n,
                           window=window,
                           qos=1,
                           topic="foo/bar/baz",
                           msg="Hello World")
        self.assertEqual(len(self.protocol.factory.windowPublish[self.addr]),
                         window)
        self.assertEqual(len(self.protocol.factory.queuePublishTx[self.addr]),
                         n - window)
        self._puback(dl[0:window])
        self.assertEqual(len(self.protocol.factory.windowPublish[self.addr]),
                         window)
        self.assertEqual(len(self.protocol.factory.queuePublishTx[self.addr]),
                         1)
        self._puback(dl[window:2 * window])
        self.assertEqual(len(self.protocol.factory.windowPublish[self.addr]),
                         1)
        self.assertEqual(len(self.protocol.factory.queuePublishTx[self.addr]),
                         0)
        self._puback(dl[2 * window:])
        self.assertEqual(len(self.protocol.factory.windowPublish[self.addr]),
                         0)
        self.assertEqual(len(self.protocol.factory.queuePublishTx[self.addr]),
                         0)

    def test_publish_many_qos2(self):
        '''
        Test enqueuing when not all ACKs arrives
        '''
        self._connect()
        window = 3
        n = 7
        dl = self._publish(n=n,
                           window=window,
                           qos=2,
                           topic="foo/bar/baz",
                           msg="Hello World")
        self.assertEqual(len(self.protocol.factory.windowPublish[self.addr]),
                         window)
        self.assertEqual(len(self.protocol.factory.queuePublishTx[self.addr]),
                         n - window)
        self._pubrec(dl[0:window])
        self.assertEqual(len(self.protocol.factory.windowPublish[self.addr]),
                         0)
        self.assertEqual(
            len(self.protocol.factory.windowPubRelease[self.addr]), window)
        self._pubcomp(dl[0:window])
        self.assertEqual(len(self.protocol.factory.queuePublishTx[self.addr]),
                         n - 2 * window)
        self.assertEqual(len(self.protocol.factory.windowPublish[self.addr]),
                         window)
        self.assertEqual(
            len(self.protocol.factory.windowPubRelease[self.addr]), 0)

    def test_lost_session(self):
        self._connect()
        dl = self._publish(n=3, qos=2, topic="foo/bar/baz", msg="Hello World")
        self.assertEqual(len(self.protocol.factory.windowPublish[self.addr]),
                         len(dl))
        self.assertEqual(
            len(self.protocol.factory.windowPubRelease[self.addr]), 0)
        self._serverDown()
        self.assertEqual(len(self.factory.windowPublish[self.addr]), 0)
        self.assertEqual(len(self.factory.windowPubRelease[self.addr]), 0)
        for d in dl:
            self.failureResultOf(d).trap(error.ConnectionDone)

    def test_persistent_session_qos1(self):
        self._connect(cleanStart=False)
        dl = self._publish(n=3, qos=1, topic="foo/bar/baz", msg="Hello World")
        self.assertEqual(len(self.protocol.factory.windowPublish[self.addr]),
                         len(dl))
        self._serverDown()
        self.assertEqual(len(self.factory.windowPublish[self.addr]), len(dl))
        for d in dl:
            self.assertNoResult(d)
        self._rebuild()
        self._connect(cleanStart=False)
        self.assertEqual(len(self.protocol.factory.windowPublish[self.addr]),
                         len(dl))
        self._puback(dl)
        self.assertEqual(len(self.protocol.factory.windowPublish[self.addr]),
                         0)

    def test_persistent_session_qos2(self):
        self._connect(cleanStart=False)
        dl = self._publish(n=3, qos=2, topic="foo/bar/baz", msg="Hello World")
        self.assertEqual(len(self.protocol.factory.windowPublish[self.addr]),
                         len(dl))
        self.assertEqual(
            len(self.protocol.factory.windowPubRelease[self.addr]), 0)
        self._serverDown()
        for d in dl:
            self.assertNoResult(d)
        self._rebuild()
        self._connect(cleanStart=False)
        self.assertEqual(len(self.factory.windowPublish[self.addr]), len(dl))
        self.assertEqual(
            len(self.protocol.factory.windowPubRelease[self.addr]), 0)
        self._pubrec(dl)
        self.assertEqual(len(self.factory.windowPublish[self.addr]), 0)
        self.assertEqual(
            len(self.protocol.factory.windowPubRelease[self.addr]), len(dl))
        self._pubcomp(dl)
        self.assertEqual(len(self.factory.windowPublish[self.addr]), 0)
        self.assertEqual(
            len(self.protocol.factory.windowPubRelease[self.addr]), 0)

    def test_persistent_release_qos2(self):
        self._connect(cleanStart=False)
        dl = self._publish(n=3, qos=2, topic="foo/bar/baz", msg="Hello World")
        #- generate two ACK and simulate a server disconnect with a new client protocol
        # being built on the client
        self._pubrec(dl[:-1])  # Only the first two

        self._serverDown()
        self.assertNoResult(dl[0])
        self.assertNoResult(dl[1])
        self.assertNoResult(dl[2])
        self._rebuild()
        self.assertEqual(len(self.protocol.factory.windowPublish[self.addr]),
                         1)
        self.assertEqual(
            len(self.protocol.factory.windowPubRelease[self.addr]), 2)
        # Reconnect with the new client protcol object
        self._connect(cleanStart=False)
        self._pubrec(dl[-1:])  # send the last one
        self.assertEqual(len(self.protocol.factory.windowPublish[self.addr]),
                         0)
        self.assertEqual(
            len(self.protocol.factory.windowPubRelease[self.addr]), 3)
        self._pubcomp(dl[0:1])  # send the first comp
        self.assertEqual(len(self.factory.windowPublish[self.addr]), 0)
        self.assertEqual(
            len(self.protocol.factory.windowPubRelease[self.addr]), 2)
        self._pubcomp(dl[1:2])  # send the second comp
        self.assertEqual(len(self.factory.windowPublish[self.addr]), 0)
        self.assertEqual(
            len(self.protocol.factory.windowPubRelease[self.addr]), 1)
        self._pubcomp(dl[-1:])  # send the last comp
        self.assertEqual(len(self.factory.windowPublish[self.addr]), 0)
        self.assertEqual(
            len(self.protocol.factory.windowPubRelease[self.addr]), 0)
Example #15
0
class TestMQTTPublisherDisconnect(unittest.TestCase):
    '''
    Testing various cases of disconnect callback
    '''
    def setUp(self):
        '''
        Set up a connencted state
        '''
        self.transport = proto_helpers.StringTransportWithDisconnection()
        self.clock = task.Clock()
        MQTTBaseProtocol.callLater = self.clock.callLater
        self.factory = MQTTFactory(MQTTFactory.PUBLISHER)
        self._rebuild()
        self.disconnected = False

    def _connect(self, cleanStart=True):
        '''
        Go to connected state
        '''
        ack = CONNACK()
        ack.session = False
        ack.resultCode = 0
        ack.encode()
        self.protocol.connect("TwistedMQTT-pub",
                              keepalive=0,
                              cleanStart=cleanStart,
                              version=v31)
        self.transport.clear()
        self.protocol.dataReceived(ack.encoded)

    def _disconnected(self, reason):
        self.disconnected = True

    def _serverDown(self):
        self.transport.loseConnection()
        self.transport.clear()
        del self.protocol

    def _rebuild(self):
        self.protocol = self.factory.buildProtocol(0)
        self.transport.protocol = self.protocol
        MQTTBaseProtocol.callLater = self.clock.callLater
        self.protocol.makeConnection(self.transport)

    def test_disconnect_1(self):
        '''Just connect and lose the transport'''
        self._connect()
        self.protocol.onDisconnection = self._disconnected
        self.transport.loseConnection()
        self.assertEqual(self.disconnected, True)

    def test_disconnect_2(self):
        '''connect and disconnect'''
        self._connect()
        self.protocol.onDisconnection = self._disconnected
        self.protocol.disconnect()
        self.assertEqual(self.disconnected, True)

    def test_disconnect_3(self):
        '''connect, generate a deferred and lose the transport'''
        self._connect()
        self.protocol.onDisconnection = self._disconnected
        d = self.protocol.publish(topic="foo/bar/baz1",
                                  qos=1,
                                  message="hello world 1")
        self.transport.clear()
        self.transport.loseConnection()
        self.assertEqual(self.disconnected, True)
        self.failureResultOf(d).trap(error.ConnectionDone)

    def test_disconnect_4(self):
        '''connect, generate a deferred and disconnect'''
        self._connect()
        self.protocol.onDisconnection = self._disconnected
        d = self.protocol.publish(topic="foo/bar/baz1",
                                  qos=1,
                                  message="hello world 1")
        self.transport.clear()
        self.protocol.disconnect()
        self.assertEqual(self.disconnected, True)
        self.failureResultOf(d).trap(error.ConnectionDone)

    def test_disconnect_5(self):
        '''connect with persistent session, generate a deferred and disconnect'''
        self._connect(cleanStart=False)
        self.protocol.onDisconnection = self._disconnected
        d = self.protocol.publish(topic="foo/bar/baz1",
                                  qos=1,
                                  message="hello world 1")
        self.transport.clear()
        self.protocol.disconnect()
        self.assertEqual(self.disconnected, True)
        self.assertNoResult(d)

    def test_disconnect_6(self):
        '''connect with persistent session, generate a deferred , rebuilds protocol'''
        self._connect(cleanStart=False)
        self.protocol.onDisconnection = self._disconnected
        d = self.protocol.publish(topic="foo/bar/baz1",
                                  qos=1,
                                  message="hello world 1")
        self._serverDown()
        self._rebuild()
        self.assertEqual(self.disconnected, True)
        self.assertNoResult(d)