def create_sent_message(self, *args, **kwargs): message = Message(*args, **kwargs) if not message.client_id and self.client.client_id: message.client_id = self.client.client_id if not message.id: message.id = str(self.client.message_id) return message
def test_from_message(self): message = Message() failure_message = FailureMessage.from_message(message) assert isinstance(failure_message, FailureMessage) assert failure_message.exception is None assert failure_message == { 'channel': None, 'id': None, 'request': message, 'successful': False, 'exception': None, 'advice': { 'reconnect': 'none', 'interval': 0 } } message = Message(channel='/test', id='1') exception = Exception() failure_message = FailureMessage.from_message(message, exception=exception, successful=True) assert isinstance(failure_message, FailureMessage) assert failure_message == { 'channel': '/test', 'id': '1', 'request': message, 'successful': True, 'exception': exception, 'advice': { 'reconnect': 'none', 'interval': 0 } }
def test_receive_handshake(self): message = Message(channel=ChannelId.META_HANDSHAKE) assert self.extension.receive(message) == message assert not self.extension.server_supports_acks message.ext = {AckExtension.FIELD_ACK: True} assert self.extension.receive(message) == message assert self.extension.server_supports_acks
def test_receive_handshake(self): message = Message(channel=ChannelId.META_HANDSHAKE) assert self.extension.receive(message) is message assert not self.extension.server_supports_acks message.ext = {AckExtension.FIELD_ACK: True} assert self.extension.receive(message) is message assert self.extension.server_supports_acks
def test_update(self): message = Message(channel='/bad1') assert message['channel'] == '/bad1' message.update({'channel': '/bad2', 'id': '1'}, channel='/test') assert message == {'channel': '/test', 'id': '1'} self.assertRaises(TypeError, message.update, {'channel': '/bad3'}, {'id': '2'}) assert message == {'channel': '/test', 'id': '1'}
def test_init(self): message = Message() assert message == {} message = Message({'channel': '/bad', 'id': '1'}, {'channel': '/test'}) assert message == {'channel': '/test', 'id': '1'} message = Message(channel='/test', id='1') assert message == {'channel': '/test', 'id': '1'} message = Message({'channel': '/bad', 'id': '1'}, channel='/test') assert message == {'channel': '/test', 'id': '1'}
def test_failure(self): message = Message() assert message.failure message = Message(successful=True) assert not message.failure message = Message(successful=False) assert message.failure message['successful'] = True assert not message.failure
def test_copy(self): message = Message() message_copy = message.copy() assert message == message_copy assert isinstance(message_copy, Message) message = Message(channel='/test', id='1', ext={'ack': True}) message_copy = message.copy() assert message == message_copy assert isinstance(message_copy, Message)
def test_to_json(self): assert Message.to_json([]) == dumps([]) message = Message(channel='/test', id='1') assert Message.to_json(message) == dumps([message]) messages = [ Message(channel='/test1', id='1'), Message(channel='/test2', id='2') ] assert Message.to_json(messages) == dumps(messages)
def test_channel(self): message = Message({'channel': '/test'}) assert isinstance(message.channel, ChannelId) message = Message(channel='/test') assert isinstance(message.channel, ChannelId) message = Message() message['channel'] = '/test' assert isinstance(message.channel, ChannelId) message = Message() message.channel = '/test' assert isinstance(message.channel, ChannelId)
def connect_client(self, client_id='client-1'): self.client.handshake() self.transport.receive([ Message(channel=ChannelId.META_HANDSHAKE, successful=True, client_id=client_id, supported_connection_types=[self.transport.name], version=Client.BAYEUX_VERSION) ]) self.transport.receive( [Message(channel=ChannelId.META_CONNECT, successful=True)]) self.transport.clear_sent_messages()
def test_fail_messages_connect(self): self.connect_client() mock_message_1 = Message( channel=ChannelId.META_CONNECT, client_id=self.client.client_id, connection_type=self.transport.name, advice={Message.FIELD_TIMEOUT: 0} ) mock_message_2 = mock_message_1.copy() exception = Exception() with nested( self.capture_messages(), self.capture_timeouts() ) as (messages, timeouts): self.client.fail_messages([mock_message_1]) self.client.fail_messages([mock_message_2], exception=exception) self.check_failure_messages(messages, { ChannelId.META_CONNECT: [ FailureMessage.from_message( mock_message_1, advice={ FailureMessage.FIELD_RECONNECT: FailureMessage.RECONNECT_RETRY, FailureMessage.FIELD_INTERVAL: 0 } ), FailureMessage.from_message( mock_message_2, exception=exception, advice={ FailureMessage.FIELD_RECONNECT: FailureMessage.RECONNECT_RETRY, FailureMessage.FIELD_INTERVAL: self.DEFAULT_OPTIONS['backoff_period_increment'] } ) ] }) # Make sure a single delayed connect was scheduled self.transport.clear_sent_messages() assert len(timeouts) == 1 assert timeouts[0].deadline == timedelta( milliseconds=self.DEFAULT_OPTIONS['backoff_period_increment'] * 2 ) timeouts[0].callback() assert self.transport.sent_messages == [ self.create_sent_message( channel=ChannelId.META_CONNECT, connection_type=self.transport.name, advice={ Message.FIELD_TIMEOUT: 0 } ) ]
def test_setdefault(self): message = Message() self.assertRaises(KeyError, message.__getitem__, 'ext') assert message.setdefault('ext') is None assert message['ext'] is None message = Message() assert message.setdefault('ext', {}) == {} assert message['ext'] == {} message = Message(ext={'ack': 1}) assert message.setdefault('ext', {}) == {'ack': 1} assert message['ext'] == {'ack': 1}
def test_from_json(self): assert Message.from_json(dumps([])) == [] expected = [{'channel': '/test', 'id': '1'}] messages = Message.from_json(dumps(expected[0])) assert messages == expected for message in messages: assert isinstance(message, Message) expected = [ {'channel': '/test1', 'id': '1'}, {'channel': '/test2', 'id': '2'} ] messages = Message.from_json(dumps(expected)) assert messages == expected for message in messages: assert isinstance(message, Message)
def _prepare_request(self, messages): # Determine the URL for the messages url = self.url if self._append_message_type and len(messages) == 1 and messages[0].channel.is_meta(): message_type = '/'.join(messages[0].channel.parts()[1:]) if not url.endswith('/'): url += '/' url += message_type # Get the headers for the request headers = HTTPHeaders() for header, values in self.get_headers().items(): for value in values: headers.add(header, value) for header, value in headers.get_all(): self.log.debug('Request header %s: %s' % (header, value)) # Get the body for the request body = Message.to_json(messages, encoding='utf8') self.log.debug('Request body (length: %d): %s' % (len(body), body)) # Get the timeout (in seconds) timeout = self.get_timeout(messages) / 1000.0 self.log.debug('Request timeout: %ss' % timeout) # Build and return the request return HTTPRequest( url, method='POST', headers=headers, body=body, connect_timeout=timeout, request_timeout=timeout )
def _handle_response(self, response, messages): # Log the received response code and headers self.log.debug('Received response: %s' % response.code) for header, value in response.headers.get_all(): self.log.debug('Response header %s: %s' % (header, value)) # If there was an error, report the sent messages as failed if response.error: if isinstance(response.error, HTTPError): if response.error.code == 599: error = errors.TimeoutError() else: error = errors.ServerError(response.error.code) else: error = errors.CommunicationError(response.error) self.log.debug('Failed to send messages: %s' % error) self._client.fail_messages(messages, error) return # Update the cookies self.update_cookies( response.headers.get_list('Set-Cookie'), time_received=response.headers.get('Date') ) # Get the received messages self.log.debug('Received body: %s' % response.body) messages = Message.from_json(response.body, encoding='utf8') self._client.receive_messages(messages)
def test_from_json_with_encoding(self): expected = [{'channel': u'/caf\xe9', 'id': '1'}] value = dumps(expected, ensure_ascii=False).encode('utf8') messages = Message.from_json(value, encoding='utf8') assert messages == expected for message in messages: assert isinstance(message, Message)
def test_send(self): message = Message(channel='/test') with patch.object(timestamp, 'formatdate') as mock_formatdate: mock_formatdate.return_value = formatdate(usegmt=True) assert self.extension.send(message) is message mock_formatdate.assert_called_with(usegmt=True) assert message.timestamp == mock_formatdate.return_value
def test_disconnect(self): # Connect the client so we can disconnect self.connect_client() # Issue the disconnect request self.client.disconnect() assert self.client.status == ClientStatus.DISCONNECTING assert len(self.transport.sent_messages) == 1 message = self.transport.sent_messages[0] assert message == self.create_sent_message( channel=ChannelId.META_DISCONNECT) # Make sure we can't attempt to disconnect again during a disconnect self.client.disconnect() assert self.client.status == ClientStatus.DISCONNECTING assert len(self.transport.sent_messages) == 1 # Complete the disconnect self.transport.receive( [Message(channel=ChannelId.META_DISCONNECT, successful=True)]) assert self.client.status == ClientStatus.DISCONNECTED assert len(self.transport.sent_messages) == 1 assert self.client.backoff_period == 0 assert self.client.client_id is None
def _prepare_request(self, messages): # Determine the URL for the messages url = self.url if self._append_message_type and len(messages) == 1 and messages[0].channel.is_meta(): message_type = '/'.join(messages[0].channel.parts()[1:]) if not url.endswith('/'): url += '/' url += message_type # Get the headers for the request headers = HTTPHeaders() for header, values in self.get_headers().iteritems(): for value in values: headers.add(header, value) for header, value in headers.get_all(): self.log.debug('Request header %s: %s' % (header, value)) # Get the body for the request body = Message.to_json(messages, encoding='utf8') self.log.debug('Request body (length: %d): %s' % (len(body), body)) # Get the timeout (in seconds) timeout = self.get_timeout(messages) / 1000.0 self.log.debug('Request timeout: %ss' % timeout) # Build and return the request return HTTPRequest( url, method='POST', headers=headers, body=body, connect_timeout=timeout, request_timeout=timeout )
def test_from_json_with_encoding(self): expected = [{'channel': '/caf\xe9', 'id': '1'}] value = dumps(expected, ensure_ascii=False).encode('utf8') messages = Message.from_json(value, encoding='utf8') assert messages == expected for message in messages: assert isinstance(message, Message)
def test_send_connect(self): message = Message(channel=ChannelId.META_CONNECT) assert self.extension.send(message) == message assert not message.ext self.extension.receive(Message( channel=ChannelId.META_HANDSHAKE, ext={AckExtension.FIELD_ACK: True} )) assert self.extension.send(message) == message assert message.ext[AckExtension.FIELD_ACK] is None self.extension.receive(Message( channel=ChannelId.META_CONNECT, successful=True, ext={AckExtension.FIELD_ACK: 1} )) assert self.extension.send(message) == message assert message.ext[AckExtension.FIELD_ACK] == 1
def test_fail_messages_connect(self): self.connect_client() mock_message_1 = Message(channel=ChannelId.META_CONNECT, client_id=self.client.client_id, connection_type=self.transport.name, advice={Message.FIELD_TIMEOUT: 0}) mock_message_2 = mock_message_1.copy() exception = Exception() with nested(self.capture_messages(), self.capture_timeouts()) as (messages, timeouts): self.client.fail_messages([mock_message_1]) self.client.fail_messages([mock_message_2], exception=exception) self.check_failure_messages( messages, { ChannelId.META_CONNECT: [ FailureMessage.from_message( mock_message_1, advice={ FailureMessage.FIELD_RECONNECT: FailureMessage.RECONNECT_RETRY, FailureMessage.FIELD_INTERVAL: 0 }), FailureMessage.from_message( mock_message_2, exception=exception, advice={ FailureMessage.FIELD_RECONNECT: FailureMessage.RECONNECT_RETRY, FailureMessage.FIELD_INTERVAL: self.DEFAULT_OPTIONS['backoff_period_increment'] }) ] }) # Make sure a single delayed connect was scheduled self.transport.clear_sent_messages() assert len(timeouts) == 1 assert timeouts[0].deadline == timedelta( milliseconds=self.DEFAULT_OPTIONS['backoff_period_increment'] * 2) timeouts[0].callback() assert self.transport.sent_messages == [ self.create_sent_message(channel=ChannelId.META_CONNECT, connection_type=self.transport.name, advice={Message.FIELD_TIMEOUT: 0}) ]
def test_send_handshake(self): message = Message(channel=ChannelId.META_HANDSHAKE) assert self.extension.send(message) == message assert message.ext[AckExtension.FIELD_ACK] assert self.extension.ack_id is None self.client.configure(ack_enabled=False) assert self.extension.send(message) == message assert not message.ext[AckExtension.FIELD_ACK] assert self.extension.ack_id is None
def test_from_json(self): assert Message.from_json(dumps([])) == [] expected = [{'channel': '/test', 'id': '1'}] messages = Message.from_json(dumps(expected[0])) assert messages == expected for message in messages: assert isinstance(message, Message) expected = [{ 'channel': '/test1', 'id': '1' }, { 'channel': '/test2', 'id': '2' }] messages = Message.from_json(dumps(expected)) assert messages == expected for message in messages: assert isinstance(message, Message)
def test_receive_connect(self): # Check that nothing happens when no ACK ID is included message = Message(channel=ChannelId.META_CONNECT, successful=True) assert self.extension.receive(message) is message assert self.extension.ack_id is None assert not self.extension.server_supports_acks # Check that nothing happens when server support is unknown message.ext = {AckExtension.FIELD_ACK: 1} assert self.extension.receive(message) is message assert self.extension.ack_id is None assert not self.extension.server_supports_acks # Notify the extension that server supports ACKs self.extension.receive(Message( channel=ChannelId.META_HANDSHAKE, ext={AckExtension.FIELD_ACK: True} )) # Check that the ACK ID is captured assert self.extension.server_supports_acks assert self.extension.receive(message) is message assert self.extension.ack_id is 1 # Check that the ACK ID is ignored for failed messages message.ext[AckExtension.FIELD_ACK] = 2 message.successful = False assert self.extension.receive(message) is message assert self.extension.ack_id is 1 # Check that the ACK ID is ignored if not an integer message.ext[AckExtension.FIELD_ACK] = '2' message.successful = True assert self.extension.receive(message) is message assert self.extension.ack_id is 1 # Check that updates to the ACK ID are captured message.ext[AckExtension.FIELD_ACK] = 2 assert self.extension.receive(message) is message assert self.extension.ack_id is 2
def subscribe(self, function, *extra_args, **extra_kwargs): properties = None if 'properties' in extra_kwargs: properties = extra_kwargs.pop('properties') if not self.has_subscriptions: self.log.debug('Subscribe to channel "%s"' % self._channel_id) message = Message(properties, channel=ChannelId.META_SUBSCRIBE, subscription=self._channel_id) self._client.send(message) return self._add_listener(self._subscriptions, function, extra_args, extra_kwargs)
def unsubscribe(self, id=None, function=None, properties=None): success = self._remove_listener(self._subscriptions, id=id, function=function) if not self.has_subscriptions: self.log.debug( 'Channel has no remaining subscriptions, sending unsubscribe') message = Message(properties, channel=ChannelId.META_UNSUBSCRIBE, subscription=self._channel_id) self._client.send(message) return success
def test_fields(self): timestamp = formatdate(usegmt=True) message = Message( advice={'reconnect': 'none'}, channel='/test', client_id='client-1', connection_type='long-polling', data='dummy', error='402 Client Unauthorized', ext={}, id='1', interval=0, minimum_version='0.9', reconnect='none', subscription='/topic', supported_connection_types=['long-polling', 'callback-polling'], timeout=1000, timestamp=timestamp, version='1.0') assert message == { 'advice': { 'reconnect': 'none' }, 'channel': '/test', 'clientId': 'client-1', 'connectionType': 'long-polling', 'data': 'dummy', 'error': '402 Client Unauthorized', 'ext': {}, 'id': '1', 'interval': 0, 'minimumVersion': '0.9', 'reconnect': 'none', 'subscription': '/topic', 'supportedConnectionTypes': ['long-polling', 'callback-polling'], 'timeout': 1000, 'timestamp': timestamp, 'version': '1.0' }
def test_receive_connect(self): # Check that nothing happens when no ACK ID is included message = Message(channel=ChannelId.META_CONNECT, successful=True) assert self.extension.receive(message) is message assert self.extension.ack_id is None assert not self.extension.server_supports_acks # Check that nothing happens when server support is unknown message.ext = {AckExtension.FIELD_ACK: 1} assert self.extension.receive(message) == message assert self.extension.ack_id is None assert not self.extension.server_supports_acks # Notify the extension that server supports ACKs self.extension.receive(Message( channel=ChannelId.META_HANDSHAKE, ext={AckExtension.FIELD_ACK: True} )) # Check that the ACK ID is captured assert self.extension.server_supports_acks assert self.extension.receive(message) == message assert self.extension.ack_id == 1 # Check that the ACK ID is ignored for failed messages message.ext[AckExtension.FIELD_ACK] = 2 message.successful = False assert self.extension.receive(message) == message assert self.extension.ack_id == 1 # Check that the ACK ID is ignored if not an integer message.ext[AckExtension.FIELD_ACK] = '2' message.successful = True assert self.extension.receive(message) == message assert self.extension.ack_id == 1 # Check that updates to the ACK ID are captured message.ext[AckExtension.FIELD_ACK] = 2 assert self.extension.receive(message) == message assert self.extension.ack_id == 2
def test_from_dict(self): message = Message.from_dict({'channel': '/test', 'id': '1'}) assert message == {'channel': '/test', 'id': '1'} assert isinstance(message.channel, ChannelId)
def test_bad_attribute(self): message = Message() self.assertRaises(AttributeError, getattr, message, 'bad_attribute') message.bad_attribute = 'dummy' assert message.bad_attribute == 'dummy' assert message == {}
def test_to_json_with_encoding(self): message = Message(channel=u'/caf\xe9', id='1') value = dumps([message], ensure_ascii=False).encode('utf8') assert Message.to_json(message, encoding='utf8') == value
def test_attribute(self): message = Message() assert message.channel is None message.channel = '/test' assert message.channel == '/test' assert message == {'channel': '/test'}
def test_receive(self): message = Message(channel='/test') assert self.extension.receive(message) is message assert message == {'channel': '/test'}
def setUp(self): self.io_loop = self.get_new_ioloop() self.client = Client('http://www.example.com', io_loop=self.io_loop) self.transport = MockTransport('mock-transport') self.client.register_transport(self.transport) self.mock_message = Message(channel='/test', data='dummy')
class TestClient(AsyncTestCase): DEFAULT_OPTIONS = { 'backoff_period_increment': 1000, 'maximum_backoff_period': 60000, 'reverse_incoming_extensions': True, 'advice': { 'timeout': 60000, 'interval': 0, 'reconnect': 'retry' } } def check_failure_messages(self, messages, expected_messages): if ChannelId.META_UNSUCCESSFUL not in expected_messages: all_messages = sum(expected_messages.values(), []) expected_messages[ChannelId.META_UNSUCCESSFUL] = all_messages assert sorted(messages.keys()) == sorted(expected_messages.keys()) for channel_id, channel_messages in expected_messages.iteritems(): assert messages[channel_id] == channel_messages def connect_client(self, client_id='client-1'): self.client.handshake() self.transport.receive([ Message( channel=ChannelId.META_HANDSHAKE, successful=True, client_id=client_id, supported_connection_types=[self.transport.name], version=Client.BAYEUX_VERSION ) ]) self.transport.receive([ Message( channel=ChannelId.META_CONNECT, successful=True ) ]) self.transport.clear_sent_messages() def create_sent_message(self, *args, **kwargs): message = Message(*args, **kwargs) if not message.client_id and self.client.client_id: message.client_id = self.client.client_id if not message.id: message.id = str(self.client.message_id) return message def create_mock_function(self, name='mock', **kwargs): mock = Mock(**kwargs) mock.__name__ = name return mock def disconnect_client(self): self.client.disconnect() self.transport.receive([ Message( channel=ChannelId.META_DISCONNECT, successful=True ) ]) self.transport.clear_sent_messages() def setUp(self): self.io_loop = self.get_new_ioloop() self.client = Client('http://www.example.com', io_loop=self.io_loop) self.transport = MockTransport('mock-transport') self.client.register_transport(self.transport) self.mock_message = Message(channel='/test', data='dummy') def test_init(self): assert isinstance(self.client.log, logging.Logger) assert self.client.log.name == 'baiocas.client.Client' assert self.client.advice == {} assert self.client.backoff_period == 0 assert self.client.client_id is None assert not self.client.is_batching assert not self.client.is_disconnected assert self.client.options == self.DEFAULT_OPTIONS assert self.client.status == ClientStatus.UNCONNECTED assert self.client.transport is None assert self.client.url == 'http://www.example.com' def test_advice(self): advice = self.client.advice assert advice == {} advice['temp'] = 'dummy' assert self.client.advice == {} def test_is_batching(self): assert not self.client.is_batching with self.client.batch(): assert self.client.is_batching assert not self.client.is_batching def test_is_batching_manual(self): assert not self.client.is_batching self.client.start_batch() assert self.client.is_batching self.client.end_batch() assert not self.client.is_batching def test_is_batching_nested(self): assert not self.client.is_batching with self.client.batch(): with self.client.batch(): assert self.client.is_batching assert self.client.is_batching assert not self.client.is_batching def test_is_disconnected(self): assert not self.client.is_disconnected self.connect_client() assert not self.client.is_disconnected self.disconnect_client() assert self.client.is_disconnected def test_message_id(self): assert self.client.message_id == 0 self.client.handshake() assert self.client.message_id == 1 def test_options(self): options = self.client.options assert options == self.DEFAULT_OPTIONS options['temp'] = 'dummy' assert self.client.options == self.DEFAULT_OPTIONS def test_clear_subscriptions(self): mock_listener = self.create_mock_function() mock_subscription = self.create_mock_function() channel_1 = self.client.get_channel('/test1') channel_2 = self.client.get_channel('/test2') channel_1.add_listener(mock_listener) channel_1.subscribe(mock_subscription) channel_1.subscribe(mock_subscription) channel_2.subscribe(mock_subscription) assert channel_1.has_subscriptions assert channel_2.has_subscriptions self.client.clear_subscriptions() assert not channel_1.has_subscriptions assert not channel_2.has_subscriptions channel_1.notify_listeners(channel_1, self.mock_message) channel_2.notify_listeners(channel_2, self.mock_message) mock_listener.called_with(channel_1, self.mock_message) assert not mock_subscription.called def test_configure(self): # Make sure blank calls keep the defaults self.client.configure() assert self.client.options == self.DEFAULT_OPTIONS # Make sure we can add an option self.client.configure(temp='dummy') options = self.DEFAULT_OPTIONS.copy() options['temp'] = 'dummy' assert self.client.options == options # Check that the option sticks self.client.configure() assert self.client.options == options # Make sure we can change existing options self.client.configure(temp=False) options['temp'] = False assert self.client.options == options # Make sure we can change a default self.client.configure(backoff_period_increment=0) options['backoff_period_increment'] = 0 assert self.client.options == options def test_disconnect(self): # Connect the client so we can disconnect self.connect_client() # Issue the disconnect request self.client.disconnect() assert self.client.status == ClientStatus.DISCONNECTING assert len(self.transport.sent_messages) == 1 message = self.transport.sent_messages[0] assert message == self.create_sent_message(channel=ChannelId.META_DISCONNECT) # Make sure we can't attempt to disconnect again during a disconnect self.client.disconnect() assert self.client.status == ClientStatus.DISCONNECTING assert len(self.transport.sent_messages) == 1 # Complete the disconnect self.transport.receive([ Message( channel=ChannelId.META_DISCONNECT, successful=True ) ]) assert self.client.status == ClientStatus.DISCONNECTED assert len(self.transport.sent_messages) == 1 assert self.client.backoff_period == 0 assert self.client.client_id is None def test_disconnect_second_response(self): self.connect_client() self.client.disconnect() self.transport.receive([ Message( channel=ChannelId.META_DISCONNECT, successful=True ) ]) def test_disconnect_with_queued_messages(self): self.client.handshake() mock_message = self.mock_message.copy() self.client.send(mock_message) with self.capture_messages(only_failures=True) as messages: self.disconnect_client() self.check_failure_messages(messages, { ChannelId.META_PUBLISH: [ FailureMessage.from_message( mock_message, exception=errors.StatusError(ClientStatus.DISCONNECTED) ) ] }) def test_disconnect_with_properties(self): self.connect_client() self.client.disconnect(properties={'temp': 'dummy'}) assert len(self.transport.sent_messages) == 1 message = self.transport.sent_messages[0] assert message == self.create_sent_message( {'temp': 'dummy'}, channel=ChannelId.META_DISCONNECT ) def test_end_batch(self): self.assertRaises(errors.BatchError, self.client.end_batch) with patch.object(self.client, 'flush_batch') as mock_flush_batch: self.client.start_batch() self.client.start_batch() self.client.end_batch() assert not mock_flush_batch.called self.client.end_batch() assert mock_flush_batch.call_count == 1 self.assertRaises(errors.BatchError, self.client.end_batch) def test_fail_messages(self): self.connect_client() mock_message_1 = self.mock_message.copy() mock_message_2 = self.mock_message.copy() exception = Exception() with self.capture_messages() as messages: self.client.fail_messages([]) self.client.fail_messages([mock_message_1]) self.client.fail_messages([mock_message_2], exception=exception) self.check_failure_messages(messages, { ChannelId.META_PUBLISH: [ FailureMessage.from_message(mock_message_1), FailureMessage.from_message(mock_message_2, exception=exception) ] }) def test_fail_messages_connect(self): self.connect_client() mock_message_1 = Message( channel=ChannelId.META_CONNECT, client_id=self.client.client_id, connection_type=self.transport.name, advice={Message.FIELD_TIMEOUT: 0} ) mock_message_2 = mock_message_1.copy() exception = Exception() with nested( self.capture_messages(), self.capture_timeouts() ) as (messages, timeouts): self.client.fail_messages([mock_message_1]) self.client.fail_messages([mock_message_2], exception=exception) self.check_failure_messages(messages, { ChannelId.META_CONNECT: [ FailureMessage.from_message( mock_message_1, advice={ FailureMessage.FIELD_RECONNECT: FailureMessage.RECONNECT_RETRY, FailureMessage.FIELD_INTERVAL: 0 } ), FailureMessage.from_message( mock_message_2, exception=exception, advice={ FailureMessage.FIELD_RECONNECT: FailureMessage.RECONNECT_RETRY, FailureMessage.FIELD_INTERVAL: self.DEFAULT_OPTIONS['backoff_period_increment'] } ) ] }) # Make sure a single delayed connect was scheduled self.transport.clear_sent_messages() assert len(timeouts) == 1 assert timeouts[0].deadline == timedelta( milliseconds=self.DEFAULT_OPTIONS['backoff_period_increment'] * 2 ) timeouts[0].callback() assert self.transport.sent_messages == [ self.create_sent_message( channel=ChannelId.META_CONNECT, connection_type=self.transport.name, advice={ Message.FIELD_TIMEOUT: 0 } ) ] def test_fire(self): # Register mock listeners event = self.client.EVENT_EXTENSION_EXCEPTION mock_listener = self.create_mock_function() bad_listener = self.create_mock_function(side_effect=Exception()) self.client.register_listener(event, mock_listener) self.client.register_listener(event, bad_listener) self.client.register_listener(event, mock_listener, 2, foo='bar2') self.client.register_listener(event, mock_listener, 3) self.client.register_listener(event, mock_listener, foo='bar4') # Make sure the listeners don't fire for a different event self.client.fire('mock_event') assert not mock_listener.called # Check the basic functionality self.client.fire(event) assert mock_listener.call_args_list == [ ((self.client,),), ((self.client, 2,), {'foo': 'bar2'}), ((self.client, 3,),), ((self.client,), {'foo': 'bar4'}) ] assert bad_listener.call_count == 1 mock_listener.reset_mock() # Make sure args/kwargs get combined correctly self.client.fire(event, 5, 6, temp='dummy', foo='bar5') assert mock_listener.call_args_list == [ ((self.client, 5, 6), {'temp': 'dummy', 'foo': 'bar5'}), ((self.client, 5, 6, 2,), {'temp': 'dummy', 'foo': 'bar2'}), ((self.client, 5, 6, 3,), {'temp': 'dummy', 'foo': 'bar5'}), ((self.client, 5, 6), {'temp': 'dummy', 'foo': 'bar4'}) ] def test_flush_batch(self): self.connect_client() self.client.flush_batch() mock_message = self.mock_message.copy() with self.client.batch(): self.client.send(mock_message) assert self.transport.sent_messages == [] self.client.flush_batch() assert self.transport.sent_messages == [mock_message] self.transport.clear_sent_messages() self.client.flush_batch() assert self.transport.sent_messages == [] assert self.transport.sent_messages == [] def test_get_channel(self): channel = self.client.get_channel('/test') assert channel.channel_id == '/test' assert self.client.get_channel('/test') is channel assert self.client.get_channel(ChannelId('/test')) is channel other_channel = self.client.get_channel('/Test') assert other_channel.channel_id == '/Test' assert channel is not other_channel def test_get_known_transports(self): transports = self.client.get_known_transports() assert len(transports) == 1 assert self.transport.name in transports transport2 = MockTransport('mock-transport-2', only_versions=[]) self.client.register_transport(transport2) transports = self.client.get_known_transports() assert len(transports) == 2 assert self.transport.name in transports assert transport2.name in transports def test_get_transport(self): assert self.client.get_transport(self.transport.name) is self.transport assert self.client.get_transport('bad-transport') is None assert self.client.get_transport(self.transport.name.upper()) is None def test_register_extension(self): # Register extensions mock_extension_1 = MockExtension('mock-extension-1') mock_extension_2 = MockExtension('mock-extension-2') assert self.client.register_extension(mock_extension_1) assert self.client.register_extension(mock_extension_2) assert mock_extension_1.client is self.client assert mock_extension_2.client is self.client # Check that they get called for received messages mock_messages = [self.mock_message.copy()] self.client.receive_messages(mock_messages) assert mock_extension_1.received_messages == mock_messages assert mock_extension_1.sent_messages == [] assert mock_extension_2.received_messages == mock_messages assert mock_extension_2.sent_messages == [] # Connect the client to test sending self.connect_client() mock_extension_1.clear_messages() mock_extension_2.clear_messages() # Check that they get called for sent messages mock_message = self.mock_message.copy() self.client.send(mock_message) assert mock_extension_1.received_messages == [] assert mock_extension_1.sent_messages == [mock_message] assert mock_extension_2.received_messages == [] assert mock_extension_2.sent_messages == [mock_message] def test_register_listener(self): # Test the basic functionality event = self.client.EVENT_EXTENSION_EXCEPTION mock_listener_1 = self.create_mock_function() listener_id = self.client.register_listener(event, mock_listener_1, 1, foo='bar') assert not mock_listener_1.called self.client.fire(event) mock_listener_1.assert_called_once_with(self.client, 1, foo='bar') mock_listener_1.reset_mock() # Make sure multiple listeners per event can be registered mock_listener_2 = self.create_mock_function() self.client.register_listener(event, mock_listener_2) self.client.fire(event) mock_listener_1.assert_called_once_with(self.client, 1, foo='bar') mock_listener_2.assert_called_once_with(self.client) # Make sure listeners are registered only for the right event self.client.fire('mock-event') assert mock_listener_1.call_count == 1 assert mock_listener_2.call_count == 1 # Make sure the right listener ID is returned assert self.client.unregister_listener(listener_id) self.client.fire(event) assert mock_listener_1.call_count == 1 assert mock_listener_2.call_count == 2 def test_register_transport(self): transport2 = MockTransport('mock-transport-2') assert not self.client.register_transport(self.transport) assert self.client.register_transport(transport2) assert transport2.name in self.client.get_known_transports() def test_unregister_extension(self): # Connect the client to test sending messages self.connect_client() # Create the extensions mock_extension_1 = MockExtension('mock-extension-1') mock_extension_2 = MockExtension('mock-extension-2') # Check that unregistering invalid extensions doesn't fail assert not self.client.unregister_extension(mock_extension_1) # Register the extensions assert self.client.register_extension(mock_extension_1) assert self.client.register_extension(mock_extension_2) assert mock_extension_1.client is self.client assert mock_extension_2.client is self.client # Unregister the extension assert self.client.unregister_extension(mock_extension_1) assert mock_extension_1.client is None assert mock_extension_2.client is self.client # Make sure messages only get routed to registered extensions mock_messages = [self.mock_message.copy()] self.client.receive_messages(mock_messages) assert mock_extension_1.received_messages == [] assert mock_extension_1.sent_messages == [] assert mock_extension_2.received_messages == mock_messages assert mock_extension_2.sent_messages == [] self.client.send(mock_messages[0]) assert mock_extension_1.received_messages == [] assert mock_extension_1.sent_messages == [] assert mock_extension_2.received_messages == mock_messages assert mock_extension_2.sent_messages == mock_messages def test_unregister_listener(self): # Add a listener event = self.client.EVENT_EXTENSION_EXCEPTION mock_listener = self.create_mock_function() listener_id = self.client.register_listener(event, mock_listener) # Check validation of optional arguments self.assertRaises(ValueError, self.client.unregister_listener) self.assertRaises(ValueError, self.client.unregister_listener, id=listener_id, event=event) self.assertRaises(ValueError, self.client.unregister_listener, id=listener_id, function=mock_listener) # Make sure non-matches are handled correctly assert not self.client.unregister_listener(id=listener_id - 1) assert not self.client.unregister_listener(event='mock-event') assert not self.client.unregister_listener(function=self.create_mock_function()) assert not self.client.unregister_listener(event='mock-event', function=mock_listener) assert not self.client.unregister_listener(event=event, function=self.create_mock_function()) # Test removal by ID assert self.client.unregister_listener(id=listener_id) self.client.fire(event) assert not mock_listener.called # Test removal by event self.client.register_listener(event, mock_listener) self.client.register_listener(event, mock_listener) self.client.register_listener('mock-event', mock_listener) self.client.fire(event) assert mock_listener.call_count == 2 assert self.client.unregister_listener(event=event) self.client.fire(event) assert mock_listener.call_count == 2 self.client.fire('mock-event') assert mock_listener.call_count == 3 mock_listener.reset_mock() # Test removal by function mock_listener_2 = self.create_mock_function() self.client.register_listener(event, mock_listener) self.client.register_listener(event, mock_listener_2) self.client.register_listener(event, mock_listener) self.client.fire(event) assert mock_listener.call_count == 2 assert mock_listener_2.call_count == 1 assert self.client.unregister_listener(function=mock_listener) self.client.fire(event) assert mock_listener.call_count == 2 assert mock_listener_2.call_count == 2 mock_listener.reset_mock() mock_listener_2.reset_mock() # Test removal by event and function self.client.register_listener(event, mock_listener) self.client.register_listener(event, mock_listener) self.client.register_listener('mock-event', mock_listener) self.client.fire(event) assert mock_listener.call_count == 2 assert mock_listener_2.call_count == 1 assert self.client.unregister_listener(event=event, function=mock_listener) self.client.fire(event) assert mock_listener.call_count == 2 assert mock_listener_2.call_count == 2 self.client.fire('mock-event') assert mock_listener.call_count == 3 def test_unregister_transport(self): assert len(self.client.get_known_transports()) == 1 assert self.client.unregister_transport('bad-transport') is None assert self.client.unregister_transport(self.transport.name.upper()) is None assert self.client.unregister_transport(self.transport.name) is self.transport assert self.client.unregister_transport(self.transport.name) is None assert len(self.client.get_known_transports()) == 0 def test_batch(self): self.connect_client() mock_message = self.mock_message.copy() with self.client.batch(): assert self.client.is_batching self.client.send(mock_message) assert self.transport.sent_messages == [] assert not self.client.is_batching assert self.transport.sent_messages == [mock_message] @contextmanager def capture_messages(self, only_failures=False): # Create a listener that logs messages keyed by channel for all channels captured_messages = defaultdict(list) skipped_messages = [0] def _receive_message(channel, message): if message.failure or not only_failures: captured_messages[channel.channel_id].append(message) channel = self.client.get_channel('/**') listener_id = channel.add_listener(_receive_message) # Yield to the wrapped functionality, removing the listener on exit try: yield captured_messages finally: channel.remove_listener(id=listener_id) @contextmanager def capture_timeouts(self): # Keep track of the timeouts timeouts = [] # Create add/remove_timeout methods that update the timeouts list. We # don't log the timeout references directly because the deadline on # those gets converted and the class is private to Tornado. def _add_timeout(deadline, callback): timeout = IOLoop.add_timeout(self.io_loop, deadline, callback) timeouts.append(Timeout( callback=callback, deadline=deadline, reference=timeout )) return timeout def _remove_timeout(reference): IOLoop.remove_timeout(self.io_loop, reference) for index, timeout in enumerate(timeouts): if timeout.reference == reference: del timeouts[index] break # Grab all calls to add_timeout/remove_timeout with nested( patch.object(self.io_loop, 'add_timeout'), patch.object(self.io_loop, 'remove_timeout', mocksignature=True) ) as (mock_add_timeout, mock_remove_timeout): mock_add_timeout.side_effect = _add_timeout mock_remove_timeout.side_effect = _remove_timeout yield timeouts
def test_key(self): message = Message() self.assertRaises(KeyError, message.__getitem__, 'channel') message['channel'] = '/test' assert message['channel'] == '/test'
def test_to_json_with_encoding(self): message = Message(channel='/caf\xe9', id='1') value = dumps([message], ensure_ascii=False).encode('utf8') assert Message.to_json(message, encoding='utf8') == value