예제 #1
0
class TestBuildFrame(TestCase):

    def setUp(self):
        self.protocol = StompProtocol()

    def test_build_frame_with_body(self):
        buf = self.protocol.build_frame('HELLO', {
            'from': 'me',
            'to': 'you'
        }, 'I Am The Walrus')

        self.assertEqual(
            buf,
            b'HELLO\n'
            b'from:me\n'
            b'to:you\n\n'
            b'I Am The Walrus'
            b'\x00')

    def test_build_frame_without_body(self):
        buf = self.protocol.build_frame('HI', {
            'from': '1',
            'to': '2'
        })

        self.assertEqual(
            buf,
            b'HI\n'
            b'from:1\n'
            b'to:2\n\n'
            b'\x00')
예제 #2
0
    def __init__(self,
                 host='localhost',
                 port=61613,
                 connect_headers={},
                 on_error=None,
                 on_disconnect=None,
                 on_connect=None,
                 reconnect_max_attempts=-1,
                 reconnect_timeout=1000):

        self.host = host
        self.port = port
        self.logger = logging.getLogger('TorStomp')

        self._connect_headers = connect_headers
        self._connect_headers['accept-version'] = self.VERSION
        self._heart_beat_handler = None
        self.connected = False
        self._disconnecting = False
        self._protocol = StompProtocol()
        self._subscriptions = {}
        self._last_subscribe_id = 0
        self._on_error = on_error
        self._on_disconnect = on_disconnect
        self._on_connect = on_connect

        self._reconnect_max_attempts = reconnect_max_attempts
        self._reconnect_timeout = timedelta(milliseconds=reconnect_timeout)
        self._reconnect_attempts = 0
예제 #3
0
class TestReadFrame(TestCase):
    def setUp(self):
        self.protocol = StompProtocol()

    def test_single_packet(self):
        self.protocol.add_data("CONNECT\n" "accept-version:1.0\n\n\x00")

        self.assertEqual(len(self.protocol._frames_ready), 1)

        frame = self.protocol._frames_ready[0]
        self.assertEqual(frame.command, "CONNECT")
        self.assertEqual(frame.headers, {"accept-version": "1.0"})
        self.assertEqual(frame.body, None)
예제 #4
0
class TestBuildFrame(TestCase):
    def setUp(self):
        self.protocol = StompProtocol()

    def test_build_frame_with_body(self):
        buf = self.protocol.build_frame("HELLO", {"from": "me", "to": "you"}, "I Am The Walrus")

        self.assertEqual(buf, b"HELLO\n" b"from:me\n" b"to:you\n\n" b"I Am The Walrus" b"\x00")

    def test_build_frame_without_body(self):
        buf = self.protocol.build_frame("HI", {"from": "1", "to": "2"})

        self.assertEqual(buf, b"HI\n" b"from:1\n" b"to:2\n\n" b"\x00")
예제 #5
0
class TestReadFrame(TestCase):
    def setUp(self):
        self.protocol = StompProtocol()

    def test_single_packet(self):
        self.protocol.add_data('CONNECT\n' 'accept-version:1.0\n\n\x00')

        self.assertEqual(len(self.protocol._frames_ready), 1)

        frame = self.protocol._frames_ready[0]
        self.assertEqual(frame.command, 'CONNECT')
        self.assertEqual(frame.headers, {'accept-version': '1.0'})
        self.assertEqual(frame.body, None)
예제 #6
0
파일: __init__.py 프로젝트: v0lk3r/torstomp
    def __init__(self, host='localhost', port=61613, host_list=None,
                 connect_headers={}, on_error=None, on_disconnect=None,
                 on_connect=None, reconnect_max_attempts=-1,
                 reconnect_timeout=1000, log_name='TorStomp'):

        self.host = host
        self.port = port
        # `host_list` is a list of (host, port) tuples. Overrides host, port attributes
        # In case of a reconnection, the list is cycled until a reconnect succeeds
        self.host_list = host_list
        self.logger = logging.getLogger(log_name)

        self._connect_headers = connect_headers
        self._connect_headers['accept-version'] = self.VERSION
        self._heart_beat_handler = None
        self.connected = False
        self._disconnecting = False
        self._protocol = StompProtocol(log_name=log_name)
        self._subscriptions = {}
        self._last_subscribe_id = 0
        self._on_error = on_error
        self._on_disconnect = on_disconnect
        self._on_connect = on_connect

        self._reconnect_max_attempts = reconnect_max_attempts
        self._reconnect_timeout = timedelta(milliseconds=reconnect_timeout)
        self._reconnect_attempts = 0
예제 #7
0
class TestReadFrame(TestCase):

    def setUp(self):
        self.protocol = StompProtocol()

    def test_single_packet(self):
        self.protocol.add_data(
            b'CONNECT\n'
            b'accept-version:1.0\n\n\x00'
        )

        self.assertEqual(len(self.protocol._frames_ready), 1)

        frame = self.protocol._frames_ready[0]
        self.assertEqual(frame.command, 'CONNECT')
        self.assertEqual(frame.headers, {'accept-version': '1.0'})
        self.assertEqual(frame.body, None)
예제 #8
0
class TestBuildFrame(TestCase):
    def setUp(self):
        self.protocol = StompProtocol()

    def test_build_frame_with_body(self):
        buf = self.protocol.build_frame('HELLO', {
            'from': 'me',
            'to': 'you'
        }, 'I Am The Walrus')

        self.assertEqual(
            buf, b'HELLO\n'
            b'from:me\n'
            b'to:you\n\n'
            b'I Am The Walrus'
            b'\x00')

    def test_build_frame_without_body(self):
        buf = self.protocol.build_frame('HI', {'from': '1', 'to': '2'})

        self.assertEqual(buf, b'HI\n' b'from:1\n' b'to:2\n\n' b'\x00')
예제 #9
0
    def __init__(self, host='localhost', port=61613, connect_headers={},
                 on_error=None, on_disconnect=None, on_connect=None,
                 reconnect_max_attempts=-1, reconnect_timeout=1000):

        self.host = host
        self.port = port
        self.logger = logging.getLogger('TorStomp')

        self._connect_headers = connect_headers
        self._connect_headers['accept-version'] = self.VERSION
        self._heart_beat_handler = None
        self.connected = False
        self._disconnecting = False
        self._protocol = StompProtocol()
        self._subscriptions = {}
        self._last_subscribe_id = 0
        self._on_error = on_error
        self._on_disconnect = on_disconnect
        self._on_connect = on_connect

        self._reconnect_max_attempts = reconnect_max_attempts
        self._reconnect_timeout = timedelta(milliseconds=reconnect_timeout)
        self._reconnect_attempts = 0
예제 #10
0
class TestRecvFrame(TestCase):
    def setUp(self):
        self.protocol = StompProtocol()

    def test_decode(self):
        self.assertEqual(self.protocol._decode(u"éĂ"), u"éĂ")

    def test_on_decode_error_show_string(self):
        data = MagicMock(spec=six.binary_type)
        data.decode.side_effect = UnicodeDecodeError("hitchhiker", b"", 42, 43, "the universe and everything else")
        with self.assertRaises(UnicodeDecodeError):
            self.protocol._decode(data)

    def test_single_packet(self):
        self.protocol._proccess_frame = MagicMock()

        self.protocol.add_data("CONNECT\n" "accept-version:1.0\n\n\x00")

        self.assertTrue(self.protocol._proccess_frame.called)
        self.assertEqual(self.protocol._proccess_frame.call_count, 1)
        self.assertEqual(self.protocol._proccess_frame.call_args[0][0], "CONNECT\n" "accept-version:1.0\n\n")
        self.assertEqual(self.protocol._pending_parts, [])

    def test_parcial_packet(self):
        self.protocol._proccess_frame = MagicMock()

        self.protocol.add_data("CONNECT\n")

        self.protocol.add_data("accept-version:1.0\n\n\x00")

        self.assertTrue(self.protocol._proccess_frame.called)
        self.assertEqual(self.protocol._proccess_frame.call_count, 1)
        self.assertEqual(self.protocol._proccess_frame.call_args[0][0], "CONNECT\n" "accept-version:1.0\n\n")
        self.assertEqual(self.protocol._pending_parts, [])

    def test_multi_parcial_packet1(self):
        self.protocol._proccess_frame = MagicMock()

        self.protocol.add_data("CONNECT\n")

        self.protocol.add_data("accept-version:1.0\n\n\x00\n")

        self.protocol.add_data("CONNECTED\n")

        self.protocol.add_data("accept-version:1.0\n\n\x00\n")

        self.assertTrue(self.protocol._proccess_frame.called)
        self.assertEqual(self.protocol._proccess_frame.call_count, 2)
        self.assertEqual(self.protocol._proccess_frame.call_args_list[0][0][0], "CONNECT\n" "accept-version:1.0\n\n")

        self.assertEqual(self.protocol._proccess_frame.call_args_list[1][0][0], "CONNECTED\n" "accept-version:1.0\n\n")
        self.assertEqual(self.protocol._pending_parts, [])

    def test_multi_parcial_packet2(self):
        self.protocol._proccess_frame = MagicMock()

        self.protocol.add_data("CONNECTED\n" "accept-version:1.0\n\n")

        self.protocol.add_data("\x00\nERROR\n")

        self.protocol.add_data("header:1.0\n\n\x00\n")

        self.assertTrue(self.protocol._proccess_frame.called)
        self.assertEqual(self.protocol._proccess_frame.call_count, 2)
        self.assertEqual(self.protocol._proccess_frame.call_args_list[0][0][0], "CONNECTED\n" "accept-version:1.0\n\n")
        self.assertEqual(self.protocol._pending_parts, [])
        self.assertEqual(self.protocol._proccess_frame.call_args_list[1][0][0], "ERROR\n" "header:1.0\n\n")
        self.assertEqual(self.protocol._pending_parts, [])

    def test_heart_beat_packet1(self):
        self.protocol._proccess_frame = MagicMock()
        self.protocol._recv_heart_beat = MagicMock()
        self.protocol.add_data("\n")
        self.assertFalse(self.protocol._proccess_frame.called)

        self.assertTrue(self.protocol._recv_heart_beat.called)
        self.assertEqual(self.protocol._pending_parts, [])

    def test_heart_beat_packet2(self):
        self.protocol._proccess_frame = MagicMock()
        self.protocol._recv_heart_beat = MagicMock()
        self.protocol.add_data("CONNECT\n" "accept-version:1.0\n\n\x00\n")

        self.assertTrue(self.protocol._proccess_frame.called)
        self.assertTrue(self.protocol._recv_heart_beat.called)
        self.assertEqual(self.protocol._pending_parts, [])

    def test_heart_beat_packet3(self):
        self.protocol._proccess_frame = MagicMock()
        self.protocol._recv_heart_beat = MagicMock()
        self.protocol.add_data("\nCONNECT\n" "accept-version:1.0\n\n\x00")

        self.assertTrue(self.protocol._proccess_frame.called)
        self.assertTrue(self.protocol._recv_heart_beat.called)
        self.assertEqual(self.protocol._pending_parts, [])
예제 #11
0
class TestRecvFrame(TestCase):

    def setUp(self):
        self.protocol = StompProtocol()

    def test_single_packet(self):
        self.protocol._proccess_frame = MagicMock()

        self.protocol.add_data(
            'CONNECT\n'
            'accept-version:1.0\n\n\x00'
        )

        self.assertTrue(self.protocol._proccess_frame.called)
        self.assertEqual(self.protocol._proccess_frame.call_count, 1)
        self.assertEqual(
            self.protocol._proccess_frame.call_args[0][0],
            'CONNECT\n'
            'accept-version:1.0\n\n'
        )
        self.assertEqual(self.protocol._pending_parts, [])

    def test_parcial_packet(self):
        self.protocol._proccess_frame = MagicMock()

        self.protocol.add_data(
            'CONNECT\n'
        )

        self.protocol.add_data(
            'accept-version:1.0\n\n\x00'
        )

        self.assertTrue(self.protocol._proccess_frame.called)
        self.assertEqual(self.protocol._proccess_frame.call_count, 1)
        self.assertEqual(
            self.protocol._proccess_frame.call_args[0][0],
            'CONNECT\n'
            'accept-version:1.0\n\n'
        )
        self.assertEqual(self.protocol._pending_parts, [])

    def test_multi_parcial_packet1(self):
        self.protocol._proccess_frame = MagicMock()

        self.protocol.add_data(
            'CONNECT\n'
        )

        self.protocol.add_data(
            'accept-version:1.0\n\n\x00\n'
        )

        self.protocol.add_data(
            'CONNECTED\n'
        )

        self.protocol.add_data(
            'accept-version:1.0\n\n\x00\n'
        )

        self.assertTrue(self.protocol._proccess_frame.called)
        self.assertEqual(self.protocol._proccess_frame.call_count, 2)
        self.assertEqual(
            self.protocol._proccess_frame.call_args_list[0][0][0],
            'CONNECT\n'
            'accept-version:1.0\n\n'
        )

        self.assertEqual(
            self.protocol._proccess_frame.call_args_list[1][0][0],
            'CONNECTED\n'
            'accept-version:1.0\n\n'
        )
        self.assertEqual(self.protocol._pending_parts, [])

    def test_multi_parcial_packet2(self):
        self.protocol._proccess_frame = MagicMock()

        self.protocol.add_data(
            'CONNECTED\n'
            'accept-version:1.0\n\n'
        )

        self.protocol.add_data(
            '\x00\nERROR\n'
        )

        self.protocol.add_data(
            'header:1.0\n\n\x00\n'
        )

        self.assertTrue(self.protocol._proccess_frame.called)
        self.assertEqual(self.protocol._proccess_frame.call_count, 2)
        self.assertEqual(
            self.protocol._proccess_frame.call_args_list[0][0][0],
            'CONNECTED\n'
            'accept-version:1.0\n\n'
        )
        self.assertEqual(self.protocol._pending_parts, [])
        self.assertEqual(
            self.protocol._proccess_frame.call_args_list[1][0][0],
            'ERROR\n'
            'header:1.0\n\n'
        )
        self.assertEqual(self.protocol._pending_parts, [])

    def test_heart_beat_packet1(self):
        self.protocol._proccess_frame = MagicMock()
        self.protocol._recv_heart_beat = MagicMock()
        self.protocol.add_data('\n')
        self.assertFalse(self.protocol._proccess_frame.called)

        self.assertTrue(self.protocol._recv_heart_beat.called)
        self.assertEqual(self.protocol._pending_parts, [])

    def test_heart_beat_packet2(self):
        self.protocol._proccess_frame = MagicMock()
        self.protocol._recv_heart_beat = MagicMock()
        self.protocol.add_data(
            'CONNECT\n'
            'accept-version:1.0\n\n\x00\n'
        )

        self.assertTrue(self.protocol._proccess_frame.called)
        self.assertTrue(self.protocol._recv_heart_beat.called)
        self.assertEqual(self.protocol._pending_parts, [])

    def test_heart_beat_packet3(self):
        self.protocol._proccess_frame = MagicMock()
        self.protocol._recv_heart_beat = MagicMock()
        self.protocol.add_data(
            '\nCONNECT\n'
            'accept-version:1.0\n\n\x00'
        )

        self.assertTrue(self.protocol._proccess_frame.called)
        self.assertTrue(self.protocol._recv_heart_beat.called)
        self.assertEqual(self.protocol._pending_parts, [])
예제 #12
0
class TorStomp(object):

    VERSION = '1.1'

    def __init__(self,
                 host='localhost',
                 port=61613,
                 connect_headers={},
                 on_error=None,
                 on_disconnect=None,
                 on_connect=None,
                 reconnect_max_attempts=-1,
                 reconnect_timeout=1000):

        self.host = host
        self.port = port
        self.logger = logging.getLogger('TorStomp')

        self._connect_headers = connect_headers
        self._connect_headers['accept-version'] = self.VERSION
        self._heart_beat_handler = None
        self.connected = False
        self._disconnecting = False
        self._protocol = StompProtocol()
        self._subscriptions = {}
        self._last_subscribe_id = 0
        self._on_error = on_error
        self._on_disconnect = on_disconnect
        self._on_connect = on_connect

        self._reconnect_max_attempts = reconnect_max_attempts
        self._reconnect_timeout = timedelta(milliseconds=reconnect_timeout)
        self._reconnect_attempts = 0

    @gen.coroutine
    def connect(self):
        self.stream = self._build_io_stream()

        try:
            yield self.stream.connect((self.host, self.port))
            self.logger.info('Stomp connection estabilished')
        except socket.error as error:
            self.logger.error('[attempt: %d] Connect error: %s',
                              self._reconnect_attempts, error)
            self._schedule_reconnect()
            return

        self.stream.set_close_callback(self._on_disconnect_socket)
        self.stream.read_until_close(streaming_callback=self._on_data,
                                     callback=self._on_data)

        self.connected = True
        self._disconnecting = False
        self._reconnect_attempts = 0
        self._protocol.reset()

        yield self._send_frame('CONNECT', self._connect_headers)

        for subscription in self._subscriptions.values():
            yield self._send_subscribe_frame(subscription)

        if self._on_connect:
            self._on_connect()

    def subscribe(self,
                  destination,
                  ack='auto',
                  extra_headers={},
                  callback=None):

        self._last_subscribe_id += 1

        subscription = Subscription(destination=destination,
                                    id=self._last_subscribe_id,
                                    ack=ack,
                                    extra_headers=extra_headers,
                                    callback=callback)

        self._subscriptions[str(self._last_subscribe_id)] = subscription

        if self.connected:
            self._send_subscribe_frame(subscription)

    def send(self, destination, body='', headers={}, send_content_length=True):
        headers['destination'] = destination

        if body:
            body = self._protocol._encode(body)

            # ActiveMQ determines the type of a message by the
            # inclusion of the content-length header
            if send_content_length:
                headers['content-length'] = len(body)

        return self._send_frame('SEND', headers, body)

    def ack(self, frame):
        headers = {
            'subscription': frame.headers['subscription'],
            'message-id': frame.headers['message-id']
        }

        return self._send_frame('ACK', headers)

    def nack(self, frame):
        headers = {
            'subscription': frame.headers['subscription'],
            'message-id': frame.headers['message-id']
        }

        return self._send_frame('NACK', headers)

    def _build_io_stream(self):
        s = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
        return IOStream(s)

    def _on_disconnect_socket(self):
        self._stop_scheduled_heart_beat()
        self.connected = False

        if self._disconnecting:
            self.logger.info('TCP connection end gracefully')
        else:
            self.logger.info('TCP connection unexpected end')
            self._schedule_reconnect()

        if self._on_disconnect:
            self._on_disconnect()

    def _schedule_reconnect(self):
        if self._reconnect_max_attempts == -1 or \
                self._reconnect_attempts < self._reconnect_max_attempts:

            self._reconnect_attempts += 1
            self._reconnect_timeout_handler = IOLoop.current().add_timeout(
                self._reconnect_timeout, self.connect)
        else:
            self.logger.error('All Connection attempts failed')

    def _on_data(self, data):
        if not data:
            return

        self._protocol.add_data(data)

        frames = self._protocol.pop_frames()
        if frames:
            self._received_frames(frames)

    def _send_frame(self, command, headers={}, body=''):
        buf = self._protocol.build_frame(command, headers, body)
        return self.stream.write(buf)

    def _set_connected(self, connected_frame):
        heartbeat = connected_frame.headers.get('heart-beat')

        if heartbeat:
            sx, sy = heartbeat.split(',')
            sx, sy = int(sx), int(sy)

            if sy:
                self._set_heart_beat(sy)

    def _set_heart_beat(self, time):
        self._heart_beat_delta = timedelta(milliseconds=time)
        self._stop_scheduled_heart_beat()

        self._do_heart_beat()

    def _schedule_heart_beat(self):
        self._heart_beat_handler = IOLoop.current().add_timeout(
            self._heart_beat_delta, self._do_heart_beat)

    def _stop_scheduled_heart_beat(self):
        if self._heart_beat_handler:
            IOLoop.current().remove_timeout(self._heart_beat_handler)

        self._heart_beat_handler = None

    def _do_heart_beat(self):
        self.logger.debug('Sending heartbeat')
        self.stream.write(b'\n')
        self._schedule_heart_beat()

    def _received_frames(self, frames):
        for frame in frames:
            if frame.command == 'MESSAGE':
                self._received_message_frame(frame)
            elif frame.command == 'CONNECTED':
                self._set_connected(frame)
            elif frame.command == 'ERROR':
                self._received_error_frame(frame)
            else:
                self._received_unhandled_frame(frame)

    def _received_message_frame(self, frame):
        subscription_header = frame.headers.get('subscription')

        subscription = self._subscriptions.get(subscription_header)

        if not subscription:
            self.logger.error('Not found subscription %d' %
                              subscription_header)
            return

        subscription.callback(frame, frame.body)

    def _received_error_frame(self, frame):
        message = frame.headers.get('message')

        self.logger.error('Received error: %s', message)
        self.logger.debug('Error detail %s', frame.body)

        if self._on_error:
            self._on_error(StompError(message, frame.body))

    def _received_unhandled_frame(self, frame):
        self.logger.warn('Received unhandled frame: %s', frame.command)

    def _send_subscribe_frame(self, subscription):
        headers = {
            'id': subscription.id,
            'destination': subscription.destination,
            'ack': subscription.ack
        }
        headers.update(subscription.extra_headers)

        return self._send_frame('SUBSCRIBE', headers)
예제 #13
0
class TestRecvFrame(TestCase):
    def setUp(self):
        self.protocol = StompProtocol()

    def test_decode(self):
        self.assertEqual(self.protocol._decode(u'éĂ'), u'éĂ')

    def test_on_decode_error_show_string(self):
        data = MagicMock(spec=six.binary_type)
        data.decode.side_effect = UnicodeDecodeError(
            'hitchhiker', b"", 42, 43, 'the universe and everything else')
        with self.assertRaises(UnicodeDecodeError):
            self.protocol._decode(data)

    def test_single_packet(self):
        self.protocol._proccess_frame = MagicMock()

        self.protocol.add_data('CONNECT\n' 'accept-version:1.0\n\n\x00')

        self.assertTrue(self.protocol._proccess_frame.called)
        self.assertEqual(self.protocol._proccess_frame.call_count, 1)
        self.assertEqual(self.protocol._proccess_frame.call_args[0][0],
                         'CONNECT\n'
                         'accept-version:1.0\n\n')
        self.assertEqual(self.protocol._pending_parts, [])

    def test_parcial_packet(self):
        self.protocol._proccess_frame = MagicMock()

        self.protocol.add_data('CONNECT\n')

        self.protocol.add_data('accept-version:1.0\n\n\x00')

        self.assertTrue(self.protocol._proccess_frame.called)
        self.assertEqual(self.protocol._proccess_frame.call_count, 1)
        self.assertEqual(self.protocol._proccess_frame.call_args[0][0],
                         'CONNECT\n'
                         'accept-version:1.0\n\n')
        self.assertEqual(self.protocol._pending_parts, [])

    def test_multi_parcial_packet1(self):
        self.protocol._proccess_frame = MagicMock()

        self.protocol.add_data('CONNECT\n')

        self.protocol.add_data('accept-version:1.0\n\n\x00\n')

        self.protocol.add_data('CONNECTED\n')

        self.protocol.add_data('accept-version:1.0\n\n\x00\n')

        self.assertTrue(self.protocol._proccess_frame.called)
        self.assertEqual(self.protocol._proccess_frame.call_count, 2)
        self.assertEqual(self.protocol._proccess_frame.call_args_list[0][0][0],
                         'CONNECT\n'
                         'accept-version:1.0\n\n')

        self.assertEqual(self.protocol._proccess_frame.call_args_list[1][0][0],
                         'CONNECTED\n'
                         'accept-version:1.0\n\n')
        self.assertEqual(self.protocol._pending_parts, [])

    def test_multi_parcial_packet2(self):
        self.protocol._proccess_frame = MagicMock()

        self.protocol.add_data('CONNECTED\n' 'accept-version:1.0\n\n')

        self.protocol.add_data('\x00\nERROR\n')

        self.protocol.add_data('header:1.0\n\n\x00\n')

        self.assertTrue(self.protocol._proccess_frame.called)
        self.assertEqual(self.protocol._proccess_frame.call_count, 2)
        self.assertEqual(self.protocol._proccess_frame.call_args_list[0][0][0],
                         'CONNECTED\n'
                         'accept-version:1.0\n\n')
        self.assertEqual(self.protocol._pending_parts, [])
        self.assertEqual(self.protocol._proccess_frame.call_args_list[1][0][0],
                         'ERROR\n'
                         'header:1.0\n\n')
        self.assertEqual(self.protocol._pending_parts, [])

    def test_heart_beat_packet1(self):
        self.protocol._proccess_frame = MagicMock()
        self.protocol._recv_heart_beat = MagicMock()
        self.protocol.add_data('\n')
        self.assertFalse(self.protocol._proccess_frame.called)

        self.assertTrue(self.protocol._recv_heart_beat.called)
        self.assertEqual(self.protocol._pending_parts, [])

    def test_heart_beat_packet2(self):
        self.protocol._proccess_frame = MagicMock()
        self.protocol._recv_heart_beat = MagicMock()
        self.protocol.add_data('CONNECT\n' 'accept-version:1.0\n\n\x00\n')

        self.assertTrue(self.protocol._proccess_frame.called)
        self.assertTrue(self.protocol._recv_heart_beat.called)
        self.assertEqual(self.protocol._pending_parts, [])

    def test_heart_beat_packet3(self):
        self.protocol._proccess_frame = MagicMock()
        self.protocol._recv_heart_beat = MagicMock()
        self.protocol.add_data('\nCONNECT\n' 'accept-version:1.0\n\n\x00')

        self.assertTrue(self.protocol._proccess_frame.called)
        self.assertTrue(self.protocol._recv_heart_beat.called)
        self.assertEqual(self.protocol._pending_parts, [])
예제 #14
0
 def setUp(self):
     self.protocol = StompProtocol()
예제 #15
0
 def setUp(self):
     self.protocol = StompProtocol()
예제 #16
0
class TestRecvFrame(TestCase):

    def setUp(self):
        self.protocol = StompProtocol()

    def test_decode(self):
        self.assertEqual(
            self.protocol._decode(u'éĂ'),
            u'éĂ'
        )

    def test_on_decode_error_show_string(self):
        data = MagicMock(spec=six.binary_type)
        data.decode.side_effect = UnicodeDecodeError(
            'hitchhiker',
            b"",
            42,
            43,
            'the universe and everything else'
        )
        with self.assertRaises(UnicodeDecodeError):
            self.protocol._decode(data)

    def test_single_packet(self):
        self.protocol.add_data(
            b'CONNECT\n'
            b'accept-version:1.0\n\n\x00'
        )

        frames = self.protocol.pop_frames()

        self.assertEqual(len(frames), 1)
        self.assertEqual(frames[0].command, u'CONNECT')
        self.assertEqual(frames[0].headers, {u'accept-version': u'1.0'})
        self.assertEqual(frames[0].body, None)

        self.assertEqual(self.protocol._pending_parts, [])

    def test_parcial_packet(self):
        stream_data = (
            b'CONNECT\n',
            b'accept-version:1.0\n\n\x00',
        )

        for data in stream_data:
            self.protocol.add_data(data)

        frames = self.protocol.pop_frames()

        self.assertEqual(len(frames), 1)
        self.assertEqual(frames[0].command, u'CONNECT')
        self.assertEqual(frames[0].headers, {u'accept-version': u'1.0'})
        self.assertEqual(frames[0].body, None)

    def test_multi_parcial_packet1(self):
        stream_data = (
            b'CONNECT\n',
            b'accept-version:1.0\n\n\x00\n',
            b'CONNECTED\n',
            b'version:1.0\n\n\x00\n'
        )

        for data in stream_data:
            self.protocol.add_data(data)

        frames = self.protocol.pop_frames()
        self.assertEqual(len(frames), 2)

        self.assertEqual(frames[0].command, u'CONNECT')
        self.assertEqual(frames[0].headers, {u'accept-version': u'1.0'})
        self.assertEqual(frames[0].body, None)

        self.assertEqual(frames[1].command, u'CONNECTED')
        self.assertEqual(frames[1].headers, {u'version': u'1.0'})
        self.assertEqual(frames[1].body, None)

        self.assertEqual(self.protocol._pending_parts, [])

    def test_multi_parcial_packet2(self):
        stream_data = (
            b'CONNECTED\n'
            b'version:1.0\n\n',
            b'\x00\nERROR\n',
            b'header:1.0\n\n',
            b'Hey dude\x00\n',
        )

        for data in stream_data:
            self.protocol.add_data(data)

        frames = self.protocol.pop_frames()
        self.assertEqual(len(frames), 2)

        self.assertEqual(frames[0].command, u'CONNECTED')
        self.assertEqual(frames[0].headers, {u'version': u'1.0'})
        self.assertEqual(frames[0].body, None)

        self.assertEqual(frames[1].command, u'ERROR')
        self.assertEqual(frames[1].headers, {u'header': u'1.0'})
        self.assertEqual(frames[1].body, u'Hey dude')

        self.assertEqual(self.protocol._pending_parts, [])

    def test_multi_parcial_packet_with_utf8(self):
        stream_data = (
            b'CONNECTED\n'
            b'accept-version:1.0\n\n',
            b'\x00\nERROR\n',
            b'header:1.0\n\n\xc3',
            b'\xa7\x00\n',
        )

        for data in stream_data:
            self.protocol.add_data(data)

        self.assertEqual(len(self.protocol._frames_ready), 2)
        self.assertEqual(self.protocol._pending_parts, [])

        self.assertEqual(self.protocol._frames_ready[0].body, None)
        self.assertEqual(self.protocol._frames_ready[1].body, u'ç')

    def test_heart_beat_packet1(self):
        self.protocol._recv_heart_beat = MagicMock()
        self.protocol.add_data(b'\n')

        self.assertEqual(self.protocol._pending_parts, [])
        self.assertTrue(self.protocol._recv_heart_beat.called)

    def test_heart_beat_packet2(self):
        self.protocol._recv_heart_beat = MagicMock()
        self.protocol.add_data(
            b'CONNECT\n'
            b'accept-version:1.0\n\n\x00\n'
        )

        self.assertTrue(self.protocol._recv_heart_beat.called)
        self.assertEqual(self.protocol._pending_parts, [])

    def test_heart_beat_packet3(self):
        self.protocol._recv_heart_beat = MagicMock()
        self.protocol.add_data(
            b'\nCONNECT\n'
            b'accept-version:1.0\n\n\x00'
        )

        frames = self.protocol.pop_frames()
        self.assertEqual(len(frames), 1)

        self.assertEqual(frames[0].command, u'CONNECT')
        self.assertEqual(frames[0].headers, {u'accept-version': u'1.0'})
        self.assertEqual(frames[0].body, None)

        self.assertTrue(self.protocol._recv_heart_beat.called)
        self.assertEqual(self.protocol._pending_parts, [])
예제 #17
0
class TorStomp(object):

    VERSION = '1.1'

    def __init__(self, host='localhost', port=61613, connect_headers={},
                 on_error=None, on_disconnect=None, on_connect=None,
                 reconnect_max_attempts=-1, reconnect_timeout=1000):

        self.host = host
        self.port = port
        self.logger = logging.getLogger('TorStomp')

        self._connect_headers = connect_headers
        self._connect_headers['accept-version'] = self.VERSION
        self._heart_beat_handler = None
        self.connected = False
        self._disconnecting = False
        self._protocol = StompProtocol()
        self._subscriptions = {}
        self._last_subscribe_id = 0
        self._on_error = on_error
        self._on_disconnect = on_disconnect
        self._on_connect = on_connect

        self._reconnect_max_attempts = reconnect_max_attempts
        self._reconnect_timeout = timedelta(milliseconds=reconnect_timeout)
        self._reconnect_attempts = 0

    @gen.coroutine
    def connect(self):
        self.stream = self._build_io_stream()

        try:
            yield self.stream.connect((self.host, self.port))
            self.logger.info('Stomp connection estabilished')
        except socket.error as error:
            self.logger.error(
                '[attempt: %d] Connect error: %s', self._reconnect_attempts,
                error)
            self._schedule_reconnect()
            return

        self.stream.set_close_callback(self._on_disconnect_socket)
        self.stream.read_until_close(
            streaming_callback=self._on_data,
            callback=self._on_data)

        self.connected = True
        self._disconnecting = False
        self._reconnect_attempts = 0
        self._protocol.reset()

        yield self._send_frame('CONNECT', self._connect_headers)

        for subscription in self._subscriptions.values():
            yield self._send_subscribe_frame(subscription)

        if self._on_connect:
            self._on_connect()

    def subscribe(self, destination, ack='auto', extra_headers={},
                  callback=None):

        self._last_subscribe_id += 1

        subscription = Subscription(
            destination=destination,
            id=self._last_subscribe_id,
            ack=ack,
            extra_headers=extra_headers,
            callback=callback)

        self._subscriptions[str(self._last_subscribe_id)] = subscription

        if self.connected:
            self._send_subscribe_frame(subscription)

    def send(self, destination, body='', headers={}, send_content_length=True):
        headers['destination'] = destination

        if body:
            body = self._protocol._encode(body)

            # ActiveMQ determines the type of a message by the
            # inclusion of the content-length header
            if send_content_length:
                headers['content-length'] = len(body)

        return self._send_frame('SEND', headers, body)

    def ack(self, frame):
        headers = {
            'subscription': frame.headers['subscription'],
            'message-id': frame.headers['message-id']
        }

        return self._send_frame('ACK', headers)

    def nack(self, frame):
        headers = {
            'subscription': frame.headers['subscription'],
            'message-id': frame.headers['message-id']
        }

        return self._send_frame('NACK', headers)

    def _build_io_stream(self):
        s = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
        return IOStream(s)

    def _on_disconnect_socket(self):
        self._stop_scheduled_heart_beat()
        self.connected = False

        if self._disconnecting:
            self.logger.info('TCP connection end gracefully')
        else:
            self.logger.info('TCP connection unexpected end')
            self._schedule_reconnect()

        if self._on_disconnect:
            self._on_disconnect()

    def _schedule_reconnect(self):
        if self._reconnect_max_attempts == -1 or \
                self._reconnect_attempts < self._reconnect_max_attempts:

            self._reconnect_attempts += 1
            self._reconnect_timeout_handler = IOLoop.current().add_timeout(
                self._reconnect_timeout, self.connect)
        else:
            self.logger.error('All Connection attempts failed')

    def _on_data(self, data):
        if not data:
            return

        self._protocol.add_data(data)

        frames = self._protocol.pop_frames()
        if frames:
            self._received_frames(frames)

    def _send_frame(self, command, headers={}, body=''):
        buf = self._protocol.build_frame(command, headers, body)
        return self.stream.write(buf)

    def _set_connected(self, connected_frame):
        heartbeat = connected_frame.headers.get('heart-beat')

        if heartbeat:
            sx, sy = heartbeat.split(',')
            sx, sy = int(sx), int(sy)

            if sy:
                self._set_heart_beat(sy)

    def _set_heart_beat(self, time):
        self._heart_beat_delta = timedelta(milliseconds=time)
        self._stop_scheduled_heart_beat()

        self._do_heart_beat()

    def _schedule_heart_beat(self):
        self._heart_beat_handler = IOLoop.current().add_timeout(
            self._heart_beat_delta, self._do_heart_beat)

    def _stop_scheduled_heart_beat(self):
        if self._heart_beat_handler:
            IOLoop.current().remove_timeout(self._heart_beat_handler)

        self._heart_beat_handler = None

    def _do_heart_beat(self):
        self.logger.debug('Sending heartbeat')
        self.stream.write(b'\n')
        self._schedule_heart_beat()

    def _received_frames(self, frames):
        for frame in frames:
            if frame.command == 'MESSAGE':
                self._received_message_frame(frame)
            elif frame.command == 'CONNECTED':
                self._set_connected(frame)
            elif frame.command == 'ERROR':
                self._received_error_frame(frame)
            else:
                self._received_unhandled_frame(frame)

    def _received_message_frame(self, frame):
        subscription_header = frame.headers.get('subscription')

        subscription = self._subscriptions.get(subscription_header)

        if not subscription:
            self.logger.error(
                'Not found subscription %d' % subscription_header)
            return

        subscription.callback(frame, frame.body)

    def _received_error_frame(self, frame):
        message = frame.headers.get('message')

        self.logger.error('Received error: %s', message)
        self.logger.debug('Error detail %s', frame.body)

        if self._on_error:
            self._on_error(
                StompError(message, frame.body))

    def _received_unhandled_frame(self, frame):
        self.logger.warn('Received unhandled frame: %s', frame.command)

    def _send_subscribe_frame(self, subscription):
        headers = {
            'id': subscription.id,
            'destination': subscription.destination,
            'ack': subscription.ack
        }
        headers.update(subscription.extra_headers)

        return self._send_frame('SUBSCRIBE', headers)