예제 #1
0
 def __init__(self, hostname, username, password, port=5672, **kwargs):
     """
     :param str hostname: Hostname
     :param str username: Username
     :param str password: Password
     :param int port: Server port
     :param str virtual_host: Virtualhost
     :param int heartbeat: RabbitMQ Heartbeat interval
     :param int|float timeout: Socket timeout
     :param bool ssl: Enable SSL
     :param dict ssl_options: SSL kwargs (from ssl.wrap_socket)
     :param bool lazy: Lazy initialize the connection
     :return:
     """
     super(Connection, self).__init__()
     self.parameters = {
         'hostname': hostname,
         'username': username,
         'password': password,
         'port': port,
         'virtual_host': kwargs.get('virtual_host', '/'),
         'heartbeat': kwargs.get('heartbeat', 60),
         'timeout': kwargs.get('timeout', 30),
         'ssl': kwargs.get('ssl', False),
         'ssl_options': kwargs.get('ssl_options', {})
     }
     self._validate_parameters()
     self.heartbeat = Heartbeat(self.parameters['heartbeat'])
     self._io = IO(self.parameters, on_read=self._read_buffer)
     self._channel0 = Channel0(self)
     self._channels = {}
     if not kwargs.get('lazy', False):
         self.open()
예제 #2
0
    def test_io_socket_close(self):
        connection = FakeConnection()
        io = IO(connection.parameters)
        io.socket = Mock(name='socket', spec=socket.socket)
        io.close()

        self.assertIsNone(io.socket)
예제 #3
0
    def test_connection_close(self):
        connection = Connection('127.0.0.1', 'guest', 'guest', lazy=True)
        connection.set_state(connection.OPEN)
        io = IO(connection.parameters, [])
        io.socket = Mock(name='socket', spec=socket.socket)
        connection._io = io

        # Create some fake channels.
        for index in range(10):
            connection._channels[index + 1] = Channel(
                index + 1, connection, 360)

        def on_write(frame_out):
            self.assertIsInstance(frame_out, specification.Connection.Close)
            connection._channel0._close_connection_ok()

        connection._channel0._write_frame = on_write

        self.assertFalse(connection.is_closed)

        connection.close()

        # Make sure all the fake channels were closed as well.
        for index in range(10):
            self.assertNotIn(index + 1, connection._channels)

        self.assertTrue(connection.is_closed)
예제 #4
0
    def test_connection_close_when_already_closed(self):
        connection = Connection('127.0.0.1', 'guest', 'guest', lazy=True)
        connection.set_state(connection.OPEN)
        io = IO(connection.parameters, [])
        io.socket = Mock(name='socket', spec=socket.socket)
        connection._io = io

        connection.set_state(connection.CLOSED)

        # Create some fake channels.
        for index in range(10):
            connection._channels[index + 1] = Channel(
                index + 1, connection, 360)

        def state_set(state):
            self.assertEqual(state, connection.CLOSED)

        connection.set_state = state_set

        self.assertTrue(connection.is_closed)

        connection.close()

        # Make sure all the fake channels were closed as well.
        for index in range(10):
            self.assertNotIn(index + 1, connection._channels)

        self.assertFalse(connection._channels)
        self.assertTrue(connection.is_closed)
예제 #5
0
    def test_connection_close_handles_raise_on_write(self):
        connection = Connection('127.0.0.1', 'guest', 'guest', lazy=True)
        connection.set_state(connection.OPEN)
        io = IO(connection.parameters, [])
        io.socket = Mock(name='socket', spec=socket.socket)
        connection._io = io

        # Create some fake channels.
        for index in range(10):
            connection._channels[index + 1] = Channel(
                index + 1, connection, 360)

        def raise_on_write(_):
            raise AMQPConnectionError('travis-ci')

        connection._channel0._write_frame = raise_on_write

        self.assertFalse(connection.is_closed)

        connection.close()

        # Make sure all the fake channels were closed as well.
        for index in range(10):
            self.assertNotIn(index + 1, connection._channels)

        self.assertFalse(connection._channels)
        self.assertTrue(connection.is_closed)
예제 #6
0
 def test_io_receive_raises_socket_timeout(self):
     connection = FakeConnection()
     io = IO(connection.parameters)
     io.socket = mock.Mock(name='socket', spec=socket.socket)
     io.socket.recv.side_effect = socket.timeout('timeout')
     io._receive()
     self.assertIsNone(connection.check_for_errors())
예제 #7
0
파일: io_tests.py 프로젝트: exg77/amqpstorm
    def test_io_simple_receive(self):
        connection = FakeConnection()
        io = IO(connection.parameters)
        io.socket = MagicMock(name='socket', spec=socket.socket)
        io.socket.recv.return_value = '12345'

        self.assertEqual(io._receive(), '12345')
예제 #8
0
    def test_io_shutdown_with_io_error(self):
        connection = FakeConnection()

        io = IO(connection.parameters)
        io._exceptions = []
        io.socket = mock.Mock(name='socket', spec=socket.socket)
        io.socket.shutdown.side_effect = OSError()
        io._close_socket()
예제 #9
0
    def test_io_receive_does_not_raise_on_block(self):
        connection = FakeConnection()

        io = IO(connection.parameters, exceptions=connection.exceptions)
        io.socket = mock.Mock(name='socket', spec=socket.socket)
        io.socket.recv.side_effect = socket.error(EWOULDBLOCK)
        io._receive()
        self.assertIsNone(connection.check_for_errors())
예제 #10
0
    def test_io_receive_raises_ssl_want_read_error(self):
        connection = FakeConnection()

        io = IO(connection.parameters, exceptions=connection.exceptions)
        io.socket = mock.Mock(name='socket', spec=socket.socket)
        io.socket.recv.side_effect = compatibility.SSLWantReadError()
        io._receive()
        self.assertIsNone(connection.check_for_errors())
예제 #11
0
    def test_io_receive_does_not_raise_on_block(self):
        connection = FakeConnection()

        io = IO(connection.parameters, exceptions=connection.exceptions)
        io.socket = mock.Mock(name='socket', spec=socket.socket)
        io.socket.recv.side_effect = socket.error(EWOULDBLOCK)
        io._receive()
        self.assertIsNone(connection.check_for_errors())
예제 #12
0
    def test_io_simple_send_with_io_error(self):
        connection = FakeConnection()

        io = IO(connection.parameters)
        io._exceptions = []
        io.socket = None
        io.write_to_socket('12345')

        self.assertTrue(io._exceptions)
예제 #13
0
    def test_io_simple_send_with_io_error(self):
        connection = FakeConnection()

        io = IO(connection.parameters)
        io._exceptions = []
        io.socket = None
        io.write_to_socket(self.message)

        self.assertTrue(io._exceptions)
예제 #14
0
    def test_connection_fileno_property(self):
        connection = Connection('127.0.0.1', 'guest', 'guest', lazy=True)
        connection.set_state(connection.OPENING)
        io = IO(connection.parameters, [])
        io.socket = Mock(name='socket', spec=socket.socket)
        connection._io = io
        io.socket.fileno.return_value = 5

        self.assertEqual(connection.fileno, 5)
예제 #15
0
    def test_io_simple_receive_when_socket_not_set(self):
        connection = FakeConnection()
        io = IO(connection.parameters, exceptions=connection.exceptions)

        self.assertFalse(io.use_ssl)

        self.assertEqual(io._receive(), bytes())
        self.assertRaisesRegexp(AMQPConnectionError, 'connection/socket error',
                                connection.check_for_errors)
예제 #16
0
    def test_io_receive_raises_socket_error(self):
        connection = FakeConnection()

        io = IO(connection.parameters, exceptions=connection.exceptions)
        io.socket = mock.Mock(name='socket', spec=socket.socket)
        io.socket.recv.side_effect = socket.error('travis-ci')
        io._receive()
        self.assertRaisesRegexp(AMQPConnectionError, 'travis-ci',
                                connection.check_for_errors)
예제 #17
0
    def test_io_receive_raises_socket_error(self):
        connection = FakeConnection()

        io = IO(connection.parameters)
        io._exceptions = []
        io.socket = MagicMock(name='socket', spec=socket.socket)
        io.socket.recv.side_effect = socket.error('error')
        io._receive()

        self.assertIsInstance(io._exceptions[0], AMQPConnectionError)
예제 #18
0
    def test_connection_wait_for_connection_raises_on_timeout(self):
        connection = Connection('127.0.0.1', 'guest', 'guest', timeout=1,
                                lazy=True)
        connection.set_state(connection.OPENING)
        io = IO(connection.parameters, [])
        io.socket = MagicMock(name='socket', spec=socket.socket)
        connection._io = io

        self.assertRaises(AMQPConnectionError,
                          connection._wait_for_connection_to_open)
예제 #19
0
    def test_connection_wait_for_connection_raises_on_timeout(self):
        connection = Connection('127.0.0.1', 'guest', 'guest', timeout=0.1,
                                lazy=True)
        connection.set_state(connection.OPENING)
        io = IO(connection.parameters, [])
        io.socket = MagicMock(name='socket', spec=socket.socket)
        connection._io = io

        self.assertRaises(AMQPConnectionError,
                          connection._wait_for_connection_to_open)
예제 #20
0
    def test_io_simple_send_with_error(self):
        connection = FakeConnection()

        io = IO(connection.parameters)
        io._exceptions = []
        io.socket = mock.Mock(name='socket', spec=socket.socket)
        io.socket.send.side_effect = socket.error('error')
        io.write_to_socket(self.message)

        self.assertIsInstance(io._exceptions[0], AMQPConnectionError)
예제 #21
0
파일: io_tests.py 프로젝트: exg77/amqpstorm
    def test_io_receive_raises_socket_error(self):
        connection = FakeConnection()

        io = IO(connection.parameters)
        io._exceptions = []
        io.socket = MagicMock(name='socket', spec=socket.socket)
        io.socket.recv.side_effect = socket.error('error')
        io._receive()

        self.assertIsInstance(io._exceptions[0], AMQPConnectionError)
예제 #22
0
    def test_io_simple_send_zero_bytes_sent(self):
        connection = FakeConnection()

        io = IO(connection.parameters, exceptions=connection.exceptions)
        io.socket = Mock(name='socket', spec=socket.socket)
        io.socket.send.return_value = 0
        io.write_to_socket(self.message)

        self.assertRaisesRegexp(AMQPConnectionError, 'connection/socket error',
                                connection.check_for_errors)
예제 #23
0
    def test_io_simple_send_zero_bytes_sent(self):
        connection = FakeConnection()

        io = IO(connection.parameters)
        io._exceptions = []
        io.socket = MagicMock(name='socket', spec=socket.socket)
        io.poller = MagicMock(name='poller', spec=amqpstorm.io.Poller)
        io.socket.send.return_value = 0
        io.write_to_socket('afasffa')

        self.assertIsInstance(io._exceptions[0], AMQPConnectionError)
예제 #24
0
파일: io_tests.py 프로젝트: exg77/amqpstorm
    def test_io_simple_send_with_error(self):
        connection = FakeConnection()

        io = IO(connection.parameters)
        io._exceptions = []
        io.socket = MagicMock(name='socket', spec=socket.socket)
        io.poller = MagicMock(name='poller', spec=amqpstorm.io.Poller)
        io.socket.send.side_effect = socket.error('error')
        io.write_to_socket('12345')

        self.assertIsInstance(io._exceptions[0], AMQPConnectionError)
예제 #25
0
파일: io_tests.py 프로젝트: exg77/amqpstorm
    def test_io_simple_send_zero_bytes_sent(self):
        connection = FakeConnection()

        io = IO(connection.parameters)
        io._exceptions = []
        io.socket = MagicMock(name='socket', spec=socket.socket)
        io.poller = MagicMock(name='poller', spec=amqpstorm.io.Poller)
        io.socket.send.return_value = 0
        io.write_to_socket('afasffa')

        self.assertIsInstance(io._exceptions[0], AMQPConnectionError)
예제 #26
0
    def test_io_simple_send_with_error(self):
        connection = FakeConnection()

        io = IO(connection.parameters)
        io._exceptions = []
        io.socket = MagicMock(name='socket', spec=socket.socket)
        io.poller = MagicMock(name='poller', spec=amqpstorm.io.Poller)
        io.socket.send.side_effect = socket.error('error')
        io.write_to_socket('12345')

        self.assertIsInstance(io._exceptions[0], AMQPConnectionError)
예제 #27
0
    def test_io_receive_raises_socket_error(self):
        connection = FakeConnection()

        io = IO(connection.parameters, exceptions=connection.exceptions)
        io.socket = mock.Mock(name='socket', spec=socket.socket)
        io.socket.recv.side_effect = socket.error('travis-ci')
        io._receive()
        self.assertRaisesRegexp(
            AMQPConnectionError,
            'travis-ci',
            connection.check_for_errors
        )
예제 #28
0
    def test_io_simple_receive_when_socket_not_set(self):
        connection = FakeConnection()
        io = IO(connection.parameters, exceptions=connection.exceptions)

        self.assertFalse(io.use_ssl)

        self.assertEqual(io._receive(), bytes())
        self.assertRaisesRegexp(
            AMQPConnectionError,
            'connection/socket error',
            connection.check_for_errors
        )
예제 #29
0
    def test_io_simple_send_zero_bytes_sent(self):
        connection = FakeConnection()

        io = IO(connection.parameters, exceptions=connection.exceptions)
        io.socket = Mock(name='socket', spec=socket.socket)
        io.socket.send.return_value = 0
        io.write_to_socket(self.message)

        self.assertRaisesRegexp(
            AMQPConnectionError,
            'connection/socket error',
            connection.check_for_errors
        )
예제 #30
0
    def test_io_simple_ssl_receive(self):
        connection = FakeConnection()
        connection.parameters['ssl'] = True
        io = IO(connection.parameters)

        self.assertTrue(io.use_ssl)

        if hasattr(ssl, 'SSLObject'):
            io.socket = Mock(name='socket', spec=ssl.SSLObject)
        elif hasattr(ssl, 'SSLSocket'):
            io.socket = Mock(name='socket', spec=ssl.SSLSocket)

        io.socket.read.return_value = '12345'

        self.assertEqual(io._receive(), '12345')
예제 #31
0
    def test_io_set_ssl_verify_req(self):
        connection = FakeConnection()
        connection.parameters['ssl_options'] = {'verify_mode': 'required'}

        io = IO(connection.parameters)
        sock = io._ssl_wrap_socket(socket.socket())
        self.assertEqual(sock.context.verify_mode, ssl.CERT_REQUIRED)
예제 #32
0
 def __init__(self, hostname, username, password, port=5672, **kwargs):
     """
     :param str hostname: Hostname
     :param str username: Username
     :param str password: Password
     :param int port: Server port
     :param str virtual_host: Virtualhost
     :param int heartbeat: RabbitMQ Heartbeat interval
     :param int|float timeout: Socket timeout
     :param bool ssl: Enable SSL
     :param dict ssl_options: SSL kwargs (from ssl.wrap_socket)
     :param bool lazy: Lazy initialize the connection
     :return:
     """
     super(Connection, self).__init__()
     self.parameters = {
         'hostname': hostname,
         'username': username,
         'password': password,
         'port': port,
         'virtual_host': kwargs.get('virtual_host', '/'),
         'heartbeat': kwargs.get('heartbeat', 60),
         'timeout': kwargs.get('timeout', 30),
         'ssl': kwargs.get('ssl', False),
         'ssl_options': kwargs.get('ssl_options', {})
     }
     self._validate_parameters()
     self.heartbeat = Heartbeat(self.parameters['heartbeat'])
     self._io = IO(self.parameters, on_read=self._read_buffer)
     self._channel0 = Channel0(self)
     self._channels = {}
     if not kwargs.get('lazy', False):
         self.open()
예제 #33
0
    def __init__(self, hostname, username, password, port=5672, **kwargs):
        """Create a new instance of the Connection class.

        :param str hostname:
        :param str username:
        :param str password:
        :param int port:
        :param str virtual_host:
        :param int heartbeat: RabbitMQ Heartbeat interval
        :param int|float timeout: Socket timeout
        :param bool ssl: Enable SSL
        :param dict ssl_options: SSL Kwargs
        :return:
        """
        super(Connection, self).__init__()
        self.parameters = {
            'hostname': hostname,
            'username': username,
            'password': password,
            'port': port,
            'virtual_host': kwargs.get('virtual_host', '/'),
            'heartbeat': kwargs.get('heartbeat', 60),
            'timeout': kwargs.get('timeout', 0),
            'ssl': kwargs.get('ssl', False),
            'ssl_options': kwargs.get('ssl_options', {})
        }
        self.io = IO(self.parameters,
                     on_read=self._read_buffer,
                     on_error=self._handle_socket_error)
        self._channel0 = Channel0(self)
        self._channels = {}
        self._validate_parameters()
        self.open()
예제 #34
0
    def test_connection_wait_for_connection(self):
        connection = Connection('127.0.0.1', 'guest', 'guest', lazy=True)
        connection.set_state(connection.OPENING)
        io = IO(connection.parameters, [])
        io.socket = Mock(name='socket', spec=socket.socket)
        connection._io = io

        self.assertFalse(connection.is_open)

        def set_state_to_open(conn):
            conn.set_state(conn.OPEN)

        threading.Timer(function=set_state_to_open,
                        interval=0.1, args=(connection,)).start()
        connection._wait_for_connection_state(connection.OPEN)

        self.assertTrue(connection.is_open)
예제 #35
0
 def test_io_raises_gaierror(self, _):
     connection = FakeConnection()
     connection.parameters['hostname'] = 'localhost'
     connection.parameters['port'] = 1234
     parameters = connection.parameters
     io = IO(parameters)
     self.assertRaisesRegexp(AMQPConnectionError, 'could not connect',
                             io._get_socket_addresses)
예제 #36
0
 def test_io_normal_connection_without_ssl_library(self, _):
     connection = FakeConnection()
     connection.parameters['hostname'] = 'localhost'
     connection.parameters['port'] = 1234
     parameters = connection.parameters
     io = IO(parameters)
     self.assertRaisesRegexp(AMQPConnectionError,
                             'Could not connect to localhost:1234', io.open)
예제 #37
0
    def test_io_set_ssl_context(self):
        connection = FakeConnection()
        connection.parameters['ssl_options'] = {
            'context': ssl.create_default_context(),
            'server_hostname': 'localhost',
        }

        io = IO(connection.parameters)
        self.assertTrue(io._ssl_wrap_socket(socket.socket()))
예제 #38
0
    def test_io_get_socket_address(self):
        connection = FakeConnection()
        connection.parameters['hostname'] = '127.0.0.1'
        connection.parameters['port'] = 5672
        io = IO(connection.parameters)
        addresses = io._get_socket_addresses()
        sock_address_tuple = addresses[0]

        self.assertEqual(sock_address_tuple[4], ('127.0.0.1', 5672))
예제 #39
0
    def test_io_simple_send_with_timeout_error(self):
        connection = FakeConnection()
        self.raised = False

        def custom_raise(*_):
            if self.raised:
                return 1
            self.raised = True
            raise socket.timeout()

        io = IO(connection.parameters)
        io._exceptions = []
        io.socket = mock.Mock(name='socket', spec=socket.socket)
        io.socket.send.side_effect = custom_raise
        io.write_to_socket(self.message)

        self.assertTrue(self.raised)
        self.assertFalse(io._exceptions)
예제 #40
0
    def test_io_socket_read_fails_with_ssl(self):
        connection = FakeConnection()
        parameters = FakeConnection().parameters
        parameters['ssl'] = True
        io = IO(parameters, exceptions=connection.exceptions)

        self.assertTrue(io.use_ssl)

        self.assertRaisesRegexp(socket.error, 'connection/socket error',
                                io._read_from_socket)
예제 #41
0
    def test_io_simple_send_with_timeout_error(self):
        connection = FakeConnection()
        self.raised = False

        def custom_raise(*args, **kwargs):
            if self.raised:
                return 1
            self.raised = True
            raise socket.timeout()

        io = IO(connection.parameters)
        io._exceptions = []
        io.socket = MagicMock(name='socket', spec=socket.socket)
        io.poller = MagicMock(name='poller', spec=amqpstorm.io.Poller)
        io.socket.send.side_effect = custom_raise
        io.write_to_socket('12345')

        self.assertTrue(self.raised)
        self.assertFalse(io._exceptions)
예제 #42
0
    def __init__(self, hostname, username, password, port=5672, **kwargs):
        """
        :param str hostname: Hostname
        :param str username: Username
        :param str password: Password
        :param int port: Server port
        :param str virtual_host: Virtual host
        :param int heartbeat: RabbitMQ Heartbeat interval
        :param int|float timeout: Socket timeout
        :param bool ssl: Enable SSL
        :param dict ssl_options: SSL kwargs (from ssl.wrap_socket)
        :param dict client_properties: None or dict of client properties
        :param bool lazy: Lazy initialize the connection

        :raises AMQPConnectionError: Raises if the connection
                                     encountered an error.

        :return:
        """
        super(Connection, self).__init__()
        self.parameters = {
            'hostname': hostname,
            'username': username,
            'password': password,
            'port': port,
            'virtual_host': kwargs.get('virtual_host', DEFAULT_VIRTUAL_HOST),
            'heartbeat': kwargs.get('heartbeat', DEFAULT_HEARTBEAT_INTERVAL),
            'timeout': kwargs.get('timeout', DEFAULT_SOCKET_TIMEOUT),
            'ssl': kwargs.get('ssl', False),
            'ssl_options': kwargs.get('ssl_options', {}),
            'client_properties': kwargs.get('client_properties', {})
        }
        self._validate_parameters()
        self._io = IO(self.parameters,
                      exceptions=self._exceptions,
                      on_read_impl=self._read_buffer)
        self._channel0 = Channel0(self, self.parameters['client_properties'])
        self._channels = {}
        self._last_channel_id = None
        self.heartbeat = Heartbeat(self.parameters['heartbeat'],
                                   self._channel0.send_heartbeat)
        if not kwargs.get('lazy', False):
            self.open()
예제 #43
0
    def test_connection_open(self):
        connection = Connection('127.0.0.1', 'guest', 'guest', lazy=True)
        io = IO(connection.parameters, [])
        io.socket = Mock(name='socket', spec=socket.socket)
        connection._io = io

        def open():
            pass

        def on_write_to_socket(_):
            connection.set_state(connection.OPEN)

        connection._io.open = open
        connection._io.write_to_socket = on_write_to_socket

        self.assertTrue(connection.is_closed)

        connection.open()

        self.assertTrue(connection.is_open)
예제 #44
0
    def __init__(self, hostname, username, password, port=5672, **kwargs):
        """
        :param str hostname: Hostname
        :param str username: Username
        :param str password: Password
        :param int port: Server port
        :param str virtual_host: Virtual host
        :param int heartbeat: RabbitMQ Heartbeat interval
        :param int|float timeout: Socket timeout
        :param bool ssl: Enable SSL
        :param dict ssl_options: SSL kwargs (from ssl.wrap_socket)
        :param bool lazy: Lazy initialize the connection

        :raises AMQPConnectionError: Raises if the connection
                                     encountered an error.

        :return:
        """
        super(Connection, self).__init__()
        self.parameters = {
            'hostname': hostname,
            'username': username,
            'password': password,
            'port': port,
            'virtual_host': kwargs.get('virtual_host', DEFAULT_VIRTUAL_HOST),
            'heartbeat': kwargs.get('heartbeat', DEFAULT_HEARTBEAT_INTERVAL),
            'timeout': kwargs.get('timeout', DEFAULT_SOCKET_TIMEOUT),
            'ssl': kwargs.get('ssl', False),
            'ssl_options': kwargs.get('ssl_options', {})
        }
        self._validate_parameters()
        self._io = IO(self.parameters, exceptions=self._exceptions,
                      on_read_impl=self._read_buffer)
        self._channel0 = Channel0(self)
        self._channels = {}
        self._last_channel_id = None
        self.heartbeat = Heartbeat(self.parameters['heartbeat'],
                                   self._channel0.send_heartbeat)
        if not kwargs.get('lazy', False):
            self.open()
예제 #45
0
class Connection(Stateful):
    """RabbitMQ Connection Class."""

    def __init__(self, hostname, username, password, port=5672, **kwargs):
        """Create a new instance of the Connection class.

        :param str hostname:
        :param str username:
        :param str password:
        :param int port:
        :param str virtual_host:
        :param int heartbeat: RabbitMQ Heartbeat interval
        :param int|float timeout: Socket timeout
        :param bool ssl: Enable SSL
        :param dict ssl_options: SSL Kwargs
        :return:
        """
        super(Connection, self).__init__()
        self.parameters = {
            'hostname': hostname,
            'username': username,
            'password': password,
            'port': port,
            'virtual_host': kwargs.get('virtual_host', '/'),
            'heartbeat': kwargs.get('heartbeat', 60),
            'timeout': kwargs.get('timeout', 0),
            'ssl': kwargs.get('ssl', False),
            'ssl_options': kwargs.get('ssl_options', {})
        }
        self.io = IO(self.parameters,
                     on_read=self._read_buffer,
                     on_error=self._handle_socket_error)
        self._channel0 = Channel0(self)
        self._channels = {}
        self._validate_parameters()
        self.open()

    def __enter__(self):
        return self

    def __exit__(self, exception_type, exception_value, _):
        if exception_value:
            message = 'Closing connection due to an unhandled exception: {0!s}'
            LOGGER.warning(message.format(exception_type))
        self.close()

    @property
    def is_blocked(self):
        """Is the connection currently being blocked from publishing by
        the remote server.

        :rtype: bool
        """
        return self._channel0.is_blocked

    @property
    def server_properties(self):
        """Returns the RabbitMQ Server properties.

        :rtype: dict
        """
        return self._channel0.server_properties

    @property
    def socket(self):
        """Returns an instance of the socket.

        :return:
        """
        return self.io.socket

    @property
    def fileno(self):
        """Socket Fileno.

        :return:
        """
        return self.io.socket.fileno

    def open(self):
        """Open Connection."""
        LOGGER.debug('Connection Opening.')
        self._exceptions = []
        self.set_state(self.OPENING)
        self.io.open(self.parameters['hostname'],
                     self.parameters['port'])
        self._send_handshake()
        while not self.is_open:
            self.check_for_errors()
            sleep(IDLE_WAIT)
        LOGGER.debug('Connection Opened.')

    def close(self):
        """Close connection."""
        LOGGER.debug('Connection Closing.')
        if not self.is_closed and self.io.socket:
            self._close_channels()
            self.set_state(self.CLOSING)
            self._channel0.send_close_connection_frame()
        self.io.close()
        self.set_state(self.CLOSED)
        LOGGER.debug('Connection Closed.')

    def channel(self, rpc_timeout=360):
        """Open Channel."""
        LOGGER.debug('Opening new Channel.')
        if not compatibility.is_integer(rpc_timeout):
            raise AMQPInvalidArgument('rpc_timeout should be an integer')
        with self.io.lock:
            channel_id = len(self._channels) + 1
            channel = Channel(channel_id, self, rpc_timeout)
            self._channels[channel_id] = channel
            channel.open()
        LOGGER.debug('Channel #%d Opened.', channel_id)
        return self._channels[channel_id]

    def check_for_errors(self):
        """Check connection for potential errors.

        :return:
        """
        if not self.io.socket:
            self._handle_socket_error('socket/connection closed')
        super(Connection, self).check_for_errors()

    def write_frame(self, channel_id, frame_out):
        """Marshal and write an outgoing pamqp frame to the socket.

        :param int channel_id:
        :param pamqp_spec.Frame frame_out: Amqp frame.
        :return:
        """
        frame_data = pamqp_frame.marshal(frame_out, channel_id)
        self.io.write_to_socket(frame_data)

    def write_frames(self, channel_id, multiple_frames):
        """Marshal and write multiple outgoing pamqp frames to the socket.

        :param int channel_id:
        :param list multiple_frames: Amqp frames.
        :return:
        """
        frame_data = EMPTY_BUFFER
        for single_frame in multiple_frames:
            frame_data += pamqp_frame.marshal(single_frame, channel_id)
        self.io.write_to_socket(frame_data)

    def _validate_parameters(self):
        """Validate Connection Parameters.

        :return:
        """
        if not compatibility.is_string(self.parameters['hostname']):
            raise AMQPInvalidArgument('hostname should be a string')
        elif not compatibility.is_integer(self.parameters['port']):
            raise AMQPInvalidArgument('port should be an integer')
        elif not compatibility.is_string(self.parameters['username']):
            raise AMQPInvalidArgument('username should be a string')
        elif not compatibility.is_string(self.parameters['password']):
            raise AMQPInvalidArgument('password should be a string')
        elif not compatibility.is_string(self.parameters['virtual_host']):
            raise AMQPInvalidArgument('virtual_host should be a string')
        elif not isinstance(self.parameters['timeout'], (int, float)):
            raise AMQPInvalidArgument('timeout should be an integer or float')
        elif not compatibility.is_integer(self.parameters['heartbeat']):
            raise AMQPInvalidArgument('heartbeat should be an integer')

    def _send_handshake(self):
        """Send RabbitMQ Handshake.

        :return:
        """
        self.io.write_to_socket(pamqp_header.ProtocolHeader().marshal())

    def _read_buffer(self, buffer):
        """Process the socket buffer, and direct the data to the correct
        channel.

        :return:
        """
        while buffer:
            buffer, channel_id, frame_in = \
                self._handle_amqp_frame(buffer)

            if frame_in is None:
                break

            if channel_id == 0:
                self._channel0.on_frame(frame_in)
            else:
                self._channels[channel_id].on_frame(frame_in)

        return buffer

    @staticmethod
    def _handle_amqp_frame(data_in):
        """Unmarshal any incoming RabbitMQ frames and return the result.

        :param data_in: socket data
        :return: buffer, channel_id, frame
        """
        if not data_in:
            return data_in, None, None
        try:
            byte_count, channel_id, frame_in = pamqp_frame.unmarshal(data_in)
            return data_in[byte_count:], channel_id, frame_in
        except pamqp_exception.UnmarshalingException:
            return data_in, None, None
        except pamqp_spec.AMQPFrameError as why:
            LOGGER.error('AMQPFrameError: %r', why, exc_info=True)
            return data_in, None, None

    def _close_channels(self):
        """Close any open channels.

        :return:
        """
        for channel_id in self._channels:
            if not self._channels[channel_id].is_open:
                continue
            self._channels[channel_id].close()

    def _handle_socket_error(self, why):
        """Handle any critical errors.

        :param exception why:
        :return:
        """
        previous_state = self._state
        self.set_state(self.CLOSED)
        if previous_state != self.CLOSED:
            LOGGER.error(why, exc_info=False)
        self.io.close()
        self._exceptions.append(AMQPConnectionError(why))
예제 #46
0
class Connection(Stateful):
    """RabbitMQ Connection."""
    __slots__ = [
        'heartbeat', 'parameters', '_channel0', '_channels', '_io'
    ]

    def __init__(self, hostname, username, password, port=5672, **kwargs):
        """
        :param str hostname: Hostname
        :param str username: Username
        :param str password: Password
        :param int port: Server port
        :param str virtual_host: Virtual host
        :param int heartbeat: RabbitMQ Heartbeat interval
        :param int|float timeout: Socket timeout
        :param bool ssl: Enable SSL
        :param dict ssl_options: SSL kwargs (from ssl.wrap_socket)
        :param bool lazy: Lazy initialize the connection

        :raises AMQPConnectionError: Raises if the connection
                                     encountered an error.

        :return:
        """
        super(Connection, self).__init__()
        self.parameters = {
            'hostname': hostname,
            'username': username,
            'password': password,
            'port': port,
            'virtual_host': kwargs.get('virtual_host', '/'),
            'heartbeat': kwargs.get('heartbeat', 60),
            'timeout': kwargs.get('timeout', 10),
            'ssl': kwargs.get('ssl', False),
            'ssl_options': kwargs.get('ssl_options', {})
        }
        self._validate_parameters()
        self._io = IO(self.parameters, exceptions=self._exceptions,
                      on_read=self._read_buffer)
        self._channel0 = Channel0(self)
        self._channels = {}
        self.heartbeat = Heartbeat(self.parameters['heartbeat'],
                                   self._channel0.send_heartbeat)
        if not kwargs.get('lazy', False):
            self.open()

    def __enter__(self):
        return self

    def __exit__(self, exception_type, exception_value, _):
        if exception_type:
            message = 'Closing connection due to an unhandled exception: %s'
            LOGGER.warning(message, exception_value)
        self.close()

    @property
    def fileno(self):
        """Returns the Socket File number.

        :return:
        """
        if not self._io.socket:
            return None
        return self._io.socket.fileno()

    @property
    def is_blocked(self):
        """Is the connection currently being blocked from publishing by
        the remote server.

        :rtype: bool
        """
        return self._channel0.is_blocked

    @property
    def server_properties(self):
        """Returns the RabbitMQ Server Properties.

        :rtype: dict
        """
        return self._channel0.server_properties

    @property
    def socket(self):
        """Returns an instance of the Socket used by the Connection.

        :rtype: socket
        """
        return self._io.socket

    def channel(self, rpc_timeout=60):
        """Open Channel.

        :param int rpc_timeout: Timeout before we give up waiting for an RPC
                                response from the server.

        :raises AMQPInvalidArgument: Invalid Parameters
        :raises AMQPChannelError: Raises if the channel encountered an error.
        :raises AMQPConnectionError: Raises if the connection
                                     encountered an error.
        """
        LOGGER.debug('Opening a new Channel')
        if not compatibility.is_integer(rpc_timeout):
            raise AMQPInvalidArgument('rpc_timeout should be an integer')
        elif self.is_closed:
            raise AMQPConnectionError('socket/connection closed')

        with self.lock:
            channel_id = len(self._channels) + 1
            channel = Channel(channel_id, self, rpc_timeout)
            self._channels[channel_id] = channel
            channel.open()
        LOGGER.debug('Channel #%d Opened', channel_id)
        return self._channels[channel_id]

    def check_for_errors(self):
        """Check Connection for errors.

        :raises AMQPConnectionError: Raises if the connection
                                     encountered an error.
        :return:
        """
        if not self.exceptions:
            if not self.is_closed:
                return
            why = AMQPConnectionError('connection was closed')
            self.exceptions.append(why)
        self.set_state(self.CLOSED)
        self.close()
        raise self.exceptions[0]

    def close(self):
        """Close connection.

        :raises AMQPConnectionError: Raises if the connection
                                     encountered an error.
        :return:
        """
        LOGGER.debug('Connection Closing')
        if not self.is_closed:
            self.set_state(self.CLOSING)
        self.heartbeat.stop()
        try:
            self._close_remaining_channels()
            if not self.is_closed and self.socket:
                self._channel0.send_close_connection()
                self._wait_for_connection_state(state=Stateful.CLOSED)
        except AMQPConnectionError:
            pass
        finally:
            self._io.close()
            self.set_state(self.CLOSED)
        LOGGER.debug('Connection Closed')

    def kill(self):
        for channel in self._channels.items():
            channel.remove_consumer_tag()

        self._io.kill()


    def open(self):
        """Open Connection.

        :raises AMQPConnectionError: Raises if the connection
                                     encountered an error.
        """
        LOGGER.debug('Connection Opening')
        self.set_state(self.OPENING)
        self._exceptions = []
        self._channels = {}
        self._io.open()
        self._send_handshake()
        self._wait_for_connection_state(state=Stateful.OPEN)
        self.heartbeat.start(self._exceptions)
        LOGGER.debug('Connection Opened')

    def write_frame(self, channel_id, frame_out):
        """Marshal and write an outgoing pamqp frame to the Socket.

        :param int channel_id: Channel ID.
        :param pamqp_spec.Frame frame_out: Amqp frame.

        :return:
        """
        frame_data = pamqp_frame.marshal(frame_out, channel_id)
        self.heartbeat.register_write()
        self._io.write_to_socket(frame_data)

    def write_frames(self, channel_id, frames_out):
        """Marshal and write multiple outgoing pamqp frames to the Socket.

        :param int channel_id: Channel ID/
        :param list frames_out: Amqp frames.

        :return:
        """
        data_out = EMPTY_BUFFER
        for single_frame in frames_out:
            data_out += pamqp_frame.marshal(single_frame, channel_id)
        self.heartbeat.register_write()
        self._io.write_to_socket(data_out)

    def _close_remaining_channels(self):
        """Close any open channels.

        :return:
        """
        for channel_id in self._channels:
            if not self._channels[channel_id].is_open:
                continue
            self._channels[channel_id].set_state(Channel.CLOSED)
            self._channels[channel_id].close()

    def _handle_amqp_frame(self, data_in):
        """Unmarshal a single AMQP frame and return the result.

        :param data_in: socket data
        :return: data_in, channel_id, frame
        """
        if not data_in:
            return data_in, None, None
        try:
            byte_count, channel_id, frame_in = pamqp_frame.unmarshal(data_in)
            return data_in[byte_count:], channel_id, frame_in
        except pamqp_exception.UnmarshalingException:
            pass
        except pamqp_spec.AMQPFrameError as why:
            LOGGER.error('AMQPFrameError: %r', why, exc_info=True)
        except ValueError as why:
            LOGGER.error(why, exc_info=True)
            self.exceptions.append(AMQPConnectionError(why))
        return data_in, None, None

    def _read_buffer(self, data_in):
        """Process the socket buffer, and direct the data to the appropriate
        channel.

        :rtype: bytes
        """
        while data_in:
            data_in, channel_id, frame_in = \
                self._handle_amqp_frame(data_in)

            if frame_in is None:
                break

            self.heartbeat.register_read()
            if channel_id == 0:
                self._channel0.on_frame(frame_in)
            else:
                self._channels[channel_id].on_frame(frame_in)

        return data_in

    def _send_handshake(self):
        """Send a RabbitMQ Handshake.

        :return:
        """
        self._io.write_to_socket(pamqp_header.ProtocolHeader().marshal())

    def _validate_parameters(self):
        """Validate Connection Parameters.

        :return:
        """
        if not compatibility.is_string(self.parameters['hostname']):
            raise AMQPInvalidArgument('hostname should be a string')
        elif not compatibility.is_integer(self.parameters['port']):
            raise AMQPInvalidArgument('port should be an integer')
        elif not compatibility.is_string(self.parameters['username']):
            raise AMQPInvalidArgument('username should be a string')
        elif not compatibility.is_string(self.parameters['password']):
            raise AMQPInvalidArgument('password should be a string')
        elif not compatibility.is_string(self.parameters['virtual_host']):
            raise AMQPInvalidArgument('virtual_host should be a string')
        elif not isinstance(self.parameters['timeout'], (int, float)):
            raise AMQPInvalidArgument('timeout should be an integer or float')
        elif not compatibility.is_integer(self.parameters['heartbeat']):
            raise AMQPInvalidArgument('heartbeat should be an integer')

    def _wait_for_connection_state(self, state=Stateful.OPEN):
        """Wait for a Connection state.

        :param int state: State that we expect

        :raises AMQPConnectionError: Raises if we reach the connection timeout.

        :return:
        """
        start_time = time.time()
        timeout = (self.parameters['timeout'] or 10) * 3
        while self.current_state != state:
            self.check_for_errors()
            if time.time() - start_time > timeout:
                raise AMQPConnectionError('Connection timed out')
            sleep(IDLE_WAIT)
예제 #47
0
 def test_io_receive_raises_socket_timeout(self):
     connection = FakeConnection()
     io = IO(connection.parameters)
     io.socket = mock.Mock(name='socket', spec=socket.socket)
     io.socket.recv.side_effect = socket.timeout('timeout')
     io._receive()