class TestOutputBuffer(unittest.TestCase): def setUp(self): self.output_buffer = OutputBuffer(enable_buffering=True) def test_get_buffer(self): self.assertEqual(OutputBuffer.EMPTY, self.output_buffer.get_buffer()) data1 = bytearray([i for i in range(20)]) self.output_buffer.enqueue_msgbytes(data1) self.output_buffer.flush() self.assertEqual(data1, self.output_buffer.get_buffer()) data2 = bytearray([i for i in range(20, 40)]) self.output_buffer.enqueue_msgbytes(data2) self.output_buffer.flush() self.assertEqual(data1, self.output_buffer.get_buffer()) new_index = 10 self.output_buffer.index = new_index self.assertEqual(data1[new_index:], self.output_buffer.get_buffer()) def test_advance_buffer(self): with self.assertRaises(ValueError): self.output_buffer.advance_buffer(5) data1 = bytearray([i for i in range(20)]) self.output_buffer.enqueue_msgbytes(data1) self.output_buffer.flush() data2 = bytearray([i for i in range(20, 40)]) self.output_buffer.enqueue_msgbytes(data2) self.output_buffer.flush() self.output_buffer.advance_buffer(10) self.assertEqual(10, self.output_buffer.index) self.assertEqual(30, self.output_buffer.length) self.output_buffer.advance_buffer(10) self.assertEqual(0, self.output_buffer.index) self.assertEqual(1, len(self.output_buffer.output_msgs)) def test_at_msg_boundary(self): self.assertTrue(self.output_buffer.at_msg_boundary()) self.output_buffer.index = 1 self.assertFalse(self.output_buffer.at_msg_boundary()) def test_enqueue_msgbytes(self): with self.assertRaises(ValueError): self.output_buffer.enqueue_msgbytes("f") data1 = bytearray([i for i in range(20)]) self.output_buffer.enqueue_msgbytes(data1) self.output_buffer.flush() self.assertEqual(data1, self.output_buffer.get_buffer()) data2 = bytearray([i for i in range(20, 40)]) self.output_buffer.enqueue_msgbytes(data2) self.output_buffer.flush() self.assertEqual(data1, self.output_buffer.get_buffer()) new_index = 10 self.output_buffer.index = new_index self.assertEqual(data1[new_index:], self.output_buffer.get_buffer()) def test_prepend_msgbytes(self): with self.assertRaises(ValueError): self.output_buffer.prepend_msgbytes("f") data1 = bytearray([i for i in range(20)]) self.output_buffer.prepend_msgbytes(data1) data2 = bytearray([i for i in range(20, 40)]) self.output_buffer.prepend_msgbytes(data2) confirm1 = deque() confirm1.append(data2) confirm1.append(data1) self.assertEqual(confirm1, self.output_buffer.output_msgs) self.assertEqual(40, self.output_buffer.length) self.output_buffer.advance_buffer(10) data3 = bytearray([i for i in range(40, 60)]) self.output_buffer.prepend_msgbytes(data3) confirm2 = deque() confirm2.append(data2) confirm2.append(data3) confirm2.append(data1) self.assertEqual(confirm2, self.output_buffer.output_msgs) self.assertEqual(50, self.output_buffer.length) def test_has_more_bytes(self): self.assertFalse(self.output_buffer.has_more_bytes()) self.output_buffer.length = 1 self.assertTrue(self.output_buffer.has_more_bytes()) def test_flush_get_buffer_on_time(self): data1 = bytearray(i for i in range(20)) self.output_buffer.enqueue_msgbytes(data1) self.assertEqual(OutputBuffer.EMPTY, self.output_buffer.get_buffer()) time.time = MagicMock(return_value=time.time() + OUTPUT_BUFFER_BATCH_MAX_HOLD_TIME + 0.001) self.assertEqual(data1, self.output_buffer.get_buffer()) def test_flush_get_buffer_on_size(self): data1 = bytearray(i for i in range(20)) self.output_buffer.enqueue_msgbytes(data1) self.assertEqual(OutputBuffer.EMPTY, self.output_buffer.get_buffer()) data2 = bytearray(1 for _ in range(OUTPUT_BUFFER_MIN_SIZE)) self.output_buffer.enqueue_msgbytes(data2) self.assertNotEqual(OutputBuffer.EMPTY, self.output_buffer.get_buffer()) def test_safe_empty(self): self.output_buffer = OutputBuffer(enable_buffering=False) messages = [ helpers.generate_bytearray(10), helpers.generate_bytearray(10) ] for message in messages: self.output_buffer.enqueue_msgbytes(message) self.output_buffer.advance_buffer(5) self.assertEqual(15, len(self.output_buffer)) self.output_buffer.safe_empty() self.assertEqual(5, len(self.output_buffer)) def test_safe_empty_no_contents(self): self.output_buffer = OutputBuffer(enable_buffering=False) self.output_buffer.safe_empty() def test_safe_empty_buffering(self): messages = [ helpers.generate_bytearray(10), helpers.generate_bytearray(10) ] for message in messages: self.output_buffer.enqueue_msgbytes(message) self.assertEqual(20, len(self.output_buffer)) self.assertEqual(OutputBuffer.EMPTY, self.output_buffer.get_buffer()) self.output_buffer.safe_empty() self.assertEqual(0, len(self.output_buffer)) self.assertEqual(OutputBuffer.EMPTY, self.output_buffer.get_buffer())
class InternalNodeConnection(AbstractConnection[Node]): __metaclass__ = ABCMeta def __init__(self, sock, address, node, from_me=False): super(InternalNodeConnection, self).__init__(sock, address, node, from_me) # Enable buffering only on internal connections self.enable_buffered_send = node.opts.enable_buffered_send self.outputbuf = OutputBuffer(enable_buffering=self.enable_buffered_send) self.network_num = node.network_num self.version_manager = bloxroute_version_manager # Setting default protocol version and message factory; override when hello message received self.message_factory = bloxroute_message_factory self.protocol_version = self.version_manager.CURRENT_PROTOCOL_VERSION self.ping_message = PingMessage() self.pong_message = PongMessage() self.ack_message = AckMessage() self.can_send_pings = True self.ping_message_timestamps = ExpiringDict(self.node.alarm_queue, constants.REQUEST_EXPIRATION_TIME) self.message_validator = BloxrouteMessageValidator(None, self.protocol_version) def disable_buffering(self): """ Disable buffering on this particular connection. :return: """ self.enable_buffered_send = False self.outputbuf.flush() self.outputbuf.enable_buffering = False self.socket_connection.send() def set_protocol_version_and_message_factory(self): """ Gets protocol version from the first bytes of hello message if not known. Sets protocol version and creates message factory for that protocol version """ # Outgoing connections use current version of protocol and message factory if self.from_me or self.state & ConnectionState.HELLO_RECVD: return True protocol_version = self.version_manager.get_connection_protocol_version(self.inputbuf) if protocol_version is None: return False if not self.version_manager.is_protocol_supported(protocol_version): self.log_debug( "Protocol version {} of remote node '{}' is not supported. Closing connection.", protocol_version, self.peer_desc ) self.mark_for_close() return False self.protocol_version = protocol_version self.message_factory = self.version_manager.get_message_factory_for_version(protocol_version) self.log_trace("Detected incoming connection with protocol version {}".format(protocol_version)) return True def pre_process_msg(self): success = self.set_protocol_version_and_message_factory() if not success: return False, None, None return super(InternalNodeConnection, self).pre_process_msg() def enqueue_msg(self, msg, prepend=False): if self.state & ConnectionState.MARK_FOR_CLOSE: return if self.protocol_version < self.version_manager.CURRENT_PROTOCOL_VERSION: versioned_message = self.version_manager.convert_message_to_older_version(self.protocol_version, msg) else: versioned_message = msg super(InternalNodeConnection, self).enqueue_msg(versioned_message, prepend) def pop_next_message(self, payload_len): msg = super(InternalNodeConnection, self).pop_next_message(payload_len) if msg is None or self.protocol_version >= self.version_manager.CURRENT_PROTOCOL_VERSION: return msg versioned_msg = self.version_manager.convert_message_from_older_version(self.protocol_version, msg) return versioned_msg def msg_hello(self, msg): super(InternalNodeConnection, self).msg_hello(msg) if self.state & ConnectionState.MARK_FOR_CLOSE: self.log_trace("Connection has been closed: {}, Ignoring: {} ", self, msg) return network_num = msg.network_num() if self.node.network_num != constants.ALL_NETWORK_NUM and network_num != self.node.network_num: self.log_warning( "Network number mismatch. Current network num {}, remote network num {}. Closing connection.", self.node.network_num, network_num) self.mark_for_close() return self.network_num = network_num self.node.alarm_queue.register_alarm(self.ping_interval_s, self.send_ping) def peek_broadcast_msg_network_num(self, input_buffer): if self.protocol_version == 1: return constants.DEFAULT_NETWORK_NUM return BroadcastMessage.peek_network_num(input_buffer) def send_ping(self): """ Send a ping (and reschedule if called from alarm queue) """ if self.can_send_pings and not self.state & ConnectionState.MARK_FOR_CLOSE: nonce = nonce_generator.get_nonce() msg = PingMessage(nonce=nonce) self.enqueue_msg(msg) self.ping_message_timestamps.add(nonce, time.time()) return self.ping_interval_s return constants.CANCEL_ALARMS def msg_ping(self, msg): nonce = msg.nonce() self.enqueue_msg(PongMessage(nonce=nonce)) def msg_pong(self, msg: PongMessage): nonce = msg.nonce() if nonce in self.ping_message_timestamps.contents: request_msg_timestamp = self.ping_message_timestamps.contents[nonce] request_response_time = time.time() - request_msg_timestamp self.log_trace("Pong for nonce {} had response time: {}", msg.nonce(), request_response_time) hooks.add_measurement(self.peer_desc, MeasurementType.PING, request_response_time) elif nonce is not None: self.log_debug("Pong message had no matching ping request. Nonce: {}", nonce) def msg_tx_service_sync_txs(self, msg: TxServiceSyncTxsMessage): """ Transaction service sync message receive txs data """ network_num = msg.network_num() self.node.last_sync_message_received_by_network[network_num] = time.time() tx_service = self.node.get_tx_service(network_num) txs_content_short_ids = msg.txs_content_short_ids() for tx_content_short_ids in txs_content_short_ids: tx_hash = tx_content_short_ids.tx_hash tx_service.set_transaction_contents(tx_hash, tx_content_short_ids.tx_content) for short_id in tx_content_short_ids.short_ids: tx_service.assign_short_id(tx_hash, short_id) def _create_txs_service_msg(self, network_num: int, tx_service_snap: List[Sha256Hash]) -> List[TxContentShortIds]: txs_content_short_ids: List[TxContentShortIds] = [] txs_msg_len = 0 while tx_service_snap: tx_hash = tx_service_snap.pop() tx_content_short_ids = TxContentShortIds( tx_hash, self.node.get_tx_service(network_num).get_transaction_by_hash(tx_hash), self.node.get_tx_service(network_num).get_short_ids(tx_hash) ) txs_msg_len += txs_serializer.get_serialized_tx_content_short_ids_bytes_len(tx_content_short_ids) txs_content_short_ids.append(tx_content_short_ids) if txs_msg_len >= constants.TXS_MSG_SIZE: break return txs_content_short_ids def send_tx_service_sync_req(self, network_num: int): """ sending transaction service sync request """ self.enqueue_msg(TxServiceSyncReqMessage(network_num)) def send_tx_service_sync_complete(self, network_num: int): self.enqueue_msg(TxServiceSyncCompleteMessage(network_num)) def send_tx_service_sync_blocks_short_ids(self, network_num: int): blocks_short_ids: List[BlockShortIds] = [] start_time = time.time() for block_hash, short_ids in self.node.get_tx_service(network_num).iter_short_ids_seen_in_block(): blocks_short_ids.append(BlockShortIds(block_hash, short_ids)) block_short_ids_msg = TxServiceSyncBlocksShortIdsMessage(network_num, blocks_short_ids) duration = time.time() - start_time self.log_trace("Sending {} block short ids took {:.3f} seconds.", len(blocks_short_ids), duration) self.enqueue_msg(block_short_ids_msg) def send_tx_service_sync_txs(self, network_num: int, tx_service_snap: List[Sha256Hash], duration: float = 0, msgs_count: int = 0, total_tx_count: int = 0, sending_tx_msgs_start_time: float = 0): if (time.time() - sending_tx_msgs_start_time) < constants.SENDING_TX_MSGS_TIMEOUT_MS: if tx_service_snap: start = time.time() txs_content_short_ids = self._create_txs_service_msg(network_num, tx_service_snap) self.enqueue_msg(TxServiceSyncTxsMessage(network_num, txs_content_short_ids)) duration += time.time() - start msgs_count += 1 total_tx_count += len(txs_content_short_ids) # checks again if tx_snap in case we still have msgs to send, else no need to wait # TX_SERVICE_SYNC_TXS_S seconds if tx_service_snap: self.node.alarm_queue.register_alarm( constants.TX_SERVICE_SYNC_TXS_S, self.send_tx_service_sync_txs, network_num, tx_service_snap, duration, msgs_count, total_tx_count, sending_tx_msgs_start_time ) else: # if all txs were sent, send complete msg self.log_trace("Sending {} transactions and {} messages took {:.3f} seconds.", total_tx_count, msgs_count, duration) self.send_tx_service_sync_complete(network_num) else: # if time is up - upgrade this node as synced - giving up self.log_trace("Sending {} transactions and {} messages took more than {} seconds. Giving up.", total_tx_count, msgs_count, constants.SENDING_TX_MSGS_TIMEOUT_MS) self.send_tx_service_sync_complete(network_num) def msg_tx_service_sync_complete(self, msg: TxServiceSyncCompleteMessage): network_num = msg.network_num() self.node.last_sync_message_received_by_network.pop(network_num, None) self.node.on_fully_updated_tx_service()
class InternalNodeConnection(AbstractConnection[Node]): __metaclass__ = ABCMeta def __init__(self, sock: AbstractSocketConnectionProtocol, node: Node) -> None: super(InternalNodeConnection, self).__init__(sock, node) # Enable buffering only on internal connections self.enable_buffered_send = node.opts.enable_buffered_send self.outputbuf = OutputBuffer( enable_buffering=self.enable_buffered_send) self.network_num = node.network_num self.version_manager = bloxroute_version_manager # Setting default protocol version; override when hello message received self.protocol_version = self.version_manager.CURRENT_PROTOCOL_VERSION self.pong_message = PongMessage() self.ack_message = AckMessage() self.can_send_pings = True self.pong_timeout_enabled = True self.ping_message_timestamps = ExpiringDict( self.node.alarm_queue, constants.REQUEST_EXPIRATION_TIME, f"{str(self)}_ping_timestamps") self.sync_ping_latencies: Dict[int, Optional[float]] = {} self._nonce_to_network_num: Dict[int, int] = {} self.message_validator = BloxrouteMessageValidator( None, self.protocol_version) self.tx_sync_service = TxSyncService(self) self.inbound_peer_latency: float = time.time() def connection_message_factory(self) -> AbstractMessageFactory: return bloxroute_message_factory def ping_message(self) -> AbstractMessage: nonce = nonce_generator.get_nonce() self.ping_message_timestamps.add(nonce, time.time()) return PingMessage(nonce) def disable_buffering(self): """ Disable buffering on this particular connection. :return: """ self.enable_buffered_send = False self.outputbuf.flush() self.outputbuf.enable_buffering = False self.socket_connection.send() def set_protocol_version_and_message_factory(self) -> bool: """ Gets protocol version from the first bytes of hello message if not known. Sets protocol version and creates message factory for that protocol version """ # Outgoing connections use current version of protocol and message factory if ConnectionState.HELLO_RECVD in self.state: return True protocol_version = self.version_manager.get_connection_protocol_version( self.inputbuf) if protocol_version is None: return False if not self.version_manager.is_protocol_supported(protocol_version): self.log_debug( "Protocol version {} of remote node '{}' is not supported. Closing connection.", protocol_version, self.peer_desc) self.mark_for_close() return False if protocol_version > self.version_manager.CURRENT_PROTOCOL_VERSION: self.log_debug( "Got message protocol {} that is higher the current version {}. Using current protocol version", protocol_version, self.version_manager.CURRENT_PROTOCOL_VERSION) protocol_version = self.version_manager.CURRENT_PROTOCOL_VERSION self.protocol_version = protocol_version self.message_factory = self.version_manager.get_message_factory_for_version( protocol_version) self.log_trace("Setting connection protocol version to {}".format( protocol_version)) return True def pre_process_msg(self) -> ConnectionMessagePreview: success = self.set_protocol_version_and_message_factory() if not success: return ConnectionMessagePreview(False, True, None, None) return super(InternalNodeConnection, self).pre_process_msg() def enqueue_msg(self, msg, prepend=False): if not self.is_alive(): return if self.protocol_version < self.version_manager.CURRENT_PROTOCOL_VERSION: versioned_message = self.version_manager.convert_message_to_older_version( self.protocol_version, msg) else: versioned_message = msg super(InternalNodeConnection, self).enqueue_msg(versioned_message, prepend) def pop_next_message(self, payload_len: int) -> AbstractMessage: msg = super(InternalNodeConnection, self).pop_next_message(payload_len) if msg is None or self.protocol_version >= self.version_manager.CURRENT_PROTOCOL_VERSION: return msg versioned_msg = self.version_manager.convert_message_from_older_version( self.protocol_version, msg) return versioned_msg def check_ping_latency_for_network(self, network_num: int) -> None: ping_message = cast(PingMessage, self.ping_message()) self.enqueue_msg(ping_message) self._nonce_to_network_num[ping_message.nonce()] = network_num self.sync_ping_latencies[network_num] = None def msg_hello(self, msg): super(InternalNodeConnection, self).msg_hello(msg) if not self.is_alive(): self.log_trace("Connection has been closed: {}, Ignoring: {} ", self, msg) return network_num = msg.network_num() if self.node.network_num != constants.ALL_NETWORK_NUM and network_num != self.node.network_num: self.log_warning(log_messages.NETWORK_NUMBER_MISMATCH, self.node.network_num, network_num) self.mark_for_close() return self.network_num = network_num self.schedule_pings() def peek_broadcast_msg_network_num(self, input_buffer): if self.protocol_version == 1: return constants.DEFAULT_NETWORK_NUM return BroadcastMessage.peek_network_num(input_buffer) # pylint: disable=arguments-differ def msg_ping(self, msg: PingMessage): nonce = msg.nonce() assumed_request_time = time.time( ) - nonce_generator.get_timestamp_from_nonce(nonce) self.inbound_peer_latency = assumed_request_time hooks.add_measurement(self.peer_desc, MeasurementType.PING_INCOMING, assumed_request_time, self.peer_id) self.enqueue_msg( PongMessage(nonce=nonce, timestamp=nonce_generator.get_nonce())) # pylint: disable=arguments-differ def msg_pong(self, msg: PongMessage): super(InternalNodeConnection, self).msg_pong(msg) nonce = msg.nonce() timestamp = msg.timestamp() if timestamp: self.inbound_peer_latency = time.time( ) - nonce_generator.get_timestamp_from_nonce(timestamp) if nonce in self.ping_message_timestamps.contents: request_msg_timestamp = self.ping_message_timestamps.contents[ nonce] request_response_time = time.time() - request_msg_timestamp if nonce in self._nonce_to_network_num: self.sync_ping_latencies[ self._nonce_to_network_num[nonce]] = request_response_time if request_response_time > constants.PING_PONG_TRESHOLD: self.log_debug( "Ping/pong exchange nonce {} took {:.2f} seconds to complete.", msg.nonce(), request_response_time) else: self.log_trace( "Ping/pong exchange nonce {} took {:.2f} seconds to complete.", msg.nonce(), request_response_time) hooks.add_measurement(self.peer_desc, MeasurementType.PING, request_response_time, self.peer_id) if timestamp: assumed_peer_response_time = nonce_generator.get_timestamp_from_nonce( timestamp) - request_msg_timestamp hooks.add_measurement(self.peer_desc, MeasurementType.PING_OUTGOING, assumed_peer_response_time, self.peer_id) elif nonce is not None: self.log_debug( "Pong message had no matching ping request. Nonce: {}", nonce) def mark_for_close(self, should_retry: Optional[bool] = None): super(InternalNodeConnection, self).mark_for_close(should_retry) self.cancel_pong_timeout() def is_gateway_connection(self): return self.CONNECTION_TYPE in ConnectionType.GATEWAY def is_external_gateway_connection(self): # self.CONNECTION_TYPE == ConnectionType.GATEWAY is equal True only for V1 gateways return self.CONNECTION_TYPE in ConnectionType.EXTERNAL_GATEWAY or self.CONNECTION_TYPE == ConnectionType.GATEWAY def is_internal_gateway_connection(self) -> bool: return self.CONNECTION_TYPE in ConnectionType.INTERNAL_GATEWAY def is_relay_connection(self): return self.CONNECTION_TYPE in ConnectionType.RELAY_ALL def is_proxy_connection(self) -> bool: return self.CONNECTION_TYPE in ConnectionType.RELAY_PROXY def update_tx_sync_complete(self, network_num: int): if network_num in self.sync_ping_latencies: del self.sync_ping_latencies[network_num] self._nonce_to_network_num = { nonce: other_network_num for nonce, other_network_num in self._nonce_to_network_num.items() if other_network_num != network_num }
class MessageTrackerTest(AbstractTestCase): def setUp(self) -> None: self.node = MockNode( helpers.get_common_opts(1001, external_ip="128.128.128.128")) self.tracker = MessageTracker( MockConnection(MockSocketConnection(), self.node)) self.output_buffer = OutputBuffer(enable_buffering=True) def test_empty_bytes_no_bytes_sent(self): message = TxMessage( helpers.generate_object_hash(), 5, tx_val=helpers.generate_bytearray(250), ) message_length = len(message.rawbytes()) self.output_buffer.enqueue_msgbytes(message.rawbytes()) self.output_buffer.flush() self.tracker.append_message(message_length, message) self.output_buffer.safe_empty() self.tracker.empty_bytes(self.output_buffer.length) self.assertEqual(0, self.output_buffer.length) self.assertEqual(0, self.tracker.bytes_remaining) self.assertEqual(0, len(self.tracker.messages)) def test_empty_bytes(self): message1 = TxMessage( helpers.generate_object_hash(), 5, tx_val=helpers.generate_bytearray(250), ) message2 = TxMessage( helpers.generate_object_hash(), 5, tx_val=helpers.generate_bytearray(250), ) message3 = TxMessage( helpers.generate_object_hash(), 5, tx_val=helpers.generate_bytearray(250), ) message_length = len(message1.rawbytes()) self.output_buffer.enqueue_msgbytes(message1.rawbytes()) self.output_buffer.flush() self.output_buffer.enqueue_msgbytes(message2.rawbytes()) self.output_buffer.enqueue_msgbytes(message3.rawbytes()) self.tracker.append_message(message_length, message1) self.tracker.append_message(message_length, message2) self.tracker.append_message(message_length, message3) self.output_buffer.advance_buffer(120) self.tracker.advance_bytes(120) self.output_buffer.safe_empty() self.assertEqual(message_length - 120, self.output_buffer.length) self.tracker.empty_bytes(self.output_buffer.length) self.assertEqual(1, len(self.tracker.messages)) self.assertEqual(message_length - 120, self.tracker.bytes_remaining) self.assertEqual(120, self.tracker.messages[0].sent_bytes) def test_empty_bytes_more_bytes(self): total_bytes = 0 for _ in range(100): message = TxMessage( helpers.generate_object_hash(), 5, tx_val=helpers.generate_bytearray(2500), ) message_length = len(message.rawbytes()) total_bytes += message_length self.output_buffer.enqueue_msgbytes(message.rawbytes()) self.tracker.append_message(message_length, message) self.output_buffer.advance_buffer(3500) self.tracker.advance_bytes(3500) self.output_buffer.safe_empty() self.tracker.empty_bytes(self.output_buffer.length) self.assertEqual(self.output_buffer.length, self.tracker.bytes_remaining)