Exemple #1
0
class TestMQTTBaseProtocol2(unittest.TestCase):

    def setUp(self):
        '''
        Set up a conencted state
        '''
        self.transport = proto_helpers.StringTransportWithDisconnection()
        self.clock     = task.Clock()
        MQTTBaseProtocol.callLater = self.clock.callLater
        self.factory   = MQTTFactory(MQTTFactory.SUBSCRIBER)
        self._rebuild()
      
    def tearDown(self):
        '''
        Needed because ping sets up a LoopingCall, outside Clock simulated callLater()
        '''
        self.transport.loseConnection()

    def _connect(self, keepalive=0, cleanStart=True):
        '''
        Go to connected state
        '''
        ack = CONNACK()
        ack.session = False
        ack.resultCode = 0
        ack.encode()
        self.protocol.connect("TwistedMQTT-sub", keepalive=keepalive, cleanStart=cleanStart, version=v31)
        self.transport.clear()
        self.protocol.dataReceived(ack.encoded)

    def _rebuild(self):
        self.protocol  = self.factory.buildProtocol(0)
        self.transport.protocol = self.protocol
        self.protocol.makeConnection(self.transport)


    def test_disconnect(self):
        self._connect()
        self.assertEqual(self.protocol.state, MQTTBaseProtocol.CONNECTED)
        self.protocol.disconnect()
        self.transport.clear()

    def test_ping(self):
        self._connect(keepalive=5)
        self.protocol.dataReceived(PINGRES().encode())
        self.transport.clear()
        self.assertEqual(self.protocol.state, MQTTBaseProtocol.CONNECTED)
      
    def test_ping_timeout(self):
        self._connect(keepalive=5)
        self.protocol.ping()
        self.transport.clear()
        self.clock.advance(6)
        self.assertEqual(self.protocol.state, MQTTBaseProtocol.IDLE)
class TestMQTTPublisherForbiddenOps(unittest.TestCase):
    '''
    Testing various cases of disconnect callback
    '''
    def setUp(self):
        '''
        Set up a connencted state
        '''
        self.transport = proto_helpers.StringTransportWithDisconnection()
        self.clock = task.Clock()
        MQTTBaseProtocol.callLater = self.clock.callLater
        self.factory = MQTTFactory(MQTTFactory.PUBLISHER)
        self._rebuild()
        self.disconnected = False
        self._rebuild()
        self._connect()

    def _connect(self, cleanStart=True):
        '''
        Go to connected state
        '''
        ack = CONNACK()
        ack.session = False
        ack.resultCode = 0
        ack.encode()
        self.protocol.connect("TwistedMQTT-pub",
                              keepalive=0,
                              cleanStart=cleanStart,
                              version=v31)
        self.transport.clear()
        self.protocol.dataReceived(ack.encoded)

    def _rebuild(self):
        self.protocol = self.factory.buildProtocol(0)
        self.transport.protocol = self.protocol
        MQTTBaseProtocol.callLater = self.clock.callLater
        self.protocol.makeConnection(self.transport)

    def test_forbidden_subscribe(self):
        '''Just connect and lose the transport'''
        d = self.protocol.subscribe("foo/bar/baz1", 2)
        self.failureResultOf(d).trap(MQTTStateError)

    def test_forbidden_unsubscribe(self):
        '''Just connect and lose the transport'''
        d = self.protocol.unsubscribe("foo/bar/baz1")
        self.failureResultOf(d).trap(MQTTStateError)

    def test_forbidden_publish_callback(self):
        '''Just connect and lose the transport'''
        def onPublish(topic, payload, qos, dup, retain, msgId):
            pass

        self.assertRaises(MQTTStateError, self.protocol.onPublish, onPublish)
Exemple #3
0
class PDUTestCase(unittest.TestCase):
    def test_buildProtocol_publisher(self):
        self.factory = MQTTFactory(MQTTFactory.PUBLISHER)
        p = self.factory.buildProtocol(0)
        self.assertIsInstance(p, MQTTPublisherProtocol)

    def test_buildProtocol_subscriber(self):
        self.factory = MQTTFactory(MQTTFactory.SUBSCRIBER)
        p = self.factory.buildProtocol(0)
        self.assertIsInstance(p, MQTTSubscriberProtocol)

    def test_buildProtocol_pubsubs(self):
        self.factory = MQTTFactory(MQTTFactory.SUBSCRIBER +
                                   MQTTFactory.PUBLISHER)
        p = self.factory.buildProtocol(0)
        self.assertIsInstance(p, MQTTPubSubsProtocol)

    def test_buildProtocol_other(self):
        self.factory = MQTTFactory(0)
        self.assertRaises(ValueError, self.factory.buildProtocol, 0)
Exemple #4
0
class TestMQTTBaseProtocol2(unittest.TestCase):

    def setUp(self):
        '''
        Set up a conencted state
        '''
        self.transport = proto_helpers.StringTransportWithDisconnection()
        self.clock     = task.Clock()
        MQTTBaseProtocol.callLater = self.clock.callLater
        self.factory   = MQTTFactory(MQTTFactory.SUBSCRIBER)
        self._rebuild()
      
    def tearDown(self):
        '''
        Needed because ping sets up a LoopingCall, outside Clock simulated callLater()
        '''
        self.transport.loseConnection()

    def _connect(self, keepalive=0, cleanStart=True):
        '''
        Go to connected state
        '''
        ack = CONNACK()
        ack.session = False
        ack.resultCode = 0
        ack.encode()
        self.protocol.connect("TwistedMQTT-sub", keepalive=keepalive, cleanStart=cleanStart, version=v31)
        self.transport.clear()
        self.protocol.dataReceived(ack.encoded)

    def _rebuild(self):
        self.protocol  = self.factory.buildProtocol(0)
        self.transport.protocol = self.protocol
        self.protocol.makeConnection(self.transport)


    def test_disconnect(self):
        self._connect()
        self.assertEqual(self.protocol.state, self.protocol.CONNECTED)
        self.protocol.disconnect()
        self.transport.clear()

    def test_ping(self):
        self._connect(keepalive=5)
        self.protocol.dataReceived(PINGRES().encode())
        self.transport.clear()
        self.assertEqual(self.protocol.state, self.protocol.CONNECTED)
      
    def test_ping_timeout(self):
        self._connect(keepalive=5)
        self.protocol.ping()
        self.transport.clear()
        self.clock.advance(6)
        self.assertEqual(self.protocol.state, self.protocol.IDLE)
class TestMQTTPublisherForbiddenOps(unittest.TestCase):
    '''
    Testing various cases of disconnect callback
    '''

    def setUp(self):
        '''
        Set up a connencted state
        '''
        self.transport = proto_helpers.StringTransportWithDisconnection()
        self.clock     = task.Clock()
        MQTTBaseProtocol.callLater = self.clock.callLater
        self.factory   = MQTTFactory(MQTTFactory.PUBLISHER)
        self._rebuild()
        self.disconnected = False
        self._rebuild()
        self._connect()

    def _connect(self, cleanStart=True):
        '''
        Go to connected state
        '''
        ack = CONNACK()
        ack.session = False
        ack.resultCode = 0
        ack.encode()
        self.protocol.connect("TwistedMQTT-pub", keepalive=0, cleanStart=cleanStart, version=v31)
        self.transport.clear()
        self.protocol.dataReceived(ack.encoded)


    def _rebuild(self):
        self.protocol  = self.factory.buildProtocol(0)
        self.transport.protocol = self.protocol
        MQTTBaseProtocol.callLater = self.clock.callLater
        self.protocol.makeConnection(self.transport)

    def test_forbidden_subscribe(self):
        '''Just connect and lose the transport'''
        d = self.protocol.subscribe("foo/bar/baz1", 2 )
        self.failureResultOf(d).trap(MQTTStateError)
      
    def test_forbidden_unsubscribe(self):
        '''Just connect and lose the transport'''
        d = self.protocol.unsubscribe("foo/bar/baz1")
        self.failureResultOf(d).trap(MQTTStateError)

    def test_forbidden_publish_callback(self):
        '''Just connect and lose the transport'''
        def onPublish(topic, payload, qos, dup, retain, msgId):
            pass
        self.assertRaises(MQTTStateError, self.protocol.onPublish, onPublish)
class TestMQTTSubscriber1(unittest.TestCase):


    def setUp(self):
        '''
        Set up a conencted state
        '''
        self.transport = proto_helpers.StringTransportWithDisconnection()
        self.clock     = task.Clock()
        MQTTBaseProtocol.callLater = self.clock.callLater
        self.factory   = MQTTFactory(MQTTFactory.SUBSCRIBER)
        self.addr = IPv4Address('TCP','localhost',1880)
        self._rebuild()
        self._connect()

    def _connect(self, cleanStart=True):
        '''
        Go to connected state
        '''
        ack = CONNACK()
        ack.session = False
        ack.resultCode = 0
        ack.encode()
        self.protocol.connect("TwistedMQTT-sub", keepalive=0, cleanStart=cleanStart, version=v31)
        self.transport.clear()
        self.protocol.dataReceived(ack.encoded)


    def _serverDown(self):
        self.transport.loseConnection()
        self.transport.clear()
        del self.protocol

    def _rebuild(self):
        self.protocol  = self.factory.buildProtocol(self.addr)
        self.transport.protocol = self.protocol
        MQTTBaseProtocol.callLater = self.clock.callLater
        self.protocol.makeConnection(self.transport)


    def _subscribe(self, n, qos, topic):
        self.protocol.setWindowSize(n)
        dl = []
        for i in range(0,n):
            t = "{0}{1}".format(topic, i)
            dl.append(self.protocol.subscribe(t, qos))
        self.transport.clear()
        for d in dl:
            self.assertNoResult(d)
        return dl

    def _unsubscribe(self, n, topic):
        self.protocol.setWindowSize(n)
        dl = []
        for i in range(0,n):
            t = "{0}{1}".format(topic, i)
            dl.append(self.protocol.unsubscribe(t))
        self.transport.clear()
        for d in dl:
            self.assertNoResult(d)
        return dl

    def test_subscribe_single(self):
        d = self.protocol.subscribe("foo/bar/baz1", 2 )
        self.transport.clear()
        ack = SUBACK()
        ack.msgId = d.msgId
        ack.granted = [(2, False)]
        self.protocol.dataReceived(ack.encode())
        self.assertEqual([(2, False)], self.successResultOf(d))

    def test_subscribe_single_large_qos(self):
        d = self.protocol.subscribe("foo/bar/baz1", 3)
        self.transport.clear()
        self.failureResultOf(d).trap(ValueError)

    def test_subscribe_single_negative_qos(self):
        d = self.protocol.subscribe("foo/bar/baz1", -1)
        self.transport.clear()
        self.failureResultOf(d).trap(ValueError)

    def test_subscribe_tuple(self):
        d = self.protocol.subscribe( ("foo/bar/baz1", 2) )
        self.transport.clear()
        ack = SUBACK()
        ack.msgId = d.msgId
        ack.granted = [(2, False)]
        self.protocol.dataReceived(ack.encode())
        self.assertEqual([(2, False)], self.successResultOf(d))

    def test_subscribe_list(self):
        d = self.protocol.subscribe( [ ("foo/bar/baz1", 2), ("foo/bar/baz2", 1), ("foo/bar/baz3", 0) ] )
        d.addCallback(self.assertEqual, [(2, False), (1, False), (0, False)] )
        self.transport.clear()
        ack = SUBACK()
        ack.msgId = d.msgId
        ack.granted = [(2, False), (1, False), (0, False)]
        self.protocol.dataReceived(ack.encode())
        self.assertEqual( [(2, False), (1, False), (0, False)], self.successResultOf(d))

    def test_subscribe_several_fail(self):
        dl = self._subscribe(n=3, qos=2, topic="foo/bar/baz")
        self.assertEqual(len(self.protocol.factory.windowSubscribe[self.addr]), 3)
        self._serverDown()
        for d in dl:
            self.failureResultOf(d).trap(error.ConnectionDone)
        

    def test_subscribe_several_window_fail(self):
        self.protocol.setWindowSize(3)
        dl = self._subscribe(n=3, qos=2, topic="foo/bar/baz")
        self.assertEqual(len(self.protocol.factory.windowSubscribe[self.addr]), 3)
        d4 = self.protocol.subscribe("foo/bar/baz3", 2 )
        self.assertEqual(len(self.protocol.factory.windowSubscribe[self.addr]), 3)
        self.failureResultOf(d4).trap(MQTTWindowError)
        self._serverDown()
        for d in dl:
            self.failureResultOf(d).trap(error.ConnectionDone)
        

    def test_unsubscribe_single(self):
        d = self.protocol.unsubscribe("foo/bar/baz1")
        self.transport.clear()
        ack = UNSUBACK()
        ack.msgId = d.msgId
        self.protocol.dataReceived(ack.encode())
        self.assertEqual(ack.msgId, self.successResultOf(d))


    def test_unsubscribe_list(self):
        d = self.protocol.unsubscribe( [ "foo/bar/baz1", "foo/bar/baz2", "foo/bar/baz3"] )
        self.transport.clear()
        ack = UNSUBACK()
        ack.msgId = d.msgId
        self.protocol.dataReceived(ack.encode())
        self.assertEqual(ack.msgId, self.successResultOf(d))

    def test_unsubscribe_several_fail(self):
        dl = self._unsubscribe(n=3, topic="foo/bar/baz")
        self.assertEqual(len(self.protocol.factory.windowUnsubscribe[self.addr]), 3)
        self._serverDown()
        for d in dl:
            self.failureResultOf(d).trap(error.ConnectionDone)

    def test_unsubscribe_several_window_fail(self):
        dl = self._unsubscribe(n=3, topic="foo/bar/baz")
        self.assertEqual(len(self.protocol.factory.windowUnsubscribe[self.addr]), 3)
        d4 = self.protocol.unsubscribe("foo/bar/baz4")
        self.assertEqual(len(self.protocol.factory.windowUnsubscribe[self.addr]), 3)
        self.failureResultOf(d4).trap(MQTTWindowError)
        self._serverDown()
        for d in dl:
            self.failureResultOf(d).trap(error.ConnectionDone)

    def test_publish_recv_qos0(self):
        def onPublish(topic, payload, qos, dup, retain, msgId):
            self.topic   = topic
            self.payload = payload.decode('utf-8')
            self.qos     = qos
            self.retain  = retain
            self.msgId   = msgId 
        self.protocol.onPublish = onPublish
        pub =PUBLISH()
        pub.qos     = 0
        pub.dup     = False
        pub.retain  = False
        pub.topic   = "foo/bar/baz0"
        pub.msgId   = None
        pub.payload = "Hello world 0"
        self.protocol.dataReceived(pub.encode())
        self.assertEqual(self.topic,   pub.topic)
        self.assertEqual(self.payload, pub.payload)
        self.assertEqual(self.qos,     pub.qos)
        self.assertEqual(self.retain,  pub.retain)
        self.assertEqual(self.msgId,   pub.msgId )

    def test_publish_recv_qos1(self):
        def onPublish(topic, payload, qos, dup, retain, msgId):
            self.topic   = topic
            self.payload = payload.decode('utf-8')
            self.qos     = qos
            self.retain  = retain
            self.msgId   = msgId 
            self.dup     = dup
        self.protocol.onPublish = onPublish
        pub =PUBLISH()
        pub.qos     = 1
        pub.dup     = False
        pub.retain  = False
        pub.topic   = "foo/bar/baz1"
        pub.msgId   = 1
        pub.payload = "Hello world 1"
        self.protocol.dataReceived(pub.encode())
        self.transport.clear()
        self.assertEqual(self.topic,   pub.topic)
        self.assertEqual(self.payload, pub.payload)
        self.assertEqual(self.qos,     pub.qos)
        self.assertEqual(self.retain,  pub.retain)
        self.assertEqual(self.msgId,   pub.msgId )
        self.assertEqual(self.dup,     pub.dup )

    def test_publish_recv_qos2(self):
        self.called = False
        def onPublish(topic, payload, qos, dup, retain, msgId):
            self.called = True
            self.topic   = topic
            self.payload = payload.decode('utf-8')
            self.qos     = qos
            self.retain  = retain
            self.msgId   = msgId 
            self.dup     = dup
        self.protocol.onPublish = onPublish
        pub =PUBLISH()
        pub.qos     = 2
        pub.dup     = False
        pub.retain  = False
        pub.topic   = "foo/bar/baz2"
        pub.msgId   = 1
        pub.payload = "Hello world 2"
        self.protocol.dataReceived(pub.encode())
        self.transport.clear()
        self.assertEqual(self.called, False)
        rel = PUBREL()
        rel.msgId = pub.msgId
        self.protocol.dataReceived(rel.encode())
        self.assertEqual(self.topic,   pub.topic)
        self.assertEqual(self.payload, pub.payload)
        self.assertEqual(self.qos,     pub.qos)
        self.assertEqual(self.retain,  pub.retain)
        self.assertEqual(self.msgId,   pub.msgId )
        self.assertEqual(self.dup,     pub.dup )
class TestMQTTSubscriber2(unittest.TestCase):


    def setUp(self):
        '''
        Set up a conencted state
        '''
        self.transport = proto_helpers.StringTransportWithDisconnection()
        self.clock     = task.Clock()
        MQTTBaseProtocol.callLater = self.clock.callLater
        self.factory   = MQTTFactory(MQTTFactory.SUBSCRIBER)
        self._rebuild()
       
    def _connect(self, cleanStart=True):
        '''
        Go to connected state
        '''
        ack = CONNACK()
        ack.session = False
        ack.resultCode = 0
        ack.encode()
        self.protocol.connect("TwistedMQTT-sub", keepalive=0, cleanStart=cleanStart, version=v31)
        self.transport.clear()
        self.protocol.dataReceived(ack.encoded)


    def _serverDown(self):
        self.transport.loseConnection()
        self.transport.clear()
        del self.protocol

    def _rebuild(self):
        self.protocol  = self.factory.buildProtocol(0)
        self.transport.protocol = self.protocol
        MQTTBaseProtocol.callLater = self.clock.callLater
        self.protocol.makeConnection(self.transport)


    def _subscribe(self, n, qos, topic):
        self.protocol.setWindowSize(n)
        dl = []
        for i in range(0,n):
            t = "{0}{1}".format(topic, i)
            dl.append(self.protocol.subscribe(t, qos))
        self.transport.clear()
        for d in dl:
            self.assertNoResult(d)
        return dl

    def _unsubscribe(self, n, topic):
        self.protocol.setWindowSize(n)
        dl = []
        for i in range(0,n):
            t = "{0}{1}".format(topic, i)
            dl.append(self.protocol.unsubscribe(t))
        self.transport.clear()
        for d in dl:
            self.assertNoResult(d)
        return dl

    def test_subscribe_setPubishHandler1(self):
        def onPublish(topic, payload, qos, dup, retain, msgId):
            self.called  = True
        self.protocol.onPublish = onPublish
        self._connect()
        d = self.protocol.subscribe("foo/bar/baz1", 2 )
        self.transport.clear()
        ack = SUBACK()
        ack.msgId = d.msgId
        ack.granted = [(2, False)]
        self.protocol.dataReceived(ack.encode())
        self.assertEqual([(2, False)], self.successResultOf(d))

    def test_subscribe_setPubishHandler2(self):
        def onPublish(topic, payload, qos, dup, retain, msgId):
            self.called  = True
        self._connect()
        self.protocol.onPublish = onPublish
        d = self.protocol.subscribe("foo/bar/baz1", 2 )
        self.transport.clear()
        ack = SUBACK()
        ack.msgId = d.msgId
        ack.granted = [(2, False)]
        self.protocol.dataReceived(ack.encode())
        self.assertEqual([(2, False)], self.successResultOf(d))
class TestMQTTSubscriberDisconnect(unittest.TestCase):
    '''
    Testing various cases of disconnect callback
    '''

    def setUp(self):
        '''
        Set up a connencted state
        '''
        self.transport = proto_helpers.StringTransportWithDisconnection()
        self.clock     = task.Clock()
        MQTTBaseProtocol.callLater = self.clock.callLater
        self.factory   = MQTTFactory(MQTTFactory.SUBSCRIBER)
        self._rebuild()
        self.disconnected = False

    def _connect(self, cleanStart=True):
        '''
        Go to connected state
        '''
        ack = CONNACK()
        ack.session = False
        ack.resultCode = 0
        ack.encode()
        self.protocol.connect("TwistedMQTT-sub", keepalive=0, cleanStart=cleanStart, version=v31)
        self.transport.clear()
        self.protocol.dataReceived(ack.encoded)

    def _disconnected(self, reason):
        self.disconnected = True

    def _rebuild(self):
        self.protocol  = self.factory.buildProtocol(0)
        self.transport.protocol = self.protocol
        MQTTBaseProtocol.callLater = self.clock.callLater
        self.protocol.makeConnection(self.transport)

    def _serverDown(self):
        self.transport.loseConnection()
        self.transport.clear()
        del self.protocol

    def test_disconnect_1(self):
        '''Just connect and lose the transport'''
        self._connect()
        self.protocol.onDisconnection = self._disconnected
        self.transport.loseConnection()
        self.assertEqual(self.disconnected, True)

    def test_disconnect_2(self):
        '''connect and disconnect'''
        self._connect()
        self.protocol.onDisconnection = self._disconnected
        self.protocol.disconnect()
        self.assertEqual(self.disconnected, True)

    def test_disconnect_3(self):
        '''connect, generate a deferred and lose the transport'''
        self._connect()
        self.protocol.onDisconnection = self._disconnected
        d = self.protocol.subscribe("foo/bar/baz1", 2 )
        self.transport.clear()
        self.transport.loseConnection()
        self.assertEqual(self.disconnected, True)
        self.failureResultOf(d).trap(error.ConnectionDone)

    def test_disconnect_4(self):
        '''connect, generate a deferred and disconnect'''
        self._connect()
        self.protocol.onDisconnection = self._disconnected
        d = self.protocol.subscribe("foo/bar/baz1", 2 )
        self.transport.clear()
        self.protocol.disconnect()
        self.assertEqual(self.disconnected, True)
        self.failureResultOf(d).trap(error.ConnectionDone)

    def test_disconnect_5(self):
        '''connect with persistent session, 
        enerate a deferred that will not errback 
        and then disconnect'''
        self._connect(cleanStart=False)
        self.protocol.onDisconnection = self._disconnected
        d = self.protocol.subscribe("foo/bar/baz1", 2 )
        self.transport.clear()
        self.protocol.disconnect()
        self.assertEqual(self.disconnected, True)
        self.assertNoResult(d)

    def test_disconnect_6(self):
        '''connect with persistent session, 
        generate a deferred that will not errback yet, 
        then rebuilds protocol'''
        self._connect(cleanStart=False)
        self.protocol.onDisconnection = self._disconnected
        d = self.protocol.subscribe("foo/bar/baz1", 2 )
        self._serverDown()
        self._rebuild()
        self.assertEqual(self.disconnected, True)
        self.assertNoResult(d)
Exemple #9
0
class TestMQTTPublisher1(unittest.TestCase):


    def setUp(self):
        '''
        Set up a conencted state
        '''
        self.transport = proto_helpers.StringTransportWithDisconnection()
        self.clock     = task.Clock()
        MQTTBaseProtocol.callLater = self.clock.callLater
        self.factory   = MQTTFactory(MQTTFactory.PUBLISHER)
        self._rebuild()

    def _connect(self, cleanStart=True):
        '''
        Go to connected state
        '''
        ack = CONNACK()
        ack.session = False
        ack.resultCode = 0
        ack.encode()
        self.protocol.connect("TwistedMQTT-pub", keepalive=0, cleanStart=cleanStart, version=v31)
        self.transport.clear()
        self.protocol.dataReceived(ack.encoded)


    def _serverDown(self):
        self.transport.loseConnection()
        self.transport.clear()
        del self.protocol

    def _rebuild(self):
        self.protocol  = self.factory.buildProtocol(0)
        self.transport.protocol = self.protocol
        MQTTBaseProtocol.callLater = self.clock.callLater
        self.protocol.makeConnection(self.transport)


    def _publish(self, n, qos, topic, msg):
        dl = []
        for i in range(0,n):
            dl.append(self.protocol.publish(topic=topic, qos=qos, message=msg))
        self.transport.clear()
        for d in dl:
            if qos == 0:
                self.assertEqual(None, self.successResultOf(d))
            else:
                self.assertNoResult(d)
        return dl
    
    def _puback(self, dl):
        ackl = []
        for i in range(0, len(dl)):
            ack= PUBACK()
            ack.msgId = dl[i].msgId
            ackl.append(ack)
        encoded = bytearray()
        for ack in ackl:
            encoded.extend(ack.encode())
        self.protocol.dataReceived(encoded)
        self.transport.clear()
        for i in range(0, len(dl)):
            self.assertEqual(dl[i].msgId, self.successResultOf(dl[i]))

    def _pubrec(self, dl):
        recl = []
        for i in range(0, len(dl)):
            rec= PUBREC()
            rec.msgId = dl[i].msgId
            recl.append(rec)
        encoded = bytearray()
        for rec in recl:
            encoded.extend(rec.encode())
        self.protocol.dataReceived(encoded)
        self.transport.clear()
        for i in range(0, len(dl)):
            self.assertNoResult(dl[i])


    def _pubcomp(self, dl):
        compl = []
        for i in range(0, len(dl)):
            comp= PUBCOMP()
            comp.msgId = dl[i].msgId
            compl.append(comp)
        encoded = bytearray()
        for rec in compl:
            encoded.extend(comp.encode())
        self.protocol.dataReceived(encoded)
        self.transport.clear()
        for i in range(0, len(dl)):
            self.assertEqual(dl[i].msgId, self.successResultOf(dl[i]))



    def test_publish_single_qos0(self):
        self._connect()
        d = self.protocol.publish(topic="foo/bar/baz1", qos=0, message="hello world 0")
        self.assertEqual(len(self.protocol.factory.queuePublishTx),  0)
        self.assertEqual(None, self.successResultOf(d))

    def test_publish_single_qos1(self):
        self._connect()
        d = self.protocol.publish(topic="foo/bar/baz1", qos=1, message="hello world 1")
        self.assertEqual(len(self.protocol.factory.queuePublishTx),  1)
        self.transport.clear()
        ack = PUBACK()
        ack.msgId = d.msgId
        self.protocol.dataReceived(ack.encode())
        self.assertEqual(len(self.protocol.factory.queuePublishTx),  0)
        self.assertEqual(ack.msgId, self.successResultOf(d))

    def test_publish_single_qos2(self):
        self._connect()
        d = self.protocol.publish(topic="foo/bar/baz1", qos=2, message="hello world 2")
        self.assertEqual(len(self.protocol.factory.queuePublishTx),  1)
        self.assertEqual(len(self.protocol.factory.queuePubRelease), 0)
        self.transport.clear()
        ack = PUBREC()
        ack.msgId = d.msgId
        self.protocol.dataReceived(ack.encode())
        self.transport.clear()
        self.assertEqual(len(self.protocol.factory.queuePublishTx),  0)
        self.assertEqual(len(self.protocol.factory.queuePubRelease), 1)
        ack = PUBCOMP()
        ack.msgId = d.msgId
        self.protocol.dataReceived(ack.encode())
        self.assertEqual(len(self.protocol.factory.queuePublishTx),  0)
        self.assertEqual(len(self.protocol.factory.queuePubRelease), 0)
        self.assertEqual(ack.msgId, self.successResultOf(d))

    def test_publish_several_qos0(self):
        self._connect()
        dl = self._publish(n=3, qos=0, topic="foo/bar/baz", msg="Hello World")
        self.assertEqual(len(self.protocol.factory.queuePublishTx),  0)
        

    def test_publish_several_qos1(self):
        self._connect()
        dl = self._publish(n=3, qos=1, topic="foo/bar/baz", msg="Hello World")
        self.assertEqual(len(self.protocol.factory.queuePublishTx),  len(dl))
        self._puback(dl)
        self.assertEqual(len(self.protocol.factory.queuePublishTx),  0)
        

    def test_publish_several_qos2(self):
        self._connect()
        dl = self._publish(n=3, qos=2, topic="foo/bar/baz", msg="Hello World")
        self.assertEqual(len(self.protocol.factory.queuePublishTx),  len(dl))
        self.assertEqual(len(self.protocol.factory.queuePubRelease), 0)
        self._pubrec(dl)
        self.assertEqual(len(self.protocol.factory.queuePublishTx),  0)
        self.assertEqual(len(self.protocol.factory.queuePubRelease), len(dl))
        self._pubcomp(dl)
        self.assertEqual(len(self.protocol.factory.queuePublishTx),  0)
        self.assertEqual(len(self.protocol.factory.queuePubRelease), 0)


    def test_lost_session(self):
        self._connect()
        dl = self._publish(n=3, qos=2, topic="foo/bar/baz", msg="Hello World")
        self.assertEqual(len(self.protocol.factory.queuePublishTx),  len(dl))
        self.assertEqual(len(self.protocol.factory.queuePubRelease), 0)
        self._serverDown()
        self.assertEqual(len(self.factory.queuePublishTx),  0)
        self.assertEqual(len(self.factory.queuePubRelease), 0)
        for d in dl:
            self.failureResultOf(d).trap(error.ConnectionDone)
       

    def test_persistent_session_qos1(self):
        self._connect(cleanStart=False)
        dl = self._publish(n=3, qos=1, topic="foo/bar/baz", msg="Hello World")
        self.assertEqual(len(self.protocol.factory.queuePublishTx),  len(dl))
        self._serverDown()
        self.assertEqual(len(self.factory.queuePublishTx),  len(dl))
        for d in dl:
            self.assertNoResult(d)
        self._rebuild()
        self._connect(cleanStart=False)
        self.assertEqual(len(self.protocol.factory.queuePublishTx),  len(dl))
        self._puback(dl)
        self.assertEqual(len(self.protocol.factory.queuePublishTx),  0)


    def test_persistent_session_qos2(self):
        self._connect(cleanStart=False)
        dl = self._publish(n=3, qos=2, topic="foo/bar/baz", msg="Hello World")
        self.assertEqual(len(self.protocol.factory.queuePublishTx),  len(dl))
        self.assertEqual(len(self.protocol.factory.queuePubRelease), 0)
        self._serverDown()
        for d in dl:
            self.assertNoResult(d)
        self._rebuild()
        self._connect(cleanStart=False)
        self.assertEqual(len(self.factory.queuePublishTx),  len(dl))
        self.assertEqual(len(self.protocol.factory.queuePubRelease), 0)
        self._pubrec(dl)
        self.assertEqual(len(self.factory.queuePublishTx), 0 )
        self.assertEqual(len(self.protocol.factory.queuePubRelease), len(dl))
        self._pubcomp(dl)
        self.assertEqual(len(self.factory.queuePublishTx),  0)
        self.assertEqual(len(self.protocol.factory.queuePubRelease), 0)



    def test_persistent_release_qos2(self):
        self._connect(cleanStart=False)
        dl = self._publish(n=3, qos=2, topic="foo/bar/baz", msg="Hello World")
        #- generate two ACK and simulate a server disconnect with a new client protocol
        # being built on the client 
        self._pubrec(dl[:-1])   # Only the first two

        self._serverDown()
        self.assertNoResult(dl[0])
        self.assertNoResult(dl[1])
        self.assertNoResult(dl[2])
        self._rebuild()
        self.assertEqual(len(self.protocol.factory.queuePublishTx),  1)
        self.assertEqual(len(self.protocol.factory.queuePubRelease),  2)
        # Reconnect with the new client protcol object
        self._connect(cleanStart=False)
        self._pubrec(dl[-1:])   # send the last one
        self.assertEqual(len(self.protocol.factory.queuePublishTx), 0)
        self.assertEqual(len(self.protocol.factory.queuePubRelease), 3)
        self._pubcomp(dl[0:1])   # send the first comp
        self.assertEqual(len(self.factory.queuePublishTx),  0)
        self.assertEqual(len(self.protocol.factory.queuePubRelease), 2)
        self._pubcomp(dl[1:2])   # send the second comp
        self.assertEqual(len(self.factory.queuePublishTx),  0)
        self.assertEqual(len(self.protocol.factory.queuePubRelease), 1)
        self._pubcomp(dl[-1:])   # send the last comp
        self.assertEqual(len(self.factory.queuePublishTx),  0)
        self.assertEqual(len(self.protocol.factory.queuePubRelease), 0)
Exemple #10
0
class TestMQTTPublisherDisconnect(unittest.TestCase):
    '''
    Testing various cases of disconnect callback
    '''

    def setUp(self):
        '''
        Set up a connencted state
        '''
        self.transport = proto_helpers.StringTransportWithDisconnection()
        self.clock     = task.Clock()
        MQTTBaseProtocol.callLater = self.clock.callLater
        self.factory   = MQTTFactory(MQTTFactory.PUBLISHER | MQTTFactory.SUBSCRIBER)
        self._rebuild()
        self.disconnected = 0

    def _connect(self, cleanStart=True):
        '''
        Go to connected state
        '''
        ack = CONNACK()
        ack.session = False
        ack.resultCode = 0
        ack.encode()
        self.protocol.connect("TwistedMQTT-pubsubs", keepalive=0, cleanStart=cleanStart, version=v31)
        self.transport.clear()
        self.protocol.dataReceived(ack.encoded)

    def _disconnected(self, reason):
        self.disconnected += 1

    def _serverDown(self):
        self.transport.loseConnection()
        self.transport.clear()
        del self.protocol

    def _rebuild(self):
        self.protocol  = self.factory.buildProtocol(0)
        self.transport.protocol = self.protocol
        MQTTBaseProtocol.callLater = self.clock.callLater
        self.protocol.makeConnection(self.transport)

    def test_disconnect_1(self):
        '''Just connect and lose the transport'''
        self._connect()
        self.protocol.onDisconnection = self._disconnected
        self.transport.loseConnection()
        self.assertEqual(self.disconnected, 1)
        

    def test_disconnect_2(self):
        '''connect and disconnect'''
        self._connect()
        self.protocol.onDisconnection = self._disconnected
        self.protocol.disconnect()
        self.assertEqual(self.disconnected, 1)
        

    def test_disconnect_3(self):
        '''connect, generate a deferred and lose the transport'''
        self._connect()
        self.protocol.onDisconnection = self._disconnected
        d = self.protocol.publish(topic="foo/bar/baz1", qos=1, message="hello world 1")
        self.transport.clear()
        self.transport.loseConnection()
        self.assertEqual(self.disconnected, 1)
        self.failureResultOf(d).trap(error.ConnectionDone)
       

    def test_disconnect_4(self):
        '''connect, generate a deferred and disconnect'''
        self._connect()
        self.protocol.onDisconnection = self._disconnected
        d = self.protocol.publish(topic="foo/bar/baz1", qos=1, message="hello world 1")
        self.transport.clear()
        self.protocol.disconnect()
        self.assertEqual(self.disconnected, 1)
        self.failureResultOf(d).trap(error.ConnectionDone)
       

    def test_disconnect_5(self):
        '''connect with persistent session, generate a deferred and disconnect'''
        self._connect(cleanStart=False)
        self.protocol.onDisconnection = self._disconnected
        d = self.protocol.publish(topic="foo/bar/baz1", qos=1, message="hello world 1")
        self.transport.clear()
        self.protocol.disconnect()
        self.assertEqual(self.disconnected, 1)
        self.assertNoResult(d)
       

    def test_disconnect_6(self):
        '''connect with persistent session, generate a deferred , rebuilds protocol'''
        self._connect(cleanStart=False)
        self.protocol.onDisconnection = self._disconnected
        d = self.protocol.publish(topic="foo/bar/baz1", qos=1, message="hello world 1")
        self._serverDown()
        self._rebuild()
        self.assertEqual(self.disconnected, 1)
        self.assertNoResult(d)
class TestMQTTPublisherIntervals(unittest.TestCase):


    def setUp(self):
        '''
        Set up a conencted state
        '''
        self.transport = proto_helpers.StringTransportWithDisconnection()
        self.clock     = task.Clock()
        MQTTBaseProtocol.callLater = self.clock.callLater
        self.factory   = MQTTFactory(MQTTFactory.PUBLISHER)
        self.addr = IPv4Address('TCP','localhost',1880)
        self._rebuild()
        # Just to generate connection contexts
        

    def _connect(self, cleanStart=True):
        '''
        Go to connected state
        '''
        ack = CONNACK()
        ack.session = False
        ack.resultCode = 0
        ack.encode()
        self.protocol.connect("TwistedMQTT-pub", keepalive=0, cleanStart=cleanStart, version=v31)
        self.transport.clear()
        self.protocol.dataReceived(ack.encoded)

    def _rebuild(self):
        self.protocol  = self.factory.buildProtocol(self.addr)
        self.transport.protocol = self.protocol
        MQTTBaseProtocol.callLater = self.clock.callLater
        self.protocol.makeConnection(self.transport)


    def _publish(self, n, qos, topic, msg, window=None):
        if window is None:
            self.protocol.setWindowSize(n)
        else:
            self.protocol.setWindowSize(window)
        dl = []
        for i in range(0,n):
            dl.append(self.protocol.publish(topic=topic, qos=qos, message=msg))
        self.transport.clear()
        for d in dl:
            if qos == 0:
                self.assertEqual(None, self.successResultOf(d))
            else:
                self.assertNoResult(d)
        return dl
    
    def _puback(self, dl):
        ackl = []
        for i in range(0, len(dl)):
            ack= PUBACK()
            ack.msgId = dl[i].msgId
            ackl.append(ack)
        encoded = bytearray()
        for ack in ackl:
            encoded.extend(ack.encode())
        self.protocol.dataReceived(encoded)
        self.transport.clear()
        for i in range(0, len(dl)):
            self.assertEqual(dl[i].msgId, self.successResultOf(dl[i]))

    def test_publish_very_large_qos1(self):
        message = '0123456789ABCDEF'*1000000 # Large PDU
        self._connect()

        # Test at 1MByte/sec
        self.protocol.setBandwith(1000000.0)
        d = self.protocol.publish(topic="foo/bar/baz1", qos=1, message=message)
        self.assertEqual(self.protocol.factory.windowPublish[self.addr][1].dup, False)
        self.transport.clear()
        self.clock.advance(10)
        self.assertEqual(self.protocol.factory.windowPublish[self.addr][1].dup, False)
        ack = PUBACK()
        ack.msgId = d.msgId
        self.protocol.dataReceived(ack.encode())
        self.transport.clear()
        self.assertEqual(ack.msgId, self.successResultOf(d))

        # A large PDU with a large bandwith estimation may retransmit
        self.protocol.setBandwith(10000000.0)
        d = self.protocol.publish(topic="foo/bar/baz1", qos=1, message=message)
        self.assertEqual(self.protocol.factory.windowPublish[self.addr][2].dup, False)
        self.transport.clear()
        self.clock.advance(10)
        self.assertEqual(self.protocol.factory.windowPublish[self.addr][2].dup, True)
        ack = PUBACK()
        ack.msgId = d.msgId
        self.protocol.dataReceived(ack.encode())
        self.transport.clear()
        self.assertEqual(ack.msgId, self.successResultOf(d))
class TestMQTTSubscriber1(unittest.TestCase):
    def setUp(self):
        '''
        Set up a conencted state
        '''
        self.transport = proto_helpers.StringTransportWithDisconnection()
        self.clock = task.Clock()
        MQTTBaseProtocol.callLater = self.clock.callLater
        self.factory = MQTTFactory(MQTTFactory.SUBSCRIBER)
        self.addr = IPv4Address('TCP', 'localhost', 1880)
        self._rebuild()
        self._connect()

    def _connect(self, cleanStart=True):
        '''
        Go to connected state
        '''
        ack = CONNACK()
        ack.session = False
        ack.resultCode = 0
        ack.encode()
        self.protocol.connect("TwistedMQTT-sub",
                              keepalive=0,
                              cleanStart=cleanStart,
                              version=v31)
        self.transport.clear()
        self.protocol.dataReceived(ack.encoded)

    def _serverDown(self):
        self.transport.loseConnection()
        self.transport.clear()
        del self.protocol

    def _rebuild(self):
        self.protocol = self.factory.buildProtocol(self.addr)
        self.transport.protocol = self.protocol
        MQTTBaseProtocol.callLater = self.clock.callLater
        self.protocol.makeConnection(self.transport)

    def _subscribe(self, n, qos, topic):
        self.protocol.setWindowSize(n)
        dl = []
        for i in range(0, n):
            t = "{0}{1}".format(topic, i)
            dl.append(self.protocol.subscribe(t, qos))
        self.transport.clear()
        for d in dl:
            self.assertNoResult(d)
        return dl

    def _unsubscribe(self, n, topic):
        self.protocol.setWindowSize(n)
        dl = []
        for i in range(0, n):
            t = "{0}{1}".format(topic, i)
            dl.append(self.protocol.unsubscribe(t))
        self.transport.clear()
        for d in dl:
            self.assertNoResult(d)
        return dl

    def test_subscribe_single(self):
        d = self.protocol.subscribe("foo/bar/baz1", 2)
        self.transport.clear()
        ack = SUBACK()
        ack.msgId = d.msgId
        ack.granted = [(2, False)]
        self.protocol.dataReceived(ack.encode())
        self.assertEqual([(2, False)], self.successResultOf(d))

    def test_subscribe_single_large_qos(self):
        d = self.protocol.subscribe("foo/bar/baz1", 3)
        self.transport.clear()
        self.failureResultOf(d).trap(ValueError)

    def test_subscribe_single_negative_qos(self):
        d = self.protocol.subscribe("foo/bar/baz1", -1)
        self.transport.clear()
        self.failureResultOf(d).trap(ValueError)

    def test_subscribe_tuple(self):
        d = self.protocol.subscribe(("foo/bar/baz1", 2))
        self.transport.clear()
        ack = SUBACK()
        ack.msgId = d.msgId
        ack.granted = [(2, False)]
        self.protocol.dataReceived(ack.encode())
        self.assertEqual([(2, False)], self.successResultOf(d))

    def test_subscribe_list(self):
        d = self.protocol.subscribe([("foo/bar/baz1", 2), ("foo/bar/baz2", 1),
                                     ("foo/bar/baz3", 0)])
        d.addCallback(self.assertEqual, [(2, False), (1, False), (0, False)])
        self.transport.clear()
        ack = SUBACK()
        ack.msgId = d.msgId
        ack.granted = [(2, False), (1, False), (0, False)]
        self.protocol.dataReceived(ack.encode())
        self.assertEqual([(2, False), (1, False), (0, False)],
                         self.successResultOf(d))

    def test_subscribe_several_fail(self):
        dl = self._subscribe(n=3, qos=2, topic="foo/bar/baz")
        self.assertEqual(len(self.protocol.factory.windowSubscribe[self.addr]),
                         3)
        self._serverDown()
        for d in dl:
            self.failureResultOf(d).trap(error.ConnectionDone)

    def test_subscribe_several_window_fail(self):
        self.protocol.setWindowSize(3)
        dl = self._subscribe(n=3, qos=2, topic="foo/bar/baz")
        self.assertEqual(len(self.protocol.factory.windowSubscribe[self.addr]),
                         3)
        d4 = self.protocol.subscribe("foo/bar/baz3", 2)
        self.assertEqual(len(self.protocol.factory.windowSubscribe[self.addr]),
                         3)
        self.failureResultOf(d4).trap(MQTTWindowError)
        self._serverDown()
        for d in dl:
            self.failureResultOf(d).trap(error.ConnectionDone)

    def test_unsubscribe_single(self):
        d = self.protocol.unsubscribe("foo/bar/baz1")
        self.transport.clear()
        ack = UNSUBACK()
        ack.msgId = d.msgId
        self.protocol.dataReceived(ack.encode())
        self.assertEqual(ack.msgId, self.successResultOf(d))

    def test_unsubscribe_list(self):
        d = self.protocol.unsubscribe(
            ["foo/bar/baz1", "foo/bar/baz2", "foo/bar/baz3"])
        self.transport.clear()
        ack = UNSUBACK()
        ack.msgId = d.msgId
        self.protocol.dataReceived(ack.encode())
        self.assertEqual(ack.msgId, self.successResultOf(d))

    def test_unsubscribe_several_fail(self):
        dl = self._unsubscribe(n=3, topic="foo/bar/baz")
        self.assertEqual(
            len(self.protocol.factory.windowUnsubscribe[self.addr]), 3)
        self._serverDown()
        for d in dl:
            self.failureResultOf(d).trap(error.ConnectionDone)

    def test_unsubscribe_several_window_fail(self):
        dl = self._unsubscribe(n=3, topic="foo/bar/baz")
        self.assertEqual(
            len(self.protocol.factory.windowUnsubscribe[self.addr]), 3)
        d4 = self.protocol.unsubscribe("foo/bar/baz4")
        self.assertEqual(
            len(self.protocol.factory.windowUnsubscribe[self.addr]), 3)
        self.failureResultOf(d4).trap(MQTTWindowError)
        self._serverDown()
        for d in dl:
            self.failureResultOf(d).trap(error.ConnectionDone)

    def test_publish_recv_qos0(self):
        def onPublish(topic, payload, qos, dup, retain, msgId):
            self.topic = topic
            self.payload = payload.decode('utf-8')
            self.qos = qos
            self.retain = retain
            self.msgId = msgId

        self.protocol.onPublish = onPublish
        pub = PUBLISH()
        pub.qos = 0
        pub.dup = False
        pub.retain = False
        pub.topic = "foo/bar/baz0"
        pub.msgId = None
        pub.payload = "Hello world 0"
        self.protocol.dataReceived(pub.encode())
        self.assertEqual(self.topic, pub.topic)
        self.assertEqual(self.payload, pub.payload)
        self.assertEqual(self.qos, pub.qos)
        self.assertEqual(self.retain, pub.retain)
        self.assertEqual(self.msgId, pub.msgId)

    def test_publish_recv_qos1(self):
        def onPublish(topic, payload, qos, dup, retain, msgId):
            self.topic = topic
            self.payload = payload.decode('utf-8')
            self.qos = qos
            self.retain = retain
            self.msgId = msgId
            self.dup = dup

        self.protocol.onPublish = onPublish
        pub = PUBLISH()
        pub.qos = 1
        pub.dup = False
        pub.retain = False
        pub.topic = "foo/bar/baz1"
        pub.msgId = 1
        pub.payload = "Hello world 1"
        self.protocol.dataReceived(pub.encode())
        self.transport.clear()
        self.assertEqual(self.topic, pub.topic)
        self.assertEqual(self.payload, pub.payload)
        self.assertEqual(self.qos, pub.qos)
        self.assertEqual(self.retain, pub.retain)
        self.assertEqual(self.msgId, pub.msgId)
        self.assertEqual(self.dup, pub.dup)

    def test_publish_recv_qos2(self):
        self.called = False

        def onPublish(topic, payload, qos, dup, retain, msgId):
            self.called = True
            self.topic = topic
            self.payload = payload.decode('utf-8')
            self.qos = qos
            self.retain = retain
            self.msgId = msgId
            self.dup = dup

        self.protocol.onPublish = onPublish
        pub = PUBLISH()
        pub.qos = 2
        pub.dup = False
        pub.retain = False
        pub.topic = "foo/bar/baz2"
        pub.msgId = 1
        pub.payload = "Hello world 2"
        self.protocol.dataReceived(pub.encode())
        self.transport.clear()
        self.assertEqual(self.called, False)
        rel = PUBREL()
        rel.msgId = pub.msgId
        self.protocol.dataReceived(rel.encode())
        self.assertEqual(self.topic, pub.topic)
        self.assertEqual(self.payload, pub.payload)
        self.assertEqual(self.qos, pub.qos)
        self.assertEqual(self.retain, pub.retain)
        self.assertEqual(self.msgId, pub.msgId)
        self.assertEqual(self.dup, pub.dup)
class TestMQTTSubscriber2(unittest.TestCase):
    def setUp(self):
        '''
        Set up a conencted state
        '''
        self.transport = proto_helpers.StringTransportWithDisconnection()
        self.clock = task.Clock()
        MQTTBaseProtocol.callLater = self.clock.callLater
        self.factory = MQTTFactory(MQTTFactory.SUBSCRIBER)
        self._rebuild()

    def _connect(self, cleanStart=True):
        '''
        Go to connected state
        '''
        ack = CONNACK()
        ack.session = False
        ack.resultCode = 0
        ack.encode()
        self.protocol.connect("TwistedMQTT-sub",
                              keepalive=0,
                              cleanStart=cleanStart,
                              version=v31)
        self.transport.clear()
        self.protocol.dataReceived(ack.encoded)

    def _serverDown(self):
        self.transport.loseConnection()
        self.transport.clear()
        del self.protocol

    def _rebuild(self):
        self.protocol = self.factory.buildProtocol(0)
        self.transport.protocol = self.protocol
        MQTTBaseProtocol.callLater = self.clock.callLater
        self.protocol.makeConnection(self.transport)

    def _subscribe(self, n, qos, topic):
        self.protocol.setWindowSize(n)
        dl = []
        for i in range(0, n):
            t = "{0}{1}".format(topic, i)
            dl.append(self.protocol.subscribe(t, qos))
        self.transport.clear()
        for d in dl:
            self.assertNoResult(d)
        return dl

    def _unsubscribe(self, n, topic):
        self.protocol.setWindowSize(n)
        dl = []
        for i in range(0, n):
            t = "{0}{1}".format(topic, i)
            dl.append(self.protocol.unsubscribe(t))
        self.transport.clear()
        for d in dl:
            self.assertNoResult(d)
        return dl

    def test_subscribe_setPubishHandler1(self):
        def onPublish(topic, payload, qos, dup, retain, msgId):
            self.called = True

        self.protocol.onPublish = onPublish
        self._connect()
        d = self.protocol.subscribe("foo/bar/baz1", 2)
        self.transport.clear()
        ack = SUBACK()
        ack.msgId = d.msgId
        ack.granted = [(2, False)]
        self.protocol.dataReceived(ack.encode())
        self.assertEqual([(2, False)], self.successResultOf(d))

    def test_subscribe_setPubishHandler2(self):
        def onPublish(topic, payload, qos, dup, retain, msgId):
            self.called = True

        self._connect()
        self.protocol.onPublish = onPublish
        d = self.protocol.subscribe("foo/bar/baz1", 2)
        self.transport.clear()
        ack = SUBACK()
        ack.msgId = d.msgId
        ack.granted = [(2, False)]
        self.protocol.dataReceived(ack.encode())
        self.assertEqual([(2, False)], self.successResultOf(d))
class TestMQTTPublisherIntervals(unittest.TestCase):
    def setUp(self):
        '''
        Set up a conencted state
        '''
        self.transport = proto_helpers.StringTransportWithDisconnection()
        self.clock = task.Clock()
        MQTTBaseProtocol.callLater = self.clock.callLater
        self.factory = MQTTFactory(MQTTFactory.PUBLISHER)
        self.addr = IPv4Address('TCP', 'localhost', 1880)
        self._rebuild()
        # Just to generate connection contexts

    def _connect(self, cleanStart=True):
        '''
        Go to connected state
        '''
        ack = CONNACK()
        ack.session = False
        ack.resultCode = 0
        ack.encode()
        self.protocol.connect("TwistedMQTT-pub",
                              keepalive=0,
                              cleanStart=cleanStart,
                              version=v31)
        self.transport.clear()
        self.protocol.dataReceived(ack.encoded)

    def _rebuild(self):
        self.protocol = self.factory.buildProtocol(self.addr)
        self.transport.protocol = self.protocol
        MQTTBaseProtocol.callLater = self.clock.callLater
        self.protocol.makeConnection(self.transport)

    def _publish(self, n, qos, topic, msg, window=None):
        if window is None:
            self.protocol.setWindowSize(n)
        else:
            self.protocol.setWindowSize(window)
        dl = []
        for i in range(0, n):
            dl.append(self.protocol.publish(topic=topic, qos=qos, message=msg))
        self.transport.clear()
        for d in dl:
            if qos == 0:
                self.assertEqual(None, self.successResultOf(d))
            else:
                self.assertNoResult(d)
        return dl

    def _puback(self, dl):
        ackl = []
        for i in range(0, len(dl)):
            ack = PUBACK()
            ack.msgId = dl[i].msgId
            ackl.append(ack)
        encoded = bytearray()
        for ack in ackl:
            encoded.extend(ack.encode())
        self.protocol.dataReceived(encoded)
        self.transport.clear()
        for i in range(0, len(dl)):
            self.assertEqual(dl[i].msgId, self.successResultOf(dl[i]))

    def test_publish_very_large_qos1(self):
        message = '0123456789ABCDEF' * 1000000  # Large PDU
        self._connect()

        # Test at 1MByte/sec
        self.protocol.setBandwith(1000000.0)
        d = self.protocol.publish(topic="foo/bar/baz1", qos=1, message=message)
        self.assertEqual(self.protocol.factory.windowPublish[self.addr][1].dup,
                         False)
        self.transport.clear()
        self.clock.advance(10)
        self.assertEqual(self.protocol.factory.windowPublish[self.addr][1].dup,
                         False)
        ack = PUBACK()
        ack.msgId = d.msgId
        self.protocol.dataReceived(ack.encode())
        self.transport.clear()
        self.assertEqual(ack.msgId, self.successResultOf(d))

        # A large PDU with a large bandwith estimation may retransmit
        self.protocol.setBandwith(10000000.0)
        d = self.protocol.publish(topic="foo/bar/baz1", qos=1, message=message)
        self.assertEqual(self.protocol.factory.windowPublish[self.addr][2].dup,
                         False)
        self.transport.clear()
        self.clock.advance(10)
        self.assertEqual(self.protocol.factory.windowPublish[self.addr][2].dup,
                         True)
        ack = PUBACK()
        ack.msgId = d.msgId
        self.protocol.dataReceived(ack.encode())
        self.transport.clear()
        self.assertEqual(ack.msgId, self.successResultOf(d))
class TestMQTTPublisher1(unittest.TestCase):
    def setUp(self):
        '''
        Set up a conencted state
        '''
        self.transport = proto_helpers.StringTransportWithDisconnection()
        self.clock = task.Clock()
        MQTTBaseProtocol.callLater = self.clock.callLater
        self.factory = MQTTFactory(MQTTFactory.PUBLISHER)
        self._rebuild()
        # Just to generate connection contexts

    def _connect(self, cleanStart=True):
        '''
        Go to connected state
        '''
        ack = CONNACK()
        ack.session = False
        ack.resultCode = 0
        ack.encode()
        self.protocol.connect("TwistedMQTT-pub",
                              keepalive=0,
                              cleanStart=cleanStart,
                              version=v31)
        self.transport.clear()
        self.protocol.dataReceived(ack.encoded)

    def _serverDown(self):
        self.transport.loseConnection()
        self.transport.clear()
        del self.protocol

    def _rebuild(self):
        self.addr = IPv4Address('TCP', 'localhost', 1880)
        self.protocol = self.factory.buildProtocol(self.addr)
        self.transport.protocol = self.protocol
        MQTTBaseProtocol.callLater = self.clock.callLater
        self.protocol.makeConnection(self.transport)

    def _publish(self, n, qos, topic, msg, window=None):
        if window is None:
            self.protocol.setWindowSize(n)
        else:
            self.protocol.setWindowSize(window)
        dl = []
        for i in range(0, n):
            dl.append(self.protocol.publish(topic=topic, qos=qos, message=msg))
        self.transport.clear()
        for d in dl:
            if qos == 0:
                self.assertEqual(None, self.successResultOf(d))
            else:
                self.assertNoResult(d)
        return dl

    def _puback(self, dl):
        ackl = []
        for i in range(0, len(dl)):
            ack = PUBACK()
            ack.msgId = dl[i].msgId
            ackl.append(ack)
        encoded = bytearray()
        for ack in ackl:
            encoded.extend(ack.encode())
        self.protocol.dataReceived(encoded)
        self.transport.clear()
        for i in range(0, len(dl)):
            self.assertEqual(dl[i].msgId, self.successResultOf(dl[i]))

    def _pubrec(self, dl):
        recl = []
        for i in range(0, len(dl)):
            rec = PUBREC()
            rec.msgId = dl[i].msgId
            recl.append(rec)
        encoded = bytearray()
        for rec in recl:
            encoded.extend(rec.encode())
        self.protocol.dataReceived(encoded)
        self.transport.clear()
        for i in range(0, len(dl)):
            self.assertNoResult(dl[i])

    def _pubcomp(self, dl):
        compl = []
        for i in range(0, len(dl)):
            comp = PUBCOMP()
            comp.msgId = dl[i].msgId
            compl.append(comp)
        encoded = bytearray()
        for comp in compl:
            encoded.extend(comp.encode())
        self.protocol.dataReceived(encoded)
        self.transport.clear()
        for i in range(0, len(dl)):
            self.assertEqual(dl[i].msgId, self.successResultOf(dl[i]))

    def test_publish_single_qos0(self):
        self._connect()
        d = self.protocol.publish(topic="foo/bar/baz1",
                                  qos=0,
                                  message="hello world 0")
        self.assertEqual(len(self.protocol.factory.windowPublish[self.addr]),
                         0)
        self.assertEqual(None, self.successResultOf(d))

    def test_publish_single_qos1(self):
        self._connect()
        d = self.protocol.publish(topic="foo/bar/baz1",
                                  qos=1,
                                  message="hello world 1")
        self.assertEqual(len(self.protocol.factory.windowPublish[self.addr]),
                         1)
        self.transport.clear()
        ack = PUBACK()
        ack.msgId = d.msgId
        self.protocol.dataReceived(ack.encode())
        self.assertEqual(len(self.protocol.factory.windowPublish[self.addr]),
                         0)
        self.assertEqual(ack.msgId, self.successResultOf(d))

    def test_publish_single_qos2(self):
        self._connect()
        d = self.protocol.publish(topic="foo/bar/baz1",
                                  qos=2,
                                  message="hello world 2")
        self.assertEqual(len(self.protocol.factory.windowPublish[self.addr]),
                         1)
        self.assertEqual(
            len(self.protocol.factory.windowPubRelease[self.addr]), 0)
        self.transport.clear()
        ack = PUBREC()
        ack.msgId = d.msgId
        self.protocol.dataReceived(ack.encode())
        self.transport.clear()
        self.assertEqual(len(self.protocol.factory.windowPublish[self.addr]),
                         0)
        self.assertEqual(
            len(self.protocol.factory.windowPubRelease[self.addr]), 1)
        ack = PUBCOMP()
        ack.msgId = d.msgId
        self.protocol.dataReceived(ack.encode())
        self.assertEqual(len(self.protocol.factory.windowPublish[self.addr]),
                         0)
        self.assertEqual(
            len(self.protocol.factory.windowPubRelease[self.addr]), 0)
        self.assertEqual(ack.msgId, self.successResultOf(d))

    def test_publish_several_qos0(self):
        self._connect()
        dl = self._publish(n=3, qos=0, topic="foo/bar/baz", msg="Hello World")
        self.assertEqual(len(self.protocol.factory.windowPublish[self.addr]),
                         0)

    def test_publish_several_qos1(self):
        self._connect()
        dl = self._publish(n=3, qos=1, topic="foo/bar/baz", msg="Hello World")
        self.assertEqual(len(self.protocol.factory.windowPublish[self.addr]),
                         len(dl))
        self._puback(dl)
        self.assertEqual(len(self.protocol.factory.windowPublish[self.addr]),
                         0)

    def test_publish_several_qos2(self):
        self._connect()
        dl = self._publish(n=3, qos=2, topic="foo/bar/baz", msg="Hello World")
        self.assertEqual(len(self.protocol.factory.windowPublish[self.addr]),
                         len(dl))
        self.assertEqual(
            len(self.protocol.factory.windowPubRelease[self.addr]), 0)
        self._pubrec(dl)
        self.assertEqual(len(self.protocol.factory.windowPublish[self.addr]),
                         0)
        self.assertEqual(
            len(self.protocol.factory.windowPubRelease[self.addr]), len(dl))
        self._pubcomp(dl)
        self.assertEqual(len(self.protocol.factory.windowPublish[self.addr]),
                         0)
        self.assertEqual(
            len(self.protocol.factory.windowPubRelease[self.addr]), 0)

    def test_publish_many_qos1(self):
        '''
        Test enqueuing when not all ACKs arrives
        '''
        self._connect()
        window = 3
        n = 7
        dl = self._publish(n=n,
                           window=window,
                           qos=1,
                           topic="foo/bar/baz",
                           msg="Hello World")
        self.assertEqual(len(self.protocol.factory.windowPublish[self.addr]),
                         window)
        self.assertEqual(len(self.protocol.factory.queuePublishTx[self.addr]),
                         n - window)
        self._puback(dl[0:window])
        self.assertEqual(len(self.protocol.factory.windowPublish[self.addr]),
                         window)
        self.assertEqual(len(self.protocol.factory.queuePublishTx[self.addr]),
                         1)
        self._puback(dl[window:2 * window])
        self.assertEqual(len(self.protocol.factory.windowPublish[self.addr]),
                         1)
        self.assertEqual(len(self.protocol.factory.queuePublishTx[self.addr]),
                         0)
        self._puback(dl[2 * window:])
        self.assertEqual(len(self.protocol.factory.windowPublish[self.addr]),
                         0)
        self.assertEqual(len(self.protocol.factory.queuePublishTx[self.addr]),
                         0)

    def test_publish_many_qos2(self):
        '''
        Test enqueuing when not all ACKs arrives
        '''
        self._connect()
        window = 3
        n = 7
        dl = self._publish(n=n,
                           window=window,
                           qos=2,
                           topic="foo/bar/baz",
                           msg="Hello World")
        self.assertEqual(len(self.protocol.factory.windowPublish[self.addr]),
                         window)
        self.assertEqual(len(self.protocol.factory.queuePublishTx[self.addr]),
                         n - window)
        self._pubrec(dl[0:window])
        self.assertEqual(len(self.protocol.factory.windowPublish[self.addr]),
                         0)
        self.assertEqual(
            len(self.protocol.factory.windowPubRelease[self.addr]), window)
        self._pubcomp(dl[0:window])
        self.assertEqual(len(self.protocol.factory.queuePublishTx[self.addr]),
                         n - 2 * window)
        self.assertEqual(len(self.protocol.factory.windowPublish[self.addr]),
                         window)
        self.assertEqual(
            len(self.protocol.factory.windowPubRelease[self.addr]), 0)

    def test_lost_session(self):
        self._connect()
        dl = self._publish(n=3, qos=2, topic="foo/bar/baz", msg="Hello World")
        self.assertEqual(len(self.protocol.factory.windowPublish[self.addr]),
                         len(dl))
        self.assertEqual(
            len(self.protocol.factory.windowPubRelease[self.addr]), 0)
        self._serverDown()
        self.assertEqual(len(self.factory.windowPublish[self.addr]), 0)
        self.assertEqual(len(self.factory.windowPubRelease[self.addr]), 0)
        for d in dl:
            self.failureResultOf(d).trap(error.ConnectionDone)

    def test_persistent_session_qos1(self):
        self._connect(cleanStart=False)
        dl = self._publish(n=3, qos=1, topic="foo/bar/baz", msg="Hello World")
        self.assertEqual(len(self.protocol.factory.windowPublish[self.addr]),
                         len(dl))
        self._serverDown()
        self.assertEqual(len(self.factory.windowPublish[self.addr]), len(dl))
        for d in dl:
            self.assertNoResult(d)
        self._rebuild()
        self._connect(cleanStart=False)
        self.assertEqual(len(self.protocol.factory.windowPublish[self.addr]),
                         len(dl))
        self._puback(dl)
        self.assertEqual(len(self.protocol.factory.windowPublish[self.addr]),
                         0)

    def test_persistent_session_qos2(self):
        self._connect(cleanStart=False)
        dl = self._publish(n=3, qos=2, topic="foo/bar/baz", msg="Hello World")
        self.assertEqual(len(self.protocol.factory.windowPublish[self.addr]),
                         len(dl))
        self.assertEqual(
            len(self.protocol.factory.windowPubRelease[self.addr]), 0)
        self._serverDown()
        for d in dl:
            self.assertNoResult(d)
        self._rebuild()
        self._connect(cleanStart=False)
        self.assertEqual(len(self.factory.windowPublish[self.addr]), len(dl))
        self.assertEqual(
            len(self.protocol.factory.windowPubRelease[self.addr]), 0)
        self._pubrec(dl)
        self.assertEqual(len(self.factory.windowPublish[self.addr]), 0)
        self.assertEqual(
            len(self.protocol.factory.windowPubRelease[self.addr]), len(dl))
        self._pubcomp(dl)
        self.assertEqual(len(self.factory.windowPublish[self.addr]), 0)
        self.assertEqual(
            len(self.protocol.factory.windowPubRelease[self.addr]), 0)

    def test_persistent_release_qos2(self):
        self._connect(cleanStart=False)
        dl = self._publish(n=3, qos=2, topic="foo/bar/baz", msg="Hello World")
        #- generate two ACK and simulate a server disconnect with a new client protocol
        # being built on the client
        self._pubrec(dl[:-1])  # Only the first two

        self._serverDown()
        self.assertNoResult(dl[0])
        self.assertNoResult(dl[1])
        self.assertNoResult(dl[2])
        self._rebuild()
        self.assertEqual(len(self.protocol.factory.windowPublish[self.addr]),
                         1)
        self.assertEqual(
            len(self.protocol.factory.windowPubRelease[self.addr]), 2)
        # Reconnect with the new client protcol object
        self._connect(cleanStart=False)
        self._pubrec(dl[-1:])  # send the last one
        self.assertEqual(len(self.protocol.factory.windowPublish[self.addr]),
                         0)
        self.assertEqual(
            len(self.protocol.factory.windowPubRelease[self.addr]), 3)
        self._pubcomp(dl[0:1])  # send the first comp
        self.assertEqual(len(self.factory.windowPublish[self.addr]), 0)
        self.assertEqual(
            len(self.protocol.factory.windowPubRelease[self.addr]), 2)
        self._pubcomp(dl[1:2])  # send the second comp
        self.assertEqual(len(self.factory.windowPublish[self.addr]), 0)
        self.assertEqual(
            len(self.protocol.factory.windowPubRelease[self.addr]), 1)
        self._pubcomp(dl[-1:])  # send the last comp
        self.assertEqual(len(self.factory.windowPublish[self.addr]), 0)
        self.assertEqual(
            len(self.protocol.factory.windowPubRelease[self.addr]), 0)