def test_channel_open_close(self):
        # Test checking opening and closing channel
        frame_writer_cls_mock = Mock()
        conn = Connection(frame_writer=frame_writer_cls_mock)
        with patch.object(conn, 'Transport') as transport_mock:
            handshake(conn, transport_mock)

            channel_id = 1
            transport_mock().read_frame.side_effect = [
                # Inject Open Handshake
                build_frame_type_1(spec.Channel.OpenOk,
                                   channel=channel_id,
                                   args=(1, False),
                                   arg_format='Lb'),
                # Inject close method
                build_frame_type_1(spec.Channel.CloseOk, channel=channel_id)
            ]

            frame_writer_mock = frame_writer_cls_mock()
            frame_writer_mock.reset_mock()

            on_open_mock = Mock()
            ch = conn.channel(channel_id=channel_id, callback=on_open_mock)
            on_open_mock.assert_called_once_with(ch)
            assert ch.is_open is True

            ch.close()
            frame_writer_mock.assert_has_calls([
                call(1, 1, spec.Channel.Open, dumps('s', ('', )), None),
                call(1, 1, spec.Channel.Close, dumps('BsBB', (0, '', 0, 0)),
                     None)
            ])
            assert ch.is_open is False
Beispiel #2
0
class AMQPClient(object):

    """docstring for AMQPClient"""

    def __init__(self, host='localhost', userid='guest',
                 password='******', virtual_host='/', heartbeat=0):
        super(AMQPClient, self).__init__()
        self.conn_params = {
            'host': host,
            'userid': userid,
            'password': password,
            'virtual_host': virtual_host,
            'heartbeat': heartbeat,
        }
        self.conn = None
        self.channel = None

    def _ensure(self):
        if self.conn:
            if self.conn.is_alive():
                return
            else:
                self.close()

        self.conn = Connection(**self.conn_params)

    def close(self):
        if not self.conn:
            return
        try:
            self.conn.close()
        except:
            pass
        self.conn = None

    def _queue(self, queue=None, exchange=None, exchange_type='direct', routing_key=None):
        self._ensure()

        channel = self.conn.channel()
        channel.queue_declare(queue=queue, durable=True,
                              exclusive=False, auto_delete=False)
        channel.exchange_declare(
            exchange=exchange, type=exchange_type, durable=True, auto_delete=False)
        channel.queue_bind(queue=queue, exchange=exchange,
                           routing_key=routing_key)

        return channel

    def queue(self, queue=None, exchange=None, exchange_type='direct', routing_key=None):
        channel = self._queue(queue, exchange, exchange_type, routing_key)
        return AMQPQueue(self, channel, queue, exchange, exchange_type, routing_key)

    def wait(self, timeout=None):
        self._ensure()
        self.conn.drain_events(timeout=timeout)
Beispiel #3
0
class AMQPClient(object):

    """docstring for AMQPClient"""

    def __init__(self, host="localhost", userid="guest", password="******", virtual_host="/", heartbeat=0):
        super(AMQPClient, self).__init__()
        self.conn_params = {
            "host": host,
            "userid": userid,
            "password": password,
            "virtual_host": virtual_host,
            "heartbeat": heartbeat,
        }
        self.conn = None
        self.channel = None

    def _ensure(self):
        if self.conn:
            if self.conn.is_alive():
                return
            else:
                self.close()

        self.conn = Connection(**self.conn_params)

    def close(self):
        if not self.conn:
            return
        try:
            self.conn.close()
        except:
            pass
        self.conn = None

    def _queue(self, queue=None, exchange=None, exchange_type="direct", routing_key=None):
        self._ensure()

        channel = self.conn.channel()
        channel.queue_declare(queue=queue, durable=True, exclusive=False, auto_delete=False)
        channel.exchange_declare(exchange=exchange, type=exchange_type, durable=True, auto_delete=False)
        channel.queue_bind(queue=queue, exchange=exchange, routing_key=routing_key)

        return channel

    def queue(self, queue=None, exchange=None, exchange_type="direct", routing_key=None):
        channel = self._queue(queue, exchange, exchange_type, routing_key)
        return AMQPQueue(self, channel, queue, exchange, exchange_type, routing_key)

    def wait(self, timeout=None):
        self._ensure()
        self.conn.drain_events(timeout=timeout)
Beispiel #4
0
    def test_channel_ignore_methods_during_close(self):
        # Test checking that py-amqp will discard any received methods
        # except Close and Close-OK after sending Channel.Close method
        # to server.
        frame_writer_cls_mock = Mock()
        conn = Connection(frame_writer=frame_writer_cls_mock)
        consumer_tag = 'amq.ctag-PCmzXGkhCw_v0Zq7jXyvkg'
        with patch.object(conn, 'Transport') as transport_mock:
            handshake(conn, transport_mock)

            channel_id = 1
            transport_mock().read_frame.side_effect = [
                # Inject Open Handshake
                build_frame_type_1(spec.Channel.OpenOk,
                                   channel=channel_id,
                                   args=(1, False),
                                   arg_format='Lb'),
                # Inject basic-deliver response
                build_frame_type_1(
                    spec.Basic.Deliver,
                    channel=1,
                    arg_format='sLbss',
                    args=(
                        # consumer-tag, delivery-tag, redelivered,
                        consumer_tag,
                        1,
                        False,
                        # exchange-name, routing-key
                        'foo_exchange',
                        'routing-key')),
                build_frame_type_2(channel=1,
                                   body_len=12,
                                   properties=b'0\x00\x00\x00\x00\x00\x01'),
                build_frame_type_3(channel=1, body=b'Hello World!'),
                # Inject close method
                build_frame_type_1(spec.Channel.CloseOk, channel=channel_id),
            ]

            frame_writer_mock = frame_writer_cls_mock()
            frame_writer_mock.reset_mock()

            with patch('amqp.Channel._on_basic_deliver') as on_deliver_mock:
                ch = conn.channel(channel_id=channel_id)
                ch.close()
                on_deliver_mock.assert_not_called()
            frame_writer_mock.assert_has_calls([
                call(1, 1, spec.Channel.Open, dumps('s', ('', )), None),
                call(1, 1, spec.Channel.Close, dumps('BsBB', (0, '', 0, 0)),
                     None)
            ])
            assert ch.is_open is False
Beispiel #5
0
class Connection(BaseConnection):
    """
    An AMQP broker connection.
    """

    __metaclass__ = ThreadSingleton

    @staticmethod
    def ssl_domain(connector):
        """
        Get SSL properties
        :param connector: A broker object.
        :type connector: Connector
        :return: The SSL properties
        :rtype: dict
        :raise: ValueError
        """
        domain = None
        if connector.use_ssl():
            domain = {}
            connector.ssl.validate()
            if connector.ssl.ca_certificate:
                required = ssl.CERT_REQUIRED
            else:
                required = ssl.CERT_NONE
            domain.update(cert_reqs=required,
                          ca_certs=connector.ssl.ca_certificate,
                          keyfile=connector.ssl.client_key,
                          certfile=connector.ssl.client_certificate)
        return domain

    def __init__(self, url):
        """
        :param url: The connector url.
        :type url: str
        """
        BaseConnection.__init__(self, url)
        self._impl = None

    def is_open(self):
        """
        Get whether the connection has been opened.
        :return: True if open.
        :rtype bool
        """
        return self._impl is not None

    @retry(*CONNECTION_EXCEPTIONS)
    def open(self):
        """
        Open a connection to the broker.
        """
        if self.is_open():
            # already open
            return
        connector = Connector.find(self.url)
        host = ':'.join((connector.host, utf8(connector.port)))
        virtual_host = connector.virtual_host or VIRTUAL_HOST
        domain = self.ssl_domain(connector)
        userid = connector.userid or USERID
        password = connector.password or PASSWORD
        log.info('open: %s', connector)
        self._impl = RealConnection(host=host,
                                    virtual_host=virtual_host,
                                    ssl=domain,
                                    userid=userid,
                                    password=password,
                                    confirm_publish=True)
        log.info('opened: %s', self.url)

    def channel(self):
        """
        Open a channel.
        :return The *real* channel.
        """
        return self._impl.channel()

    def close(self):
        """
        Close the connection.
        """
        connection = self._impl
        self._impl = None
        try:
            connection.close()
            log.info('closed: %s', self.url)
        except Exception, pe:
            log.exception(utf8(pe))
Beispiel #6
0
class test_Connection:

    @pytest.fixture(autouse=True)
    def setup_conn(self):
        self.frame_handler = Mock(name='frame_handler')
        self.frame_writer = Mock(name='frame_writer_cls')
        self.conn = Connection(
            frame_handler=self.frame_handler,
            frame_writer=self.frame_writer,
            authentication=AMQPLAIN('foo', 'bar'),
        )
        self.conn.Channel = Mock(name='Channel')
        self.conn.Transport = Mock(name='Transport')
        self.conn.transport = self.conn.Transport.return_value
        self.conn.send_method = Mock(name='send_method')
        self.conn.frame_writer = Mock(name='frame_writer')

    def test_sasl_authentication(self):
        authentication = SASL()
        self.conn = Connection(authentication=authentication)
        assert self.conn.authentication == (authentication,)

    def test_sasl_authentication_iterable(self):
        authentication = SASL()
        self.conn = Connection(authentication=(authentication,))
        assert self.conn.authentication == (authentication,)

    def test_gssapi(self):
        self.conn = Connection()
        assert isinstance(self.conn.authentication[0], GSSAPI)

    def test_external(self):
        self.conn = Connection()
        assert isinstance(self.conn.authentication[1], EXTERNAL)

    def test_amqplain(self):
        self.conn = Connection(userid='foo', password='******')
        auth = self.conn.authentication[2]
        assert isinstance(auth, AMQPLAIN)
        assert auth.username == 'foo'
        assert auth.password == 'bar'

    def test_plain(self):
        self.conn = Connection(userid='foo', password='******')
        auth = self.conn.authentication[3]
        assert isinstance(auth, PLAIN)
        assert auth.username == 'foo'
        assert auth.password == 'bar'

    def test_login_method_gssapi(self):
        try:
            self.conn = Connection(userid=None, password=None,
                                   login_method='GSSAPI')
        except NotImplementedError:
            pass
        else:
            auths = self.conn.authentication
            assert len(auths) == 1
            assert isinstance(auths[0], GSSAPI)

    def test_login_method_external(self):
        self.conn = Connection(userid=None, password=None,
                               login_method='EXTERNAL')
        auths = self.conn.authentication
        assert len(auths) == 1
        assert isinstance(auths[0], EXTERNAL)

    def test_login_method_amqplain(self):
        self.conn = Connection(login_method='AMQPLAIN')
        auths = self.conn.authentication
        assert len(auths) == 1
        assert isinstance(auths[0], AMQPLAIN)

    def test_login_method_plain(self):
        self.conn = Connection(login_method='PLAIN')
        auths = self.conn.authentication
        assert len(auths) == 1
        assert isinstance(auths[0], PLAIN)

    def test_enter_exit(self):
        self.conn.connect = Mock(name='connect')
        self.conn.close = Mock(name='close')
        with self.conn:
            self.conn.connect.assert_called_with()
        self.conn.close.assert_called_with()

    def test__enter__socket_error(self):
        # test when entering
        self.conn = Connection()
        self.conn.close = Mock(name='close')
        reached = False
        with patch('socket.socket', side_effect=socket.error):
            with pytest.raises(socket.error):
                with self.conn:
                    reached = True
        assert not reached and not self.conn.close.called
        assert self.conn._transport is None and not self.conn.connected

    def test__exit__socket_error(self):
        # test when exiting
        connection = self.conn
        transport = connection._transport
        transport.connected = True
        connection.send_method = Mock(name='send_method',
                                      side_effect=socket.error)
        reached = False
        with pytest.raises(socket.error):
            with connection:
                reached = True
        assert reached
        assert connection.send_method.called and transport.close.called
        assert self.conn._transport is None and not self.conn.connected

    def test_then(self):
        self.conn.on_open = Mock(name='on_open')
        on_success = Mock(name='on_success')
        on_error = Mock(name='on_error')
        self.conn.then(on_success, on_error)
        self.conn.on_open.then.assert_called_with(on_success, on_error)

    def test_connect(self):
        self.conn.transport.connected = False
        self.conn.drain_events = Mock(name='drain_events')

        def on_drain(*args, **kwargs):
            self.conn._handshake_complete = True
        self.conn.drain_events.side_effect = on_drain
        self.conn.connect()
        self.conn.Transport.assert_called_with(
            self.conn.host, self.conn.connect_timeout, self.conn.ssl,
            self.conn.read_timeout, self.conn.write_timeout,
            socket_settings=self.conn.socket_settings,
        )

    def test_connect__already_connected(self):
        callback = Mock(name='callback')
        self.conn.transport.connected = True
        assert self.conn.connect(callback) == callback.return_value
        callback.assert_called_with()

    def test_connect__socket_error(self):
        # check Transport.Connect error
        # socket.error derives from IOError
        # ssl.SSLError derives from socket.error
        self.conn = Connection()
        self.conn.Transport = Mock(name='Transport')
        transport = self.conn.Transport.return_value
        transport.connect.side_effect = IOError
        assert self.conn._transport is None and not self.conn.connected
        with pytest.raises(IOError):
            self.conn.connect()
        transport.connect.assert_called
        assert self.conn._transport is None and not self.conn.connected

    def test_on_start(self):
        self.conn._on_start(3, 4, {'foo': 'bar'}, b'x y z AMQPLAIN PLAIN',
                            'en_US en_GB')
        assert self.conn.version_major == 3
        assert self.conn.version_minor == 4
        assert self.conn.server_properties == {'foo': 'bar'}
        assert self.conn.mechanisms == [b'x', b'y', b'z',
                                        b'AMQPLAIN', b'PLAIN']
        assert self.conn.locales == ['en_US', 'en_GB']
        self.conn.send_method.assert_called_with(
            spec.Connection.StartOk, 'FsSs', (
                self.conn.client_properties, b'AMQPLAIN',
                self.conn.authentication[0].start(self.conn), self.conn.locale,
            ),
        )

    def test_on_start_string_mechanisms(self):
        self.conn._on_start(3, 4, {'foo': 'bar'}, 'x y z AMQPLAIN PLAIN',
                            'en_US en_GB')
        assert self.conn.version_major == 3
        assert self.conn.version_minor == 4
        assert self.conn.server_properties == {'foo': 'bar'}
        assert self.conn.mechanisms == [b'x', b'y', b'z',
                                        b'AMQPLAIN', b'PLAIN']
        assert self.conn.locales == ['en_US', 'en_GB']
        self.conn.send_method.assert_called_with(
            spec.Connection.StartOk, 'FsSs', (
                self.conn.client_properties, b'AMQPLAIN',
                self.conn.authentication[0].start(self.conn), self.conn.locale,
            ),
        )

    def test_missing_credentials(self):
        with pytest.raises(ValueError):
            self.conn = Connection(userid=None, password=None,
                                   login_method='AMQPLAIN')
        with pytest.raises(ValueError):
            self.conn = Connection(password=None, login_method='PLAIN')

    def test_invalid_method(self):
        with pytest.raises(ValueError):
            self.conn = Connection(login_method='any')

    def test_mechanism_mismatch(self):
        with pytest.raises(ConnectionError):
            self.conn._on_start(3, 4, {'foo': 'bar'}, b'x y z',
                                'en_US en_GB')

    def test_login_method_response(self):
        # An old way of doing things.:
        login_method, login_response = b'foo', b'bar'
        with warnings.catch_warnings(record=True) as w:
            warnings.simplefilter("always")
            self.conn = Connection(login_method=login_method,
                                   login_response=login_response)
            self.conn.send_method = Mock(name='send_method')
            self.conn._on_start(3, 4, {'foo': 'bar'}, login_method,
                                'en_US en_GB')
            assert len(w) == 1
            assert issubclass(w[0].category, DeprecationWarning)

        self.conn.send_method.assert_called_with(
            spec.Connection.StartOk, 'FsSs', (
                self.conn.client_properties, login_method,
                login_response, self.conn.locale,
            ),
        )

    def test_on_start__consumer_cancel_notify(self):
        self.conn._on_start(
            3, 4, {'capabilities': {'consumer_cancel_notify': 1}},
            b'AMQPLAIN', '',
        )
        cap = self.conn.client_properties['capabilities']
        assert cap['consumer_cancel_notify']

    def test_on_start__connection_blocked(self):
        self.conn._on_start(
            3, 4, {'capabilities': {'connection.blocked': 1}},
            b'AMQPLAIN', '',
        )
        cap = self.conn.client_properties['capabilities']
        assert cap['connection.blocked']

    def test_on_start__authentication_failure_close(self):
        self.conn._on_start(
            3, 4, {'capabilities': {'authentication_failure_close': 1}},
            b'AMQPLAIN', '',
        )
        cap = self.conn.client_properties['capabilities']
        assert cap['authentication_failure_close']

    def test_on_start__authentication_failure_close__disabled(self):
        self.conn._on_start(
            3, 4, {'capabilities': {}},
            b'AMQPLAIN', '',
        )
        assert 'capabilities' not in self.conn.client_properties

    def test_on_secure(self):
        self.conn._on_secure('vfz')

    def test_on_tune(self):
        self.conn.client_heartbeat = 16
        self.conn._on_tune(345, 16, 10)
        assert self.conn.channel_max == 345
        assert self.conn.frame_max == 16
        assert self.conn.server_heartbeat == 10
        assert self.conn.heartbeat == 10
        self.conn.send_method.assert_called_with(
            spec.Connection.TuneOk, 'BlB', (
                self.conn.channel_max, self.conn.frame_max,
                self.conn.heartbeat,
            ),
            callback=self.conn._on_tune_sent,
        )

    def test_on_tune__client_heartbeat_disabled(self):
        self.conn.client_heartbeat = 0
        self.conn._on_tune(345, 16, 10)
        assert self.conn.heartbeat == 0

    def test_on_tune_sent(self):
        self.conn._on_tune_sent()
        self.conn.send_method.assert_called_with(
            spec.Connection.Open, 'ssb', (self.conn.virtual_host, '', False),
        )

    def test_on_open_ok(self):
        self.conn.on_open = Mock(name='on_open')
        self.conn._on_open_ok()
        assert self.conn._handshake_complete
        self.conn.on_open.assert_called_with(self.conn)

    def test_connected(self):
        self.conn.transport.connected = False
        assert not self.conn.connected
        self.conn.transport.connected = True
        assert self.conn.connected
        self.conn.transport = None
        assert not self.conn.connected

    def test_collect(self):
        channels = self.conn.channels = {
            0: self.conn, 1: Mock(name='c1'), 2: Mock(name='c2'),
        }
        transport = self.conn.transport
        self.conn.collect()
        transport.close.assert_called_with()
        for i, channel in items(channels):
            if i:
                channel.collect.assert_called_with()
        assert self.conn._transport is None

    def test_collect__channel_raises_socket_error(self):
        self.conn.channels = self.conn.channels = {1: Mock(name='c1')}
        self.conn.channels[1].collect.side_effect = socket.error()
        self.conn.collect()

    def test_collect_no_transport(self):
        self.conn = Connection()
        self.conn.connect = Mock(name='connect')
        assert not self.conn.connected
        self.conn.collect()
        assert not self.conn.connect.called

    def test_collect_again(self):
        self.conn = Connection()
        self.conn.collect()
        self.conn.collect()

    def test_get_free_channel_id__raises_IndexError(self):
        self.conn._avail_channel_ids = []
        with pytest.raises(ResourceError):
            self.conn._get_free_channel_id()

    def test_claim_channel_id(self):
        self.conn._claim_channel_id(30)
        with pytest.raises(ConnectionError):
            self.conn._claim_channel_id(30)

    def test_channel(self):
        callback = Mock(name='callback')
        c = self.conn.channel(3, callback)
        self.conn.Channel.assert_called_with(self.conn, 3, on_open=callback)
        c2 = self.conn.channel(3, callback)
        assert c2 is c

    def test_is_alive(self):
        with pytest.raises(NotImplementedError):
            self.conn.is_alive()

    def test_drain_events(self):
        self.conn.blocking_read = Mock(name='blocking_read')
        self.conn.drain_events(30)
        self.conn.blocking_read.assert_called_with(30)

    def test_blocking_read__no_timeout(self):
        self.conn.on_inbound_frame = Mock(name='on_inbound_frame')
        self.conn.transport.having_timeout = ContextMock()
        ret = self.conn.blocking_read(None)
        self.conn.transport.read_frame.assert_called_with()
        self.conn.on_inbound_frame.assert_called_with(
            self.conn.transport.read_frame(),
        )
        assert ret is self.conn.on_inbound_frame()

    def test_blocking_read__timeout(self):
        self.conn.transport = TCPTransport('localhost:5672')
        sock = self.conn.transport.sock = Mock(name='sock')
        sock.gettimeout.return_value = 1
        self.conn.transport.read_frame = Mock(name='read_frame')
        self.conn.on_inbound_frame = Mock(name='on_inbound_frame')
        self.conn.blocking_read(3)
        sock.gettimeout.assert_called_with()
        sock.settimeout.assert_has_calls([call(3), call(1)])
        self.conn.transport.read_frame.assert_called_with()
        self.conn.on_inbound_frame.assert_called_with(
            self.conn.transport.read_frame(),
        )
        sock.gettimeout.return_value = 3
        self.conn.blocking_read(3)

    def test_blocking_read__SSLError(self):
        self.conn.on_inbound_frame = Mock(name='on_inbound_frame')
        self.conn.transport = TCPTransport('localhost:5672')
        sock = self.conn.transport.sock = Mock(name='sock')
        sock.gettimeout.return_value = 1
        self.conn.transport.read_frame = Mock(name='read_frame')
        self.conn.transport.read_frame.side_effect = SSLError(
            'operation timed out')
        with pytest.raises(socket.timeout):
            self.conn.blocking_read(3)
        self.conn.transport.read_frame.side_effect = SSLError(
            'The operation did not complete foo bar')
        with pytest.raises(socket.timeout):
            self.conn.blocking_read(3)
        self.conn.transport.read_frame.side_effect = SSLError(
            'oh noes')
        with pytest.raises(SSLError):
            self.conn.blocking_read(3)

    def test_on_inbound_method(self):
        self.conn.channels[1] = self.conn.channel(1)
        self.conn.on_inbound_method(1, (50, 60), 'payload', 'content')
        self.conn.channels[1].dispatch_method.assert_called_with(
            (50, 60), 'payload', 'content',
        )

    def test_close(self):
        self.conn.collect = Mock(name='collect')
        self.conn.close(reply_text='foo', method_sig=spec.Channel.Open)
        self.conn.send_method.assert_called_with(
            spec.Connection.Close, 'BsBB',
            (0, 'foo', spec.Channel.Open[0], spec.Channel.Open[1]),
            wait=spec.Connection.CloseOk,
        )

    def test_close__already_closed(self):
        self.conn.transport = None
        self.conn.close()

    def test_close__socket_error(self):
        self.conn.send_method = Mock(name='send_method',
                                     side_effect=socket.error)
        with pytest.raises(socket.error):
            self.conn.close()
        self.conn.send_method.assert_called()
        assert self.conn._transport is None and not self.conn.connected

    def test_on_close(self):
        self.conn._x_close_ok = Mock(name='_x_close_ok')
        with pytest.raises(NotFound):
            self.conn._on_close(404, 'bah not found', 50, 60)

    def test_x_close_ok(self):
        self.conn._x_close_ok()
        self.conn.send_method.assert_called_with(
            spec.Connection.CloseOk, callback=self.conn._on_close_ok,
        )

    def test_on_close_ok(self):
        self.conn.collect = Mock(name='collect')
        self.conn._on_close_ok()
        self.conn.collect.assert_called_with()

    def test_on_blocked(self):
        self.conn._on_blocked()
        self.conn.on_blocked = Mock(name='on_blocked')
        self.conn._on_blocked()
        self.conn.on_blocked.assert_called_with(
            'connection blocked, see broker logs')

    def test_on_unblocked(self):
        self.conn._on_unblocked()
        self.conn.on_unblocked = Mock(name='on_unblocked')
        self.conn._on_unblocked()
        self.conn.on_unblocked.assert_called_with()

    def test_send_heartbeat(self):
        self.conn.send_heartbeat()
        self.conn.frame_writer.assert_called_with(
            8, 0, None, None, None,
        )

    def test_heartbeat_tick__no_heartbeat(self):
        self.conn.heartbeat = 0
        self.conn.heartbeat_tick()

    def test_heartbeat_tick(self):
        self.conn.heartbeat = 3
        self.conn.heartbeat_tick()
        self.conn.bytes_sent = 3124
        self.conn.bytes_recv = 123
        self.conn.heartbeat_tick()
        self.conn.last_heartbeat_received -= 1000
        self.conn.last_heartbeat_sent -= 1000
        with pytest.raises(ConnectionError):
            self.conn.heartbeat_tick()

    def test_server_capabilities(self):
        self.conn.server_properties['capabilities'] = {'foo': 1}
        assert self.conn.server_capabilities == {'foo': 1}
Beispiel #7
0
class test_Connection(Case):

    def setup(self):
        self.frame_handler = Mock(name='frame_handler')
        self.frame_writer = Mock(name='frame_writer')
        self.conn = Connection(
            frame_handler=self.frame_handler,
            frame_writer=self.frame_writer,
        )
        self.conn.Channel = Mock(name='Channel')
        self.conn.Transport = Mock(name='Transport')
        self.conn.transport = self.conn.Transport.return_value
        self.conn.send_method = Mock(name='send_method')
        self.conn._frame_writer = Mock(name='_frame_writer')

    def test_login_response(self):
        self.conn = Connection(login_response='foo')
        self.assertEqual(self.conn.login_response, 'foo')

    def test_enter_exit(self):
        self.conn.connect = Mock(name='connect')
        self.conn.close = Mock(name='close')
        with self.conn:
            self.conn.connect.assert_called_with()
        self.conn.close.assert_called_with()

    def test_then(self):
        self.conn.on_open = Mock(name='on_open')
        on_success = Mock(name='on_success')
        on_error = Mock(name='on_error')
        self.conn.then(on_success, on_error)
        self.conn.on_open.then.assert_called_with(on_success, on_error)

    def test_connect(self):
        self.conn.transport.connected = False
        self.conn.drain_events = Mock(name='drain_events')

        def on_drain(*args, **kwargs):
            self.conn._handshake_complete = True
        self.conn.drain_events.side_effect = on_drain
        self.conn.connect()
        self.conn.Transport.assert_called_with(
            self.conn.host, self.conn.connect_timeout, self.conn.ssl,
            self.conn.read_timeout, self.conn.write_timeout,
            socket_settings=self.conn.socket_settings,
        )

    def test_connect__already_connected(self):
        callback = Mock(name='callback')
        self.conn.transport.connected = True
        self.assertIs(self.conn.connect(callback), callback.return_value)
        callback.assert_called_with()

    def test_on_start(self):
        self.conn._on_start(3, 4, {'foo': 'bar'}, 'x y z', 'en_US en_GB')
        self.assertEqual(self.conn.version_major, 3)
        self.assertEqual(self.conn.version_minor, 4)
        self.assertEqual(self.conn.server_properties, {'foo': 'bar'})
        self.assertEqual(self.conn.mechanisms, ['x', 'y', 'z'])
        self.assertEqual(self.conn.locales, ['en_US', 'en_GB'])
        self.conn.send_method.assert_called_with(
            spec.Connection.StartOk, 'FsSs', (
                self.conn.client_properties, self.conn.login_method,
                self.conn.login_response, self.conn.locale,
            ),
        )

    def test_on_start__consumer_cancel_notify(self):
        self.conn._on_start(
            3, 4, {'capabilities': {'consumer_cancel_notify': 1}},
            '', '',
        )
        cap = self.conn.client_properties['capabilities']
        self.assertTrue(cap['consumer_cancel_notify'])

    def test_on_start__connection_blocked(self):
        self.conn._on_start(
            3, 4, {'capabilities': {'connection.blocked': 1}},
            '', '',
        )
        cap = self.conn.client_properties['capabilities']
        self.assertTrue(cap['connection.blocked'])

    def test_on_secure(self):
        self.conn._on_secure('vfz')

    def test_on_tune(self):
        self.conn.client_heartbeat = 16
        self.conn._on_tune(345, 16, 10)
        self.assertEqual(self.conn.channel_max, 345)
        self.assertEqual(self.conn.frame_max, 16)
        self.assertEqual(self.conn.server_heartbeat, 10)
        self.assertEqual(self.conn.heartbeat, 10)
        self.conn.send_method.assert_called_with(
            spec.Connection.TuneOk, 'BlB', (
                self.conn.channel_max, self.conn.frame_max,
                self.conn.heartbeat,
            ),
            callback=self.conn._on_tune_sent,
        )

    def test_on_tune__client_heartbeat_disabled(self):
        self.conn.client_heartbeat = 0
        self.conn._on_tune(345, 16, 10)
        self.assertEqual(self.conn.heartbeat, 0)

    def test_on_tune_sent(self):
        self.conn._on_tune_sent()
        self.conn.send_method.assert_called_with(
            spec.Connection.Open, 'ssb', (self.conn.virtual_host, '', False),
        )

    def test_on_open_ok(self):
        self.conn.on_open = Mock(name='on_open')
        self.conn._on_open_ok()
        self.assertTrue(self.conn._handshake_complete)
        self.conn.on_open.assert_called_with(self.conn)

    def test_connected(self):
        self.conn.transport.connected = False
        self.assertFalse(self.conn.connected)
        self.conn.transport.connected = True
        self.assertTrue(self.conn.connected)
        self.conn.transport = None
        self.assertFalse(self.conn.connected)

    def test_collect(self):
        channels = self.conn.channels = {
            0: self.conn, 1: Mock(name='c1'), 2: Mock(name='c2'),
        }
        transport = self.conn.transport
        self.conn.collect()
        transport.close.assert_called_with()
        for i, channel in items(channels):
            if i:
                channel.collect.assert_called_with()

    def test_collect__channel_raises_socket_error(self):
        self.conn.channels = self.conn.channels = {1: Mock(name='c1')}
        self.conn.channels[1].collect.side_effect = socket.error()
        self.conn.collect()

    def test_get_free_channel_id__raises_IndexError(self):
        self.conn._avail_channel_ids = []
        with self.assertRaises(ResourceError):
            self.conn._get_free_channel_id()

    def test_claim_channel_id(self):
        self.conn._claim_channel_id(30)
        with self.assertRaises(ConnectionError):
            self.conn._claim_channel_id(30)

    def test_channel(self):
        callback = Mock(name='callback')
        c = self.conn.channel(3, callback)
        self.conn.Channel.assert_called_with(self.conn, 3, on_open=callback)
        c2 = self.conn.channel(3, callback)
        self.assertIs(c2, c)

    def test_is_alive(self):
        with self.assertRaises(NotImplementedError):
            self.conn.is_alive()

    def test_drain_events(self):
        self.conn.blocking_read = Mock(name='blocking_read')
        self.conn.drain_events(30)
        self.conn.blocking_read.assert_called_with(30)

    def test_blocking_read__no_timeout(self):
        self.conn.on_inbound_frame = Mock(name='on_inbound_frame')
        self.conn.transport.having_timeout = ContextMock()
        ret = self.conn.blocking_read(None)
        self.conn.transport.read_frame.assert_called_with()
        self.conn.on_inbound_frame.assert_called_with(
            self.conn.transport.read_frame(),
        )
        self.assertIs(ret, self.conn.on_inbound_frame())

    def test_blocking_read__timeout(self):
        self.conn.transport = TCPTransport('localhost:5672')
        sock = self.conn.transport.sock = Mock(name='sock')
        sock.gettimeout.return_value = 1
        self.conn.transport.read_frame = Mock(name='read_frame')
        self.conn.on_inbound_frame = Mock(name='on_inbound_frame')
        self.conn.blocking_read(3)
        sock.gettimeout.assert_called_with()
        sock.settimeout.assert_has_calls([call(3), call(1)])
        self.conn.transport.read_frame.assert_called_with()
        self.conn.on_inbound_frame.assert_called_with(
            self.conn.transport.read_frame(),
        )
        sock.gettimeout.return_value = 3
        self.conn.blocking_read(3)

    def test_blocking_read__SSLError(self):
        self.conn.on_inbound_frame = Mock(name='on_inbound_frame')
        self.conn.transport = TCPTransport('localhost:5672')
        sock = self.conn.transport.sock = Mock(name='sock')
        sock.gettimeout.return_value = 1
        self.conn.transport.read_frame = Mock(name='read_frame')
        self.conn.transport.read_frame.side_effect = SSLError(
            'operation timed out')
        with self.assertRaises(socket.timeout):
            self.conn.blocking_read(3)
        self.conn.transport.read_frame.side_effect = SSLError(
            'The operation did not complete foo bar')
        with self.assertRaises(socket.timeout):
            self.conn.blocking_read(3)
        self.conn.transport.read_frame.side_effect = SSLError(
            'oh noes')
        with self.assertRaises(SSLError):
            self.conn.blocking_read(3)

    def test_on_inbound_method(self):
        self.conn.channels[1] = self.conn.channel(1)
        self.conn.on_inbound_method(1, (50, 60), 'payload', 'content')
        self.conn.channels[1].dispatch_method.assert_called_with(
            (50, 60), 'payload', 'content',
        )

    def test_close(self):
        self.conn.close(reply_text='foo', method_sig=spec.Channel.Open)
        self.conn.send_method.assert_called_with(
            spec.Connection.Close, 'BssBB',
            (0, 'foo', spec.Channel.Open[0], spec.Channel.Open[1]),
            wait=spec.Connection.CloseOk,
        )

    def test_close__already_closed(self):
        self.conn.transport = None
        self.conn.close()

    def test_on_close(self):
        self.conn._x_close_ok = Mock(name='_x_close_ok')
        with self.assertRaises(NotFound):
            self.conn._on_close(404, 'bah not found', 50, 60)

    def test_x_close_ok(self):
        self.conn._x_close_ok()
        self.conn.send_method.assert_called_with(
            spec.Connection.CloseOk, callback=self.conn._on_close_ok,
        )

    def test_on_close_ok(self):
        self.conn.collect = Mock(name='collect')
        self.conn._on_close_ok()
        self.conn.collect.assert_called_with()

    def test_on_blocked(self):
        self.conn._on_blocked()
        self.conn.on_blocked = Mock(name='on_blocked')
        self.conn._on_blocked()
        self.conn.on_blocked.assert_called_with(
            'connection blocked, see broker logs')

    def test_on_unblocked(self):
        self.conn._on_unblocked()
        self.conn.on_unblocked = Mock(name='on_unblocked')
        self.conn._on_unblocked()
        self.conn.on_unblocked.assert_called_with()

    def test_send_heartbeat(self):
        self.conn.send_heartbeat()
        self.conn._frame_writer.send.assert_called_with(
            (8, 0, None, None, None),
        )
        self.conn._frame_writer.send.side_effect = StopIteration()
        with self.assertRaises(RecoverableConnectionError):
            self.conn.send_heartbeat()

    def test_heartbeat_tick__no_heartbeat(self):
        self.conn.heartbeat = 0
        self.conn.heartbeat_tick()

    def test_heartbeat_tick(self):
        self.conn.heartbeat = 3
        self.conn.heartbeat_tick()
        self.conn.bytes_sent = 3124
        self.conn.bytes_recv = 123
        self.conn.heartbeat_tick()
        self.conn.last_heartbeat_received -= 1000
        self.conn.last_heartbeat_sent -= 1000
        with self.assertRaises(ConnectionError):
            self.conn.heartbeat_tick()

    def test_server_capabilities(self):
        self.conn.server_properties['capabilities'] = {'foo': 1}
        self.assertEqual(self.conn.server_capabilities, {'foo': 1})
Beispiel #8
0
class test_Connection:
    @pytest.fixture(autouse=True)
    def setup_conn(self):
        self.frame_handler = Mock(name='frame_handler')
        self.frame_writer = Mock(name='frame_writer_cls')
        self.conn = Connection(
            frame_handler=self.frame_handler,
            frame_writer=self.frame_writer,
        )
        self.conn.Channel = Mock(name='Channel')
        self.conn.Transport = Mock(name='Transport')
        self.conn.transport = self.conn.Transport.return_value
        self.conn.send_method = Mock(name='send_method')
        self.conn.frame_writer = Mock(name='frame_writer')

    def test_login_response(self):
        self.conn = Connection(login_response='foo')
        assert self.conn.login_response == 'foo'

    def test_enter_exit(self):
        self.conn.connect = Mock(name='connect')
        self.conn.close = Mock(name='close')
        with self.conn:
            self.conn.connect.assert_called_with()
        self.conn.close.assert_called_with()

    def test_then(self):
        self.conn.on_open = Mock(name='on_open')
        on_success = Mock(name='on_success')
        on_error = Mock(name='on_error')
        self.conn.then(on_success, on_error)
        self.conn.on_open.then.assert_called_with(on_success, on_error)

    def test_connect(self):
        self.conn.transport.connected = False
        self.conn.drain_events = Mock(name='drain_events')

        def on_drain(*args, **kwargs):
            self.conn._handshake_complete = True

        self.conn.drain_events.side_effect = on_drain
        self.conn.connect()
        self.conn.Transport.assert_called_with(
            self.conn.host,
            self.conn.connect_timeout,
            self.conn.ssl,
            self.conn.read_timeout,
            self.conn.write_timeout,
            socket_settings=self.conn.socket_settings,
        )

    def test_connect__already_connected(self):
        callback = Mock(name='callback')
        self.conn.transport.connected = True
        assert self.conn.connect(callback) == callback.return_value
        callback.assert_called_with()

    def test_on_start(self):
        self.conn._on_start(3, 4, {'foo': 'bar'}, 'x y z', 'en_US en_GB')
        assert self.conn.version_major == 3
        assert self.conn.version_minor == 4
        assert self.conn.server_properties == {'foo': 'bar'}
        assert self.conn.mechanisms == ['x', 'y', 'z']
        assert self.conn.locales == ['en_US', 'en_GB']
        self.conn.send_method.assert_called_with(
            spec.Connection.StartOk,
            'FsSs',
            (
                self.conn.client_properties,
                self.conn.login_method,
                self.conn.login_response,
                self.conn.locale,
            ),
        )

    def test_on_start__consumer_cancel_notify(self):
        self.conn._on_start(
            3,
            4,
            {'capabilities': {
                'consumer_cancel_notify': 1
            }},
            '',
            '',
        )
        cap = self.conn.client_properties['capabilities']
        assert cap['consumer_cancel_notify']

    def test_on_start__connection_blocked(self):
        self.conn._on_start(
            3,
            4,
            {'capabilities': {
                'connection.blocked': 1
            }},
            '',
            '',
        )
        cap = self.conn.client_properties['capabilities']
        assert cap['connection.blocked']

    def test_on_secure(self):
        self.conn._on_secure('vfz')

    def test_on_tune(self):
        self.conn.client_heartbeat = 16
        self.conn._on_tune(345, 16, 10)
        assert self.conn.channel_max == 345
        assert self.conn.frame_max == 16
        assert self.conn.server_heartbeat == 10
        assert self.conn.heartbeat == 10
        self.conn.send_method.assert_called_with(
            spec.Connection.TuneOk,
            'BlB',
            (
                self.conn.channel_max,
                self.conn.frame_max,
                self.conn.heartbeat,
            ),
            callback=self.conn._on_tune_sent,
        )

    def test_on_tune__client_heartbeat_disabled(self):
        self.conn.client_heartbeat = 0
        self.conn._on_tune(345, 16, 10)
        assert self.conn.heartbeat == 0

    def test_on_tune_sent(self):
        self.conn._on_tune_sent()
        self.conn.send_method.assert_called_with(
            spec.Connection.Open,
            'ssb',
            (self.conn.virtual_host, '', False),
        )

    def test_on_open_ok(self):
        self.conn.on_open = Mock(name='on_open')
        self.conn._on_open_ok()
        assert self.conn._handshake_complete
        self.conn.on_open.assert_called_with(self.conn)

    def test_connected(self):
        self.conn.transport.connected = False
        assert not self.conn.connected
        self.conn.transport.connected = True
        assert self.conn.connected
        self.conn.transport = None
        assert not self.conn.connected

    def test_collect(self):
        channels = self.conn.channels = {
            0: self.conn,
            1: Mock(name='c1'),
            2: Mock(name='c2'),
        }
        transport = self.conn.transport
        self.conn.collect()
        transport.close.assert_called_with()
        for i, channel in items(channels):
            if i:
                channel.collect.assert_called_with()

    def test_collect__channel_raises_socket_error(self):
        self.conn.channels = self.conn.channels = {1: Mock(name='c1')}
        self.conn.channels[1].collect.side_effect = socket.error()
        self.conn.collect()

    def test_get_free_channel_id__raises_IndexError(self):
        self.conn._avail_channel_ids = []
        with pytest.raises(ResourceError):
            self.conn._get_free_channel_id()

    def test_claim_channel_id(self):
        self.conn._claim_channel_id(30)
        with pytest.raises(ConnectionError):
            self.conn._claim_channel_id(30)

    def test_channel(self):
        callback = Mock(name='callback')
        c = self.conn.channel(3, callback)
        self.conn.Channel.assert_called_with(self.conn, 3, on_open=callback)
        c2 = self.conn.channel(3, callback)
        assert c2 is c

    def test_is_alive(self):
        with pytest.raises(NotImplementedError):
            self.conn.is_alive()

    def test_drain_events(self):
        self.conn.blocking_read = Mock(name='blocking_read')
        self.conn.drain_events(30)
        self.conn.blocking_read.assert_called_with(30)

    def test_blocking_read__no_timeout(self):
        self.conn.on_inbound_frame = Mock(name='on_inbound_frame')
        self.conn.transport.having_timeout = ContextMock()
        ret = self.conn.blocking_read(None)
        self.conn.transport.read_frame.assert_called_with()
        self.conn.on_inbound_frame.assert_called_with(
            self.conn.transport.read_frame(), )
        assert ret is self.conn.on_inbound_frame()

    def test_blocking_read__timeout(self):
        self.conn.transport = TCPTransport('localhost:5672')
        sock = self.conn.transport.sock = Mock(name='sock')
        sock.gettimeout.return_value = 1
        self.conn.transport.read_frame = Mock(name='read_frame')
        self.conn.on_inbound_frame = Mock(name='on_inbound_frame')
        self.conn.blocking_read(3)
        sock.gettimeout.assert_called_with()
        sock.settimeout.assert_has_calls([call(3), call(1)])
        self.conn.transport.read_frame.assert_called_with()
        self.conn.on_inbound_frame.assert_called_with(
            self.conn.transport.read_frame(), )
        sock.gettimeout.return_value = 3
        self.conn.blocking_read(3)

    def test_blocking_read__SSLError(self):
        self.conn.on_inbound_frame = Mock(name='on_inbound_frame')
        self.conn.transport = TCPTransport('localhost:5672')
        sock = self.conn.transport.sock = Mock(name='sock')
        sock.gettimeout.return_value = 1
        self.conn.transport.read_frame = Mock(name='read_frame')
        self.conn.transport.read_frame.side_effect = SSLError(
            'operation timed out')
        with pytest.raises(socket.timeout):
            self.conn.blocking_read(3)
        self.conn.transport.read_frame.side_effect = SSLError(
            'The operation did not complete foo bar')
        with pytest.raises(socket.timeout):
            self.conn.blocking_read(3)
        self.conn.transport.read_frame.side_effect = SSLError('oh noes')
        with pytest.raises(SSLError):
            self.conn.blocking_read(3)

    def test_on_inbound_method(self):
        self.conn.channels[1] = self.conn.channel(1)
        self.conn.on_inbound_method(1, (50, 60), 'payload', 'content')
        self.conn.channels[1].dispatch_method.assert_called_with(
            (50, 60),
            'payload',
            'content',
        )

    def test_close(self):
        self.conn.close(reply_text='foo', method_sig=spec.Channel.Open)
        self.conn.send_method.assert_called_with(
            spec.Connection.Close,
            'BsBB',
            (0, 'foo', spec.Channel.Open[0], spec.Channel.Open[1]),
            wait=spec.Connection.CloseOk,
        )

    def test_close__already_closed(self):
        self.conn.transport = None
        self.conn.close()

    def test_on_close(self):
        self.conn._x_close_ok = Mock(name='_x_close_ok')
        with pytest.raises(NotFound):
            self.conn._on_close(404, 'bah not found', 50, 60)

    def test_x_close_ok(self):
        self.conn._x_close_ok()
        self.conn.send_method.assert_called_with(
            spec.Connection.CloseOk,
            callback=self.conn._on_close_ok,
        )

    def test_on_close_ok(self):
        self.conn.collect = Mock(name='collect')
        self.conn._on_close_ok()
        self.conn.collect.assert_called_with()

    def test_on_blocked(self):
        self.conn._on_blocked()
        self.conn.on_blocked = Mock(name='on_blocked')
        self.conn._on_blocked()
        self.conn.on_blocked.assert_called_with(
            'connection blocked, see broker logs')

    def test_on_unblocked(self):
        self.conn._on_unblocked()
        self.conn.on_unblocked = Mock(name='on_unblocked')
        self.conn._on_unblocked()
        self.conn.on_unblocked.assert_called_with()

    def test_send_heartbeat(self):
        self.conn.send_heartbeat()
        self.conn.frame_writer.assert_called_with(
            8,
            0,
            None,
            None,
            None,
        )

    def test_heartbeat_tick__no_heartbeat(self):
        self.conn.heartbeat = 0
        self.conn.heartbeat_tick()

    def test_heartbeat_tick(self):
        self.conn.heartbeat = 3
        self.conn.heartbeat_tick()
        self.conn.bytes_sent = 3124
        self.conn.bytes_recv = 123
        self.conn.heartbeat_tick()
        self.conn.last_heartbeat_received -= 1000
        self.conn.last_heartbeat_sent -= 1000
        with pytest.raises(ConnectionError):
            self.conn.heartbeat_tick()

    def test_server_capabilities(self):
        self.conn.server_properties['capabilities'] = {'foo': 1}
        assert self.conn.server_capabilities == {'foo': 1}
class TestChannel(unittest.TestCase):

    def setUp(self):
        self.conn = Connection(**settings.connect_args)
        self.ch = self.conn.channel()

    def tearDown(self):
        self.ch.close()
        self.conn.close()

    def test_defaults(self):
        """Test how a queue defaults to being bound to an AMQP default
        exchange, and how publishing defaults to the default exchange, and
        basic_get defaults to getting from the most recently declared queue,
        and queue_delete defaults to deleting the most recently declared
        queue."""
        msg = Message(
            'funtest message',
            content_type='text/plain',
            application_headers={'foo': 7, 'bar': 'baz'},
        )

        qname, _, _ = self.ch.queue_declare()
        self.ch.basic_publish(msg, routing_key=qname)

        msg2 = self.ch.basic_get(no_ack=True)
        self.assertEqual(msg, msg2)

        n = self.ch.queue_purge()
        self.assertEqual(n, 0)

        n = self.ch.queue_delete()
        self.assertEqual(n, 0)

    def test_encoding(self):
        my_routing_key = 'funtest.test_queue'

        qname, _, _ = self.ch.queue_declare()
        self.ch.queue_bind(qname, 'amq.direct', routing_key=my_routing_key)

        #
        # No encoding, body passed through unchanged
        #
        msg = Message('hello world')
        self.ch.basic_publish(msg, 'amq.direct', routing_key=my_routing_key)
        msg2 = self.ch.basic_get(qname, no_ack=True)
        if sys.version_info[0] < 3:
            self.assertFalse(hasattr(msg2, 'content_encoding'))
        self.assertTrue(isinstance(msg2.body, str))
        self.assertEqual(msg2.body, 'hello world')

        #
        # Default UTF-8 encoding of unicode body, returned as unicode
        #
        msg = Message(u'hello world')
        self.ch.basic_publish(msg, 'amq.direct', routing_key=my_routing_key)
        msg2 = self.ch.basic_get(qname, no_ack=True)
        self.assertEqual(msg2.content_encoding, 'UTF-8')
        self.assertTrue(isinstance(msg2.body, unicode))
        self.assertEqual(msg2.body, u'hello world')

        #
        # Explicit latin_1 encoding, still comes back as unicode
        #
        msg = Message(u'hello world', content_encoding='latin_1')
        self.ch.basic_publish(msg, 'amq.direct', routing_key=my_routing_key)
        msg2 = self.ch.basic_get(qname, no_ack=True)
        self.assertEqual(msg2.content_encoding, 'latin_1')
        self.assertTrue(isinstance(msg2.body, unicode))
        self.assertEqual(msg2.body, u'hello world')

        #
        # Plain string with specified encoding comes back as unicode
        #
        msg = Message('hello w\xf6rld', content_encoding='latin_1')
        self.ch.basic_publish(msg, 'amq.direct', routing_key=my_routing_key)
        msg2 = self.ch.basic_get(qname, no_ack=True)
        self.assertEqual(msg2.content_encoding, 'latin_1')
        self.assertTrue(isinstance(msg2.body, unicode))
        self.assertEqual(msg2.body, u'hello w\u00f6rld')

        #
        # Plain string (bytes in Python 3.x) with bogus encoding
        #

        # don't really care about latin_1, just want bytes
        test_bytes = u'hello w\xd6rld'.encode('latin_1')
        msg = Message(test_bytes, content_encoding='I made this up')
        self.ch.basic_publish(msg, 'amq.direct', routing_key=my_routing_key)
        msg2 = self.ch.basic_get(qname, no_ack=True)
        self.assertEqual(msg2.content_encoding, 'I made this up')
        self.assertTrue(isinstance(msg2.body, bytes))
        self.assertEqual(msg2.body, test_bytes)

        #
        # Turn off auto_decode for remaining tests
        #
        self.ch.auto_decode = False

        #
        # Unicode body comes back as utf-8 encoded str
        #
        msg = Message(u'hello w\u00f6rld')
        self.ch.basic_publish(msg, 'amq.direct', routing_key=my_routing_key)
        msg2 = self.ch.basic_get(qname, no_ack=True)
        self.assertEqual(msg2.content_encoding, 'UTF-8')
        self.assertTrue(isinstance(msg2.body, bytes))
        self.assertEqual(msg2.body, u'hello w\xc3\xb6rld'.encode('latin_1'))

        #
        # Plain string with specified encoding stays plain string
        #
        msg = Message('hello w\xf6rld', content_encoding='latin_1')
        self.ch.basic_publish(msg, 'amq.direct', routing_key=my_routing_key)
        msg2 = self.ch.basic_get(qname, no_ack=True)
        self.assertEqual(msg2.content_encoding, 'latin_1')
        self.assertTrue(isinstance(msg2.body, bytes))
        self.assertEqual(msg2.body, u'hello w\xf6rld'.encode('latin_1'))

        #
        # Explicit latin_1 encoding, comes back as str
        #
        msg = Message(u'hello w\u00f6rld', content_encoding='latin_1')
        self.ch.basic_publish(msg, 'amq.direct', routing_key=my_routing_key)
        msg2 = self.ch.basic_get(qname, no_ack=True)
        self.assertEqual(msg2.content_encoding, 'latin_1')
        self.assertTrue(isinstance(msg2.body, bytes))
        self.assertEqual(msg2.body, u'hello w\xf6rld'.encode('latin_1'))

    def test_exception(self):
        """
        Check that Channel exceptions are actually raised as Python
        exceptions.

        """
        with self.assertRaises(ChannelError):
            self.ch.queue_delete('bogus_queue_that_does_not_exist')

    def test_invalid_header(self):
        """
        Test sending a message with an unserializable object in the header

        http://code.google.com/p/py-amqplib/issues/detail?id=17

        """
        qname, _, _ = self.ch.queue_declare()

        msg = Message(application_headers={'test': None})

        self.assertRaises(
            FrameSyntaxError, self.ch.basic_publish, msg, routing_key=qname,
        )

    def test_large(self):
        """
        Test sending some extra large messages.

        """
        qname, _, _ = self.ch.queue_declare()

        for multiplier in [100, 1000, 10000]:
            msg = Message(
                'funtest message' * multiplier,
                content_type='text/plain',
                application_headers={'foo': 7, 'bar': 'baz'},
            )

            self.ch.basic_publish(msg, routing_key=qname)

            msg2 = self.ch.basic_get(no_ack=True)
            self.assertEqual(msg, msg2)

    def test_publish(self):
        self.ch.exchange_declare('funtest.fanout', 'fanout', auto_delete=True)

        msg = Message(
            'funtest message',
            content_type='text/plain',
            application_headers={'foo': 7, 'bar': 'baz'},
        )

        self.ch.basic_publish(msg, 'funtest.fanout')

    def test_queue(self):
        my_routing_key = 'funtest.test_queue'
        msg = Message(
            'funtest message',
            content_type='text/plain',
            application_headers={'foo': 7, 'bar': 'baz'},
        )

        qname, _, _ = self.ch.queue_declare()
        self.ch.queue_bind(qname, 'amq.direct', routing_key=my_routing_key)

        self.ch.basic_publish(msg, 'amq.direct', routing_key=my_routing_key)

        msg2 = self.ch.basic_get(qname, no_ack=True)
        self.assertEqual(msg, msg2)

    def test_unbind(self):
        my_routing_key = 'funtest.test_queue'

        qname, _, _ = self.ch.queue_declare()
        self.ch.queue_bind(qname, 'amq.direct', routing_key=my_routing_key)
        self.ch.queue_unbind(qname, 'amq.direct', routing_key=my_routing_key)

    def test_basic_return(self):
        self.ch.exchange_declare('funtest.fanout', 'fanout', auto_delete=True)

        msg = Message(
            'funtest message',
            content_type='text/plain',
            application_headers={'foo': 7, 'bar': 'baz'})

        self.ch.basic_publish(msg, 'unittest.fanout')
        self.ch.basic_publish(msg, 'unittest.fanout', mandatory=True)
        self.ch.basic_publish(msg, 'unittest.fanout', mandatory=True)
        self.ch.basic_publish(msg, 'unittest.fanout', mandatory=True)
        self.ch.close()

        #
        # 3 of the 4 messages we sent should have been returned
        #
        self.assertEqual(self.ch.returned_messages.qsize(), 3)

    def test_exchange_bind(self):
        """Test exchange binding.
        Network configuration is as follows (-> is forwards to :
        source_exchange -> dest_exchange -> queue
        The test checks that once the message is publish to the
        destination exchange(unittest.topic_dest) it is delivered to the queue.
        """

        test_routing_key = 'unit_test__key'
        dest_exchange = 'unittest.topic_dest_bind'
        source_exchange = 'unittest.topic_source_bind'

        self.ch.exchange_declare(dest_exchange, 'topic', auto_delete=True)
        self.ch.exchange_declare(source_exchange, 'topic', auto_delete=True)

        qname, _, _ = self.ch.queue_declare()
        self.ch.exchange_bind(destination = dest_exchange,
                              source = source_exchange,
                              routing_key = test_routing_key)

        self.ch.queue_bind(qname, dest_exchange,
                           routing_key=test_routing_key)

        msg = Message('unittest message',
                      content_type='text/plain',
                      application_headers={'foo': 7, 'bar': 'baz'})


        self.ch.basic_publish(msg, source_exchange,
                              routing_key = test_routing_key)

        msg2 = self.ch.basic_get(qname, no_ack=True)
        self.assertEqual(msg, msg2)

    def test_exchange_unbind(self):
        dest_exchange = 'unittest.topic_dest_unbind'
        source_exchange = 'unittest.topic_source_unbind'
        test_routing_key = 'unit_test__key'

        self.ch.exchange_declare(dest_exchange,
                                 'topic', auto_delete=True)
        self.ch.exchange_declare(source_exchange,
                                 'topic', auto_delete=True)

        self.ch.exchange_bind(destination=dest_exchange,
                              source=source_exchange,
                              routing_key=test_routing_key)

        self.ch.exchange_unbind(destination=dest_exchange,
                                source=source_exchange,
                                routing_key=test_routing_key)
Beispiel #10
0
class test_Connection:
    @pytest.fixture(autouse=True)
    def setup_conn(self):
        self.frame_handler = Mock(name='frame_handler')
        self.frame_writer = Mock(name='frame_writer_cls')
        self.conn = Connection(
            frame_handler=self.frame_handler,
            frame_writer=self.frame_writer,
            authentication=AMQPLAIN('foo', 'bar'),
        )
        self.conn.Channel = Mock(name='Channel')
        self.conn.Transport = Mock(name='Transport')
        self.conn.transport = self.conn.Transport.return_value
        self.conn.send_method = Mock(name='send_method')
        self.conn.frame_writer = Mock(name='frame_writer')

    def test_sasl_authentication(self):
        authentication = SASL()
        self.conn = Connection(authentication=authentication)
        assert self.conn.authentication == (authentication, )

    def test_sasl_authentication_iterable(self):
        authentication = SASL()
        self.conn = Connection(authentication=(authentication, ))
        assert self.conn.authentication == (authentication, )

    def test_gssapi(self):
        self.conn = Connection()
        assert isinstance(self.conn.authentication[0], GSSAPI)

    def test_external(self):
        self.conn = Connection()
        assert isinstance(self.conn.authentication[1], EXTERNAL)

    def test_amqplain(self):
        self.conn = Connection(userid='foo', password='******')
        auth = self.conn.authentication[2]
        assert isinstance(auth, AMQPLAIN)
        assert auth.username == 'foo'
        assert auth.password == 'bar'

    def test_plain(self):
        self.conn = Connection(userid='foo', password='******')
        auth = self.conn.authentication[3]
        assert isinstance(auth, PLAIN)
        assert auth.username == 'foo'
        assert auth.password == 'bar'

    def test_login_method_gssapi(self):
        try:
            self.conn = Connection(userid=None,
                                   password=None,
                                   login_method='GSSAPI')
        except NotImplementedError:
            pass
        else:
            auths = self.conn.authentication
            assert len(auths) == 1
            assert isinstance(auths[0], GSSAPI)

    def test_login_method_external(self):
        self.conn = Connection(userid=None,
                               password=None,
                               login_method='EXTERNAL')
        auths = self.conn.authentication
        assert len(auths) == 1
        assert isinstance(auths[0], EXTERNAL)

    def test_login_method_amqplain(self):
        self.conn = Connection(login_method='AMQPLAIN')
        auths = self.conn.authentication
        assert len(auths) == 1
        assert isinstance(auths[0], AMQPLAIN)

    def test_login_method_plain(self):
        self.conn = Connection(login_method='PLAIN')
        auths = self.conn.authentication
        assert len(auths) == 1
        assert isinstance(auths[0], PLAIN)

    def test_enter_exit(self):
        self.conn.connect = Mock(name='connect')
        self.conn.close = Mock(name='close')
        with self.conn:
            self.conn.connect.assert_called_with()
        self.conn.close.assert_called_with()

    def test__enter__socket_error(self):
        # test when entering
        self.conn = Connection()
        self.conn.close = Mock(name='close')
        reached = False
        with patch('socket.socket', side_effect=socket.error):
            with pytest.raises(socket.error):
                with self.conn:
                    reached = True
        assert not reached and not self.conn.close.called
        assert self.conn._transport is None and not self.conn.connected

    def test__exit__socket_error(self):
        # test when exiting
        connection = self.conn
        transport = connection._transport
        transport.connected = True
        connection.send_method = Mock(name='send_method',
                                      side_effect=socket.error)
        reached = False
        with pytest.raises(socket.error):
            with connection:
                reached = True
        assert reached
        assert connection.send_method.called and transport.close.called
        assert self.conn._transport is None and not self.conn.connected

    def test_then(self):
        self.conn.on_open = Mock(name='on_open')
        on_success = Mock(name='on_success')
        on_error = Mock(name='on_error')
        self.conn.then(on_success, on_error)
        self.conn.on_open.then.assert_called_with(on_success, on_error)

    def test_connect(self):
        self.conn.transport.connected = False
        self.conn.drain_events = Mock(name='drain_events')

        def on_drain(*args, **kwargs):
            self.conn._handshake_complete = True

        self.conn.drain_events.side_effect = on_drain
        self.conn.connect()
        self.conn.Transport.assert_called_with(
            self.conn.host,
            self.conn.connect_timeout,
            self.conn.ssl,
            self.conn.read_timeout,
            self.conn.write_timeout,
            socket_settings=self.conn.socket_settings,
        )

    def test_connect__already_connected(self):
        callback = Mock(name='callback')
        self.conn.transport.connected = True
        assert self.conn.connect(callback) == callback.return_value
        callback.assert_called_with()

    def test_connect__socket_error(self):
        # check Transport.Connect error
        # socket.error derives from IOError
        # ssl.SSLError derives from socket.error
        self.conn = Connection()
        self.conn.Transport = Mock(name='Transport')
        transport = self.conn.Transport.return_value
        transport.connect.side_effect = IOError
        assert self.conn._transport is None and not self.conn.connected
        with pytest.raises(IOError):
            self.conn.connect()
        transport.connect.assert_called
        assert self.conn._transport is None and not self.conn.connected

    def test_on_start(self):
        self.conn._on_start(3, 4, {'foo': 'bar'}, b'x y z AMQPLAIN PLAIN',
                            'en_US en_GB')
        assert self.conn.version_major == 3
        assert self.conn.version_minor == 4
        assert self.conn.server_properties == {'foo': 'bar'}
        assert self.conn.mechanisms == [
            b'x', b'y', b'z', b'AMQPLAIN', b'PLAIN'
        ]
        assert self.conn.locales == ['en_US', 'en_GB']
        self.conn.send_method.assert_called_with(
            spec.Connection.StartOk,
            'FsSs',
            (
                self.conn.client_properties,
                b'AMQPLAIN',
                self.conn.authentication[0].start(self.conn),
                self.conn.locale,
            ),
        )

    def test_on_start_string_mechanisms(self):
        self.conn._on_start(3, 4, {'foo': 'bar'}, 'x y z AMQPLAIN PLAIN',
                            'en_US en_GB')
        assert self.conn.version_major == 3
        assert self.conn.version_minor == 4
        assert self.conn.server_properties == {'foo': 'bar'}
        assert self.conn.mechanisms == [
            b'x', b'y', b'z', b'AMQPLAIN', b'PLAIN'
        ]
        assert self.conn.locales == ['en_US', 'en_GB']
        self.conn.send_method.assert_called_with(
            spec.Connection.StartOk,
            'FsSs',
            (
                self.conn.client_properties,
                b'AMQPLAIN',
                self.conn.authentication[0].start(self.conn),
                self.conn.locale,
            ),
        )

    def test_missing_credentials(self):
        with pytest.raises(ValueError):
            self.conn = Connection(userid=None,
                                   password=None,
                                   login_method='AMQPLAIN')
        with pytest.raises(ValueError):
            self.conn = Connection(password=None, login_method='PLAIN')

    def test_invalid_method(self):
        with pytest.raises(ValueError):
            self.conn = Connection(login_method='any')

    def test_mechanism_mismatch(self):
        with pytest.raises(ConnectionError):
            self.conn._on_start(3, 4, {'foo': 'bar'}, b'x y z', 'en_US en_GB')

    def test_login_method_response(self):
        # An old way of doing things.:
        login_method, login_response = b'foo', b'bar'
        with warnings.catch_warnings(record=True) as w:
            warnings.simplefilter("always")
            self.conn = Connection(login_method=login_method,
                                   login_response=login_response)
            self.conn.send_method = Mock(name='send_method')
            self.conn._on_start(3, 4, {'foo': 'bar'}, login_method,
                                'en_US en_GB')
            assert len(w) == 1
            assert issubclass(w[0].category, DeprecationWarning)

        self.conn.send_method.assert_called_with(
            spec.Connection.StartOk,
            'FsSs',
            (
                self.conn.client_properties,
                login_method,
                login_response,
                self.conn.locale,
            ),
        )

    def test_on_start__consumer_cancel_notify(self):
        self.conn._on_start(
            3,
            4,
            {'capabilities': {
                'consumer_cancel_notify': 1
            }},
            b'AMQPLAIN',
            '',
        )
        cap = self.conn.client_properties['capabilities']
        assert cap['consumer_cancel_notify']

    def test_on_start__connection_blocked(self):
        self.conn._on_start(
            3,
            4,
            {'capabilities': {
                'connection.blocked': 1
            }},
            b'AMQPLAIN',
            '',
        )
        cap = self.conn.client_properties['capabilities']
        assert cap['connection.blocked']

    def test_on_start__authentication_failure_close(self):
        self.conn._on_start(
            3,
            4,
            {'capabilities': {
                'authentication_failure_close': 1
            }},
            b'AMQPLAIN',
            '',
        )
        cap = self.conn.client_properties['capabilities']
        assert cap['authentication_failure_close']

    def test_on_start__authentication_failure_close__disabled(self):
        self.conn._on_start(
            3,
            4,
            {'capabilities': {}},
            b'AMQPLAIN',
            '',
        )
        assert 'capabilities' not in self.conn.client_properties

    def test_on_secure(self):
        self.conn._on_secure('vfz')

    def test_on_tune(self):
        self.conn.client_heartbeat = 16
        self.conn._on_tune(345, 16, 10)
        assert self.conn.channel_max == 345
        assert self.conn.frame_max == 16
        assert self.conn.server_heartbeat == 10
        assert self.conn.heartbeat == 10
        self.conn.send_method.assert_called_with(
            spec.Connection.TuneOk,
            'BlB',
            (
                self.conn.channel_max,
                self.conn.frame_max,
                self.conn.heartbeat,
            ),
            callback=self.conn._on_tune_sent,
        )

    def test_on_tune__client_heartbeat_disabled(self):
        self.conn.client_heartbeat = 0
        self.conn._on_tune(345, 16, 10)
        assert self.conn.heartbeat == 0

    def test_on_tune_sent(self):
        self.conn._on_tune_sent()
        self.conn.send_method.assert_called_with(
            spec.Connection.Open,
            'ssb',
            (self.conn.virtual_host, '', False),
        )

    def test_on_open_ok(self):
        self.conn.on_open = Mock(name='on_open')
        self.conn._on_open_ok()
        assert self.conn._handshake_complete
        self.conn.on_open.assert_called_with(self.conn)

    def test_connected(self):
        self.conn.transport.connected = False
        assert not self.conn.connected
        self.conn.transport.connected = True
        assert self.conn.connected
        self.conn.transport = None
        assert not self.conn.connected

    def test_collect(self):
        channels = self.conn.channels = {
            0: self.conn,
            1: Mock(name='c1'),
            2: Mock(name='c2'),
        }
        transport = self.conn.transport
        self.conn.collect()
        transport.close.assert_called_with()
        for i, channel in channels.items():
            if i:
                channel.collect.assert_called_with()
        assert self.conn._transport is None

    def test_collect__channel_raises_socket_error(self):
        self.conn.channels = self.conn.channels = {1: Mock(name='c1')}
        self.conn.channels[1].collect.side_effect = socket.error()
        self.conn.collect()

    def test_collect_no_transport(self):
        self.conn = Connection()
        self.conn.connect = Mock(name='connect')
        assert not self.conn.connected
        self.conn.collect()
        assert not self.conn.connect.called

    def test_collect_again(self):
        self.conn = Connection()
        self.conn.collect()
        self.conn.collect()

    def test_get_free_channel_id__raises_IndexError(self):
        self.conn._avail_channel_ids = []
        with pytest.raises(ResourceError):
            self.conn._get_free_channel_id()

    def test_claim_channel_id(self):
        self.conn._claim_channel_id(30)
        with pytest.raises(ConnectionError):
            self.conn._claim_channel_id(30)

    def test_channel(self):
        callback = Mock(name='callback')
        c = self.conn.channel(3, callback)
        self.conn.Channel.assert_called_with(self.conn, 3, on_open=callback)
        c2 = self.conn.channel(3, callback)
        assert c2 is c

    def test_channel_when_connection_is_closed(self):
        self.conn.collect()
        callback = Mock(name='callback')
        with pytest.raises(RecoverableConnectionError):
            self.conn.channel(3, callback)

    def test_is_alive(self):
        with pytest.raises(NotImplementedError):
            self.conn.is_alive()

    def test_drain_events(self):
        self.conn.blocking_read = Mock(name='blocking_read')
        self.conn.drain_events(30)
        self.conn.blocking_read.assert_called_with(30)

    def test_blocking_read__no_timeout(self):
        self.conn.on_inbound_frame = Mock(name='on_inbound_frame')
        self.conn.transport.having_timeout = ContextMock()
        ret = self.conn.blocking_read(None)
        self.conn.transport.read_frame.assert_called_with()
        self.conn.on_inbound_frame.assert_called_with(
            self.conn.transport.read_frame(), )
        assert ret is self.conn.on_inbound_frame()

    def test_blocking_read__timeout(self):
        self.conn.transport = TCPTransport('localhost:5672')
        sock = self.conn.transport.sock = Mock(name='sock')
        sock.gettimeout.return_value = 1
        self.conn.transport.read_frame = Mock(name='read_frame')
        self.conn.on_inbound_frame = Mock(name='on_inbound_frame')
        self.conn.blocking_read(3)
        sock.gettimeout.assert_called_with()
        sock.settimeout.assert_has_calls([call(3), call(1)])
        self.conn.transport.read_frame.assert_called_with()
        self.conn.on_inbound_frame.assert_called_with(
            self.conn.transport.read_frame(), )
        sock.gettimeout.return_value = 3
        self.conn.blocking_read(3)

    def test_blocking_read__SSLError(self):
        self.conn.on_inbound_frame = Mock(name='on_inbound_frame')
        self.conn.transport = TCPTransport('localhost:5672')
        sock = self.conn.transport.sock = Mock(name='sock')
        sock.gettimeout.return_value = 1
        self.conn.transport.read_frame = Mock(name='read_frame')
        self.conn.transport.read_frame.side_effect = SSLError(
            'operation timed out')
        with pytest.raises(socket.timeout):
            self.conn.blocking_read(3)
        self.conn.transport.read_frame.side_effect = SSLError(
            'The operation did not complete foo bar')
        with pytest.raises(socket.timeout):
            self.conn.blocking_read(3)
        self.conn.transport.read_frame.side_effect = SSLError('oh noes')
        with pytest.raises(SSLError):
            self.conn.blocking_read(3)

    def test_on_inbound_method(self):
        self.conn.channels[1] = self.conn.channel(1)
        self.conn.on_inbound_method(1, (50, 60), 'payload', 'content')
        self.conn.channels[1].dispatch_method.assert_called_with(
            (50, 60),
            'payload',
            'content',
        )

    def test_on_inbound_method_when_connection_is_closed(self):
        self.conn.collect()
        with pytest.raises(RecoverableConnectionError):
            self.conn.on_inbound_method(1, (50, 60), 'payload', 'content')

    def test_close(self):
        self.conn.collect = Mock(name='collect')
        self.conn.close(reply_text='foo', method_sig=spec.Channel.Open)
        self.conn.send_method.assert_called_with(
            spec.Connection.Close,
            'BsBB',
            (0, 'foo', spec.Channel.Open[0], spec.Channel.Open[1]),
            wait=spec.Connection.CloseOk,
        )

    def test_close__already_closed(self):
        self.conn.transport = None
        self.conn.close()

    def test_close__socket_error(self):
        self.conn.send_method = Mock(name='send_method',
                                     side_effect=socket.error)
        with pytest.raises(socket.error):
            self.conn.close()
        self.conn.send_method.assert_called()
        assert self.conn._transport is None and not self.conn.connected

    def test_on_close(self):
        self.conn._x_close_ok = Mock(name='_x_close_ok')
        with pytest.raises(NotFound):
            self.conn._on_close(404, 'bah not found', 50, 60)

    def test_x_close_ok(self):
        self.conn._x_close_ok()
        self.conn.send_method.assert_called_with(
            spec.Connection.CloseOk,
            callback=self.conn._on_close_ok,
        )

    def test_on_close_ok(self):
        self.conn.collect = Mock(name='collect')
        self.conn._on_close_ok()
        self.conn.collect.assert_called_with()

    def test_on_blocked(self):
        self.conn._on_blocked()
        self.conn.on_blocked = Mock(name='on_blocked')
        self.conn._on_blocked()
        self.conn.on_blocked.assert_called_with(
            'connection blocked, see broker logs')

    def test_on_unblocked(self):
        self.conn._on_unblocked()
        self.conn.on_unblocked = Mock(name='on_unblocked')
        self.conn._on_unblocked()
        self.conn.on_unblocked.assert_called_with()

    def test_send_heartbeat(self):
        self.conn.send_heartbeat()
        self.conn.frame_writer.assert_called_with(
            8,
            0,
            None,
            None,
            None,
        )

    def test_heartbeat_tick__no_heartbeat(self):
        self.conn.heartbeat = 0
        self.conn.heartbeat_tick()

    def test_heartbeat_tick(self):
        self.conn.heartbeat = 3
        self.conn.heartbeat_tick()
        self.conn.bytes_sent = 3124
        self.conn.bytes_recv = 123
        self.conn.heartbeat_tick()
        self.conn.last_heartbeat_received -= 1000
        self.conn.last_heartbeat_sent -= 1000
        with pytest.raises(ConnectionError):
            self.conn.heartbeat_tick()

    def test_server_capabilities(self):
        self.conn.server_properties['capabilities'] = {'foo': 1}
        assert self.conn.server_capabilities == {'foo': 1}
Beispiel #11
0
class TestChannel(unittest.TestCase):
    def setUp(self):
        self.conn = Connection(**settings.connect_args)
        self.ch = self.conn.channel()

    def tearDown(self):
        self.ch.close()
        self.conn.close()

    def test_defaults(self):
        """Test how a queue defaults to being bound to an AMQP default
        exchange, and how publishing defaults to the default exchange, and
        basic_get defaults to getting from the most recently declared queue,
        and queue_delete defaults to deleting the most recently declared
        queue."""
        msg = Message(
            'funtest message',
            content_type='text/plain',
            application_headers={
                'foo': 7,
                'bar': 'baz'
            },
        )

        qname, _, _ = self.ch.queue_declare()
        self.ch.basic_publish(msg, routing_key=qname)

        msg2 = self.ch.basic_get(no_ack=True)
        self.assertEqual(msg, msg2)

        n = self.ch.queue_purge()
        self.assertEqual(n, 0)

        n = self.ch.queue_delete()
        self.assertEqual(n, 0)

    def test_encoding(self):
        my_routing_key = 'funtest.test_queue'

        qname, _, _ = self.ch.queue_declare()
        self.ch.queue_bind(qname, 'amq.direct', routing_key=my_routing_key)

        #
        # No encoding, body passed through unchanged
        #
        msg = Message('hello world')
        self.ch.basic_publish(msg, 'amq.direct', routing_key=my_routing_key)
        msg2 = self.ch.basic_get(qname, no_ack=True)
        if sys.version_info[0] < 3:
            self.assertFalse(hasattr(msg2, 'content_encoding'))
        self.assertTrue(isinstance(msg2.body, str))
        self.assertEqual(msg2.body, 'hello world')

        #
        # Default UTF-8 encoding of unicode body, returned as unicode
        #
        msg = Message(u'hello world')
        self.ch.basic_publish(msg, 'amq.direct', routing_key=my_routing_key)
        msg2 = self.ch.basic_get(qname, no_ack=True)
        self.assertEqual(msg2.content_encoding, 'UTF-8')
        self.assertTrue(isinstance(msg2.body, unicode))
        self.assertEqual(msg2.body, u'hello world')

        #
        # Explicit latin_1 encoding, still comes back as unicode
        #
        msg = Message(u'hello world', content_encoding='latin_1')
        self.ch.basic_publish(msg, 'amq.direct', routing_key=my_routing_key)
        msg2 = self.ch.basic_get(qname, no_ack=True)
        self.assertEqual(msg2.content_encoding, 'latin_1')
        self.assertTrue(isinstance(msg2.body, unicode))
        self.assertEqual(msg2.body, u'hello world')

        #
        # Plain string with specified encoding comes back as unicode
        #
        msg = Message('hello w\xf6rld', content_encoding='latin_1')
        self.ch.basic_publish(msg, 'amq.direct', routing_key=my_routing_key)
        msg2 = self.ch.basic_get(qname, no_ack=True)
        self.assertEqual(msg2.content_encoding, 'latin_1')
        self.assertTrue(isinstance(msg2.body, unicode))
        self.assertEqual(msg2.body, u'hello w\u00f6rld')

        #
        # Plain string (bytes in Python 3.x) with bogus encoding
        #

        # don't really care about latin_1, just want bytes
        test_bytes = u'hello w\xd6rld'.encode('latin_1')
        msg = Message(test_bytes, content_encoding='I made this up')
        self.ch.basic_publish(msg, 'amq.direct', routing_key=my_routing_key)
        msg2 = self.ch.basic_get(qname, no_ack=True)
        self.assertEqual(msg2.content_encoding, 'I made this up')
        self.assertTrue(isinstance(msg2.body, bytes))
        self.assertEqual(msg2.body, test_bytes)

        #
        # Turn off auto_decode for remaining tests
        #
        self.ch.auto_decode = False

        #
        # Unicode body comes back as utf-8 encoded str
        #
        msg = Message(u'hello w\u00f6rld')
        self.ch.basic_publish(msg, 'amq.direct', routing_key=my_routing_key)
        msg2 = self.ch.basic_get(qname, no_ack=True)
        self.assertEqual(msg2.content_encoding, 'UTF-8')
        self.assertTrue(isinstance(msg2.body, bytes))
        self.assertEqual(msg2.body, u'hello w\xc3\xb6rld'.encode('latin_1'))

        #
        # Plain string with specified encoding stays plain string
        #
        msg = Message('hello w\xf6rld', content_encoding='latin_1')
        self.ch.basic_publish(msg, 'amq.direct', routing_key=my_routing_key)
        msg2 = self.ch.basic_get(qname, no_ack=True)
        self.assertEqual(msg2.content_encoding, 'latin_1')
        self.assertTrue(isinstance(msg2.body, bytes))
        self.assertEqual(msg2.body, u'hello w\xf6rld'.encode('latin_1'))

        #
        # Explicit latin_1 encoding, comes back as str
        #
        msg = Message(u'hello w\u00f6rld', content_encoding='latin_1')
        self.ch.basic_publish(msg, 'amq.direct', routing_key=my_routing_key)
        msg2 = self.ch.basic_get(qname, no_ack=True)
        self.assertEqual(msg2.content_encoding, 'latin_1')
        self.assertTrue(isinstance(msg2.body, bytes))
        self.assertEqual(msg2.body, u'hello w\xf6rld'.encode('latin_1'))

    def test_exception(self):
        """
        Check that Channel exceptions are actually raised as Python
        exceptions.

        """
        with self.assertRaises(ChannelError):
            self.ch.queue_delete('bogus_queue_that_does_not_exist')

    def test_invalid_header(self):
        """
        Test sending a message with an unserializable object in the header

        http://code.google.com/p/py-amqplib/issues/detail?id=17

        """
        qname, _, _ = self.ch.queue_declare()

        msg = Message(application_headers={'test': None})

        self.assertRaises(
            FrameSyntaxError,
            self.ch.basic_publish,
            msg,
            routing_key=qname,
        )

    def test_large(self):
        """
        Test sending some extra large messages.

        """
        qname, _, _ = self.ch.queue_declare()

        for multiplier in [100, 1000, 10000]:
            msg = Message(
                'funtest message' * multiplier,
                content_type='text/plain',
                application_headers={
                    'foo': 7,
                    'bar': 'baz'
                },
            )

            self.ch.basic_publish(msg, routing_key=qname)

            msg2 = self.ch.basic_get(no_ack=True)
            self.assertEqual(msg, msg2)

    def test_publish(self):
        self.ch.exchange_declare('funtest.fanout', 'fanout', auto_delete=True)

        msg = Message(
            'funtest message',
            content_type='text/plain',
            application_headers={
                'foo': 7,
                'bar': 'baz'
            },
        )

        self.ch.basic_publish(msg, 'funtest.fanout')

    def test_queue(self):
        my_routing_key = 'funtest.test_queue'
        msg = Message(
            'funtest message',
            content_type='text/plain',
            application_headers={
                'foo': 7,
                'bar': 'baz'
            },
        )

        qname, _, _ = self.ch.queue_declare()
        self.ch.queue_bind(qname, 'amq.direct', routing_key=my_routing_key)

        self.ch.basic_publish(msg, 'amq.direct', routing_key=my_routing_key)

        msg2 = self.ch.basic_get(qname, no_ack=True)
        self.assertEqual(msg, msg2)

    def test_unbind(self):
        my_routing_key = 'funtest.test_queue'

        qname, _, _ = self.ch.queue_declare()
        self.ch.queue_bind(qname, 'amq.direct', routing_key=my_routing_key)
        self.ch.queue_unbind(qname, 'amq.direct', routing_key=my_routing_key)

    def test_basic_return(self):
        self.ch.exchange_declare('funtest.fanout', 'fanout', auto_delete=True)

        msg = Message('funtest message',
                      content_type='text/plain',
                      application_headers={
                          'foo': 7,
                          'bar': 'baz'
                      })

        self.ch.basic_publish(msg, 'unittest.fanout')
        self.ch.basic_publish(msg, 'unittest.fanout', mandatory=True)
        self.ch.basic_publish(msg, 'unittest.fanout', mandatory=True)
        self.ch.basic_publish(msg, 'unittest.fanout', mandatory=True)
        self.ch.close()

        #
        # 3 of the 4 messages we sent should have been returned
        #
        self.assertEqual(self.ch.returned_messages.qsize(), 3)

    def test_exchange_bind(self):
        """Test exchange binding.
        Network configuration is as follows (-> is forwards to :
        source_exchange -> dest_exchange -> queue
        The test checks that once the message is publish to the
        destination exchange(unittest.topic_dest) it is delivered to the queue.
        """

        test_routing_key = 'unit_test__key'
        dest_exchange = 'unittest.topic_dest_bind'
        source_exchange = 'unittest.topic_source_bind'

        self.ch.exchange_declare(dest_exchange, 'topic', auto_delete=True)
        self.ch.exchange_declare(source_exchange, 'topic', auto_delete=True)

        qname, _, _ = self.ch.queue_declare()
        self.ch.exchange_bind(destination=dest_exchange,
                              source=source_exchange,
                              routing_key=test_routing_key)

        self.ch.queue_bind(qname, dest_exchange, routing_key=test_routing_key)

        msg = Message('unittest message',
                      content_type='text/plain',
                      application_headers={
                          'foo': 7,
                          'bar': 'baz'
                      })

        self.ch.basic_publish(msg,
                              source_exchange,
                              routing_key=test_routing_key)

        msg2 = self.ch.basic_get(qname, no_ack=True)
        self.assertEqual(msg, msg2)

    def test_exchange_unbind(self):
        dest_exchange = 'unittest.topic_dest_unbind'
        source_exchange = 'unittest.topic_source_unbind'
        test_routing_key = 'unit_test__key'

        self.ch.exchange_declare(dest_exchange, 'topic', auto_delete=True)
        self.ch.exchange_declare(source_exchange, 'topic', auto_delete=True)

        self.ch.exchange_bind(destination=dest_exchange,
                              source=source_exchange,
                              routing_key=test_routing_key)

        self.ch.exchange_unbind(destination=dest_exchange,
                                source=source_exchange,
                                routing_key=test_routing_key)
Beispiel #12
0
class TestConnection(unittest.TestCase):
    def setUp(self):
        self.conn = Connection(**settings.connect_args)

    def tearDown(self):
        if self.conn:
            self.conn.close()

    def test_channel(self):
        ch = self.conn.channel(1)
        self.assertEqual(ch.channel_id, 1)

        ch2 = self.conn.channel()
        self.assertNotEqual(ch2.channel_id, 1)

        ch.close()
        ch2.close()


    def test_close(self):
        """
        Make sure we've broken various references when closing
        channels and connections, to help with GC.

        """
        #
        # Create a channel and make sure it's linked as we'd expect
        #
        ch = self.conn.channel()
        self.assertEqual(1 in self.conn.channels, True)
        self.assertEqual(ch.connection, self.conn)
        self.assertEqual(ch.is_open, True)

        #
        # Close the channel and make sure the references are broken
        # that we expect.
        #
        ch.close()
        self.assertEqual(ch.connection, None)
        self.assertEqual(1 in self.conn.channels, False)
        self.assertEqual(ch.callbacks, {})
        self.assertEqual(ch.is_open, False)

        #
        # Close the connection and make sure the references we expect
        # are gone.
        #
        self.conn.close()
        self.assertEqual(self.conn.connection, None)
        self.assertEqual(self.conn.channels, None)

    def test_gc_closed(self):
        """
        Make sure we've broken various references when closing
        channels and connections, to help with GC.

        gc.garbage: http://docs.python.org/library/gc.html#gc.garbage
            "A list of objects which the collector found to be
            unreachable but could not be freed (uncollectable objects)."
        """
        unreachable_before = len(gc.garbage)
        #
        # Create a channel and make sure it's linked as we'd expect
        #
        ch = self.conn.channel()
        self.assertEqual(1 in self.conn.channels, True)

        #
        # Close the connection and make sure the references we expect
        # are gone.
        #
        self.conn.close()

        gc.collect()
        gc.collect()
        gc.collect()
        self.assertEqual(unreachable_before, len(gc.garbage))

    def test_gc_forget(self):
        """
        Make sure the connection gets gc'ed when there is no more
        references to it.
        """
        unreachable_before = len(gc.garbage)

        ch = self.conn.channel()
        self.assertEqual(1 in self.conn.channels, True)

        # remove all the references
        self.conn = None
        ch = None

        gc.collect()
        gc.collect()
        gc.collect()
        self.assertEqual(unreachable_before, len(gc.garbage))
# a-priori. We keep an account of which partitions are in flight, and make sure
# to return the next free partition when one is requested.
#
from amqp import Connection, Message

partitions = 4096

# sequence number space S = sequence number space in postgres = domain of int8

S_total = 2**64-1
S_start = -2**63
S_end = 2**63 - 1

partsize = S_total / partitions

print("Seeding rabbitmq sequence_space with %ld partitions of %ld size"% (partitions, partsize))

conn = Connection(host='localhost', userid='admin', password='******', virtual_host='/messaging')
channel = conn.channel()

for p in range(0, partitions):
    channel.basic_publish(
            Message(
                "%ld:%ld" % ((S_start + (p * partsize)), (S_start + ((p + 1) * partsize) - 1)),
                content_type='text/plain',
                content_encoding='utf-8',
                delivery_mode=2),
            exchange='sequencer',
            routing_key='pond')
    print "%ld:%ld" % ((S_start + (p * partsize)), (S_start + ((p + 1) * partsize) - 1))
Beispiel #14
0
class Connection(BaseConnection):
    """
    An AMQP broker connection.
    """

    __metaclass__ = ThreadSingleton

    @staticmethod
    def ssl_domain(connector):
        """
        Get SSL properties
        :param connector: A broker object.
        :type connector: Connector
        :return: The SSL properties
        :rtype: dict
        :raise: ValueError
        """
        domain = None
        if connector.use_ssl():
            domain = {}
            connector.ssl.validate()
            if connector.ssl.ca_certificate:
                required = ssl.CERT_REQUIRED
            else:
                required = ssl.CERT_NONE
            domain.update(
                cert_reqs=required,
                ca_certs=connector.ssl.ca_certificate,
                keyfile=connector.ssl.client_key,
                certfile=connector.ssl.client_certificate)
        return domain

    def __init__(self, url):
        """
        :param url: The connector url.
        :type url: str
        """
        BaseConnection.__init__(self, url)
        self._impl = None

    def is_open(self):
        """
        Get whether the connection has been opened.
        :return: True if open.
        :rtype bool
        """
        return self._impl is not None

    @retry(*CONNECTION_EXCEPTIONS)
    def open(self):
        """
        Open a connection to the broker.
        """
        if self.is_open():
            # already open
            return
        connector = Connector.find(self.url)
        host = ':'.join((connector.host, utf8(connector.port)))
        virtual_host = connector.virtual_host or VIRTUAL_HOST
        domain = self.ssl_domain(connector)
        userid = connector.userid or USERID
        password = connector.password or PASSWORD
        log.info('open: %s', connector)
        self._impl = RealConnection(
            host=host,
            virtual_host=virtual_host,
            ssl=domain,
            userid=userid,
            password=password,
            confirm_publish=True)
        log.info('opened: %s', self.url)

    def channel(self):
        """
        Open a channel.
        :return The *real* channel.
        """
        return self._impl.channel()

    def close(self):
        """
        Close the connection.
        """
        connection = self._impl
        self._impl = None
        try:
            connection.close()
            log.info('closed: %s', self.url)
        except Exception, pe:
            log.exception(utf8(pe))
Beispiel #15
0
import requests
from amqp import Message, Connection, ConnectionError, ChannelError
import nb_log
TEST_QUEUE = 'pyrabbit.testq'

connection = Connection(host='localhost:5672',
                        userid='guest',
                        password='******',
                        virtual_host='/')
channel = connection.channel()
channel.queue_delete(TEST_QUEUE)

channel.exchange_declare(TEST_QUEUE, 'direct')
x = channel.queue_declare(TEST_QUEUE)
# self.assertEqual(x.message_count, x[1])
# self.assertEqual(x.consumer_count, x[2])
# self.assertEqual(x.queue, TEST_QUEUE)
channel.queue_bind(TEST_QUEUE, TEST_QUEUE, TEST_QUEUE)