def __init__(self, sock: AbstractSocketConnectionProtocol,
                 node: Node) -> None:
        super(InternalNodeConnection, self).__init__(sock, node)

        # Enable buffering only on internal connections
        self.enable_buffered_send = node.opts.enable_buffered_send
        self.outputbuf = OutputBuffer(
            enable_buffering=self.enable_buffered_send)

        self.network_num = node.network_num
        self.version_manager = bloxroute_version_manager

        # Setting default protocol version; override when hello message received
        self.protocol_version = self.version_manager.CURRENT_PROTOCOL_VERSION

        self.pong_message = PongMessage()
        self.ack_message = AckMessage()

        self.can_send_pings = True
        self.pong_timeout_enabled = True

        self.ping_message_timestamps = ExpiringDict(
            self.node.alarm_queue, constants.REQUEST_EXPIRATION_TIME,
            f"{str(self)}_ping_timestamps")

        self.sync_ping_latencies: Dict[int, Optional[float]] = {}
        self._nonce_to_network_num: Dict[int, int] = {}
        self.message_validator = BloxrouteMessageValidator(
            None, self.protocol_version)
        self.tx_sync_service = TxSyncService(self)
        self.inbound_peer_latency: float = time.time()
Exemple #2
0
    def __init__(self, sock: SocketConnection, address, node, from_me=False):
        self.socket_connection = sock
        self.fileno = sock.fileno()

        # (IP, Port) at time of socket creation. We may get a new application level port in
        # the version message if the connection is not from me.
        self.peer_ip, self.peer_port = address
        self.peer_id = None
        self.my_ip = node.opts.external_ip
        self.my_port = node.opts.external_port

        self.from_me = from_me  # Whether or not I initiated the connection

        self.outputbuf = OutputBuffer()
        self.inputbuf = InputBuffer()
        self.node = node

        self.is_persistent = False
        self.state = ConnectionState.CONNECTING

        # Number of bad messages I've received in a row.
        self.num_bad_messages = 0
        self.peer_desc = "%s %d" % (self.peer_ip, self.peer_port)
        self.message_handlers = None
        self.network_num = node.opts.blockchain_network_num

        self.enqueued_messages = []
    def test_queuing_messages_cleared_after_timeout(self):
        node = self._initialize_gateway(True, True, True)
        remote_blockchain_conn = next(iter(node.connection_pool.get_by_connection_types([ConnectionType.REMOTE_BLOCKCHAIN_NODE])))
        remote_blockchain_conn.mark_for_close()

        queued_message = PingMessage(12345)
        node.send_msg_to_remote_node(queued_message)
        self.assertEqual(1, len(node.remote_node_msg_queue._queue))
        self.assertEqual(queued_message, node.remote_node_msg_queue._queue[0])

        # queue has been cleared
        time.time = MagicMock(return_value=time.time() + node.opts.remote_blockchain_message_ttl + 0.1)
        node.alarm_queue.fire_alarms()

        node.on_connection_added(MockSocketConnection(3, node, ip_address=LOCALHOST, port=8003))
        next_conn = next(iter(node.connection_pool.get_by_connection_types([ConnectionType.REMOTE_BLOCKCHAIN_NODE])))
        next_conn.outputbuf = OutputBuffer()  # clear buffer

        node.on_remote_blockchain_connection_ready(next_conn)
        self.assertEqual(0, next_conn.outputbuf.length)

        next_conn.mark_for_close()

        queued_message = PingMessage(12345)
        node.send_msg_to_remote_node(queued_message)
        self.assertEqual(1, len(node.remote_node_msg_queue._queue))
        self.assertEqual(queued_message, node.remote_node_msg_queue._queue[0])

        node.on_connection_added(MockSocketConnection(4, node, ip_address=LOCALHOST, port=8003))
        reestablished_conn = next(iter(node.connection_pool.get_by_connection_types([ConnectionType.REMOTE_BLOCKCHAIN_NODE])))
        reestablished_conn.outputbuf = OutputBuffer()  # clear buffer

        node.on_remote_blockchain_connection_ready(reestablished_conn)
        self.assertEqual(queued_message.rawbytes().tobytes(), reestablished_conn.outputbuf.get_buffer().tobytes())
    def __init__(self, sock: AbstractSocketConnectionProtocol, node) -> None:
        self.socket_connection = sock
        self.file_no = sock.file_no

        # (IP, Port) at time of socket creation. We may get a new application level port in
        # the version message if the connection is not from me.
        self.peer_ip, self.peer_port = sock.endpoint
        self.endpoint = sock.endpoint
        self.peer_id = None
        self.my_ip = node.opts.external_ip
        self.my_port = node.opts.external_port
        self.direction = self.socket_connection.direction

        self.from_me = self.direction == NetworkDirection.OUTBOUND  # Whether or not I initiated the connection

        self.outputbuf = OutputBuffer()
        self.inputbuf = InputBuffer()
        self.node = node

        self.is_persistent = False
        self.state = ConnectionState.CONNECTING

        # Number of bad messages I've received in a row.
        self.num_bad_messages = 0
        self.peer_desc = "{} {}".format(self.peer_ip, self.peer_port)
        self.message_handlers = None
        self.network_num = node.opts.blockchain_network_num
        self.format_connection()

        self.enqueued_messages = []
        self.node_privileges = "general"
        self.subscribed_broadcasts = [BroadcastMessageType.BLOCK]
Exemple #5
0
    def __init__(self, socket_connection, address, node: Node, from_me=False):
        if not isinstance(socket_connection, SocketConnection):
            raise ValueError(
                "SocketConnection type is expected for socket_connection arg but was {0}."
                .format(type(socket_connection)))

        self.socket_connection = socket_connection
        self.fileno = socket_connection.fileno()

        # (IP, Port) at time of socket creation.
        # If the version/hello message contains a different port (i.e. connection is not from me), this will
        # be updated to the one in the message.
        self.peer_ip, self.peer_port = address
        self.peer_id: Optional[str] = None
        self.external_ip = node.opts.external_ip
        self.external_port = node.opts.external_port

        self.from_me = from_me  # Whether or not I initiated the connection

        if node.opts.track_detailed_sent_messages:
            self.message_tracker = MessageTracker(self)
        self.outputbuf = OutputBuffer()
        self.inputbuf = InputBuffer()
        self.node = node

        self.state = ConnectionState.CONNECTING

        # Number of bad messages I've received in a row.
        self.num_bad_messages = 0
        self.peer_desc = "%s %d" % (self.peer_ip, self.peer_port)

        self.can_send_pings = False

        self.hello_messages = []
        self.header_size = 0
        self.message_factory = None
        self.message_handlers = None

        self.log_throughput = True

        self.ping_message = None
        self.pong_message = None
        self.ack_message = None

        # Default network number to network number of current node. But it can change after hello message is received
        self.network_num = node.network_num

        self.message_validator = DefaultMessageValidator()

        self._debug_message_tracker = defaultdict(int)
        self._last_debug_message_log_time = time.time()
        self.ping_interval_s: int = constants.PING_INTERVAL_S
        self.peer_model: Optional[OutboundPeerModel] = None

        self.log_debug("Connection initialized.")
    def test_safe_empty(self):
        self.output_buffer = OutputBuffer(enable_buffering=False)
        messages = [
            helpers.generate_bytearray(10),
            helpers.generate_bytearray(10)
        ]
        for message in messages:
            self.output_buffer.enqueue_msgbytes(message)

        self.output_buffer.advance_buffer(5)
        self.assertEqual(15, len(self.output_buffer))

        self.output_buffer.safe_empty()
        self.assertEqual(5, len(self.output_buffer))
Exemple #7
0
    async def test_queuing_messages_no_blockchain_connection(self):
        node = self._initialize_gateway(True, True)
        blockchain_conn = next(
            iter(
                node.connection_pool.get_by_connection_types(
                    [ConnectionType.BLOCKCHAIN_NODE])))
        blockchain_conn.mark_for_close()

        self.assertIsNone(node.node_conn)

        queued_message = PingMessage(12345)
        node.send_msg_to_node(queued_message)
        self.assertEqual(1, len(node.node_msg_queue._queue))
        self.assertEqual(queued_message, node.node_msg_queue._queue[0])

        node.on_connection_added(
            MockSocketConnection(ip_address=LOCALHOST, port=8001))
        next_conn = next(
            iter(
                node.connection_pool.get_by_connection_types(
                    [ConnectionType.BLOCKCHAIN_NODE])))
        next_conn.outputbuf = OutputBuffer()  # clear buffer

        node.on_blockchain_connection_ready(next_conn)
        self.assertEqual(queued_message.rawbytes().tobytes(),
                         next_conn.outputbuf.get_buffer().tobytes())
Exemple #8
0
    def __init__(self, sock, address, node, from_me=False):
        super(InternalNodeConnection, self).__init__(sock, address, node, from_me)

        # Enable buffering only on internal connections
        self.enable_buffered_send = node.opts.enable_buffered_send
        self.outputbuf = OutputBuffer(enable_buffering=self.enable_buffered_send)

        self.network_num = node.network_num
        self.version_manager = bloxroute_version_manager

        # Setting default protocol version and message factory; override when hello message received
        self.message_factory = bloxroute_message_factory
        self.protocol_version = self.version_manager.CURRENT_PROTOCOL_VERSION

        self.ping_message = PingMessage()
        self.pong_message = PongMessage()
        self.ack_message = AckMessage()

        self.can_send_pings = True
        self.ping_message_timestamps = ExpiringDict(self.node.alarm_queue, constants.REQUEST_EXPIRATION_TIME)
        self.message_validator = BloxrouteMessageValidator(None, self.protocol_version)
Exemple #9
0
class MockConnection(AbstractConnection, SpecialMemoryProperties):
    CONNECTION_TYPE = MockConnectionType.MOCK

    def __init__(self, sock: SocketConnection, address, node, from_me=False):
        self.socket_connection = sock
        self.fileno = sock.fileno()

        # (IP, Port) at time of socket creation. We may get a new application level port in
        # the version message if the connection is not from me.
        self.peer_ip, self.peer_port = address
        self.peer_id = None
        self.my_ip = node.opts.external_ip
        self.my_port = node.opts.external_port

        self.from_me = from_me  # Whether or not I initiated the connection

        self.outputbuf = OutputBuffer()
        self.inputbuf = InputBuffer()
        self.node = node

        self.is_persistent = False
        self.state = ConnectionState.CONNECTING

        # Number of bad messages I've received in a row.
        self.num_bad_messages = 0
        self.peer_desc = "%s %d" % (self.peer_ip, self.peer_port)
        self.message_handlers = None
        self.network_num = node.opts.blockchain_network_num

        self.enqueued_messages = []

    def __repr__(self):
        return f"MockConnection<fileno: {self.fileno}, address: ({self.peer_ip}, {self.peer_port}), " \
               f"network_num: {self.network_num}>"

    def is_active(self):
        return self.state & ConnectionState.ESTABLISHED == ConnectionState.ESTABLISHED and \
               not self.state & ConnectionState.MARK_FOR_CLOSE

    def is_sendable(self):
        return self.is_active()

    def mark_for_close(self, force_destroy_now=False):
        self.state |= ConnectionState.MARK_FOR_CLOSE

    def add_received_bytes(self, bytes_received):
        self.inputbuf.add_bytes(bytes_received)
        self.mark_for_close()

    def get_bytes_to_send(self):
        return self.outputbuf.output_msgs[0]

    def advance_sent_bytes(self, bytes_sent):
        self.advance_bytes_on_buffer(self.outputbuf, bytes_sent)

    def advance_bytes_on_buffer(self, buf, bytes_written):
        buf.advance_buffer(bytes_written)

    def enqueue_msg(self, msg, _prepend_to_queue=False):

        if self.state & ConnectionState.MARK_FOR_CLOSE:
            return

        self.outputbuf.enqueue_msgbytes(msg.rawbytes())
        self.enqueued_messages.append(msg)

    def enqueue_msg_bytes(self, msg_bytes, prepend=False):

        if self.state & ConnectionState.MARK_FOR_CLOSE:
            return

        self.outputbuf.enqueue_msgbytes(msg_bytes)
        self.enqueued_messages.append(msg_bytes)

    def process_message(self):
        pass

    def send_ping(self):
        return PING_INTERVAL_S

    def special_memory_size(self, ids: Optional[Set[int]] = None) -> SpecialTuple:
        return memory_utils.add_special_objects(self.inputbuf, self.outputbuf, ids=ids)
 def test_safe_empty_no_contents(self):
     self.output_buffer = OutputBuffer(enable_buffering=False)
     self.output_buffer.safe_empty()
class MockConnection(AbstractConnection, SpecialMemoryProperties):
    CONNECTION_TYPE = ConnectionType.EXTERNAL_GATEWAY

    # pylint: disable=super-init-not-called
    def __init__(self, sock: AbstractSocketConnectionProtocol, node) -> None:
        self.socket_connection = sock
        self.file_no = sock.file_no

        # (IP, Port) at time of socket creation. We may get a new application level port in
        # the version message if the connection is not from me.
        self.peer_ip, self.peer_port = sock.endpoint
        self.endpoint = sock.endpoint
        self.peer_id = None
        self.my_ip = node.opts.external_ip
        self.my_port = node.opts.external_port
        self.direction = self.socket_connection.direction

        self.from_me = self.direction == NetworkDirection.OUTBOUND  # Whether or not I initiated the connection

        self.outputbuf = OutputBuffer()
        self.inputbuf = InputBuffer()
        self.node = node

        self.is_persistent = False
        self.state = ConnectionState.CONNECTING

        # Number of bad messages I've received in a row.
        self.num_bad_messages = 0
        self.peer_desc = "{} {}".format(self.peer_ip, self.peer_port)
        self.message_handlers = None
        self.network_num = node.opts.blockchain_network_num
        self.format_connection()

        self.enqueued_messages = []
        self.node_privileges = "general"
        self.subscribed_broadcasts = [BroadcastMessageType.BLOCK]

    def __repr__(self):
        return f"MockConnection<file_no: {self.file_no}, address: ({self.peer_ip}, {self.peer_port}), " \
            f"network_num: {self.network_num}>"

    def ping_message(self) -> AbstractMessage:
        pass

    def add_received_bytes(self, bytes_received):
        self.inputbuf.add_bytes(bytes_received)
        self.mark_for_close()

    def get_bytes_to_send(self):
        return self.outputbuf.output_msgs[0]

    def advance_sent_bytes(self, bytes_sent):
        self.advance_bytes_on_buffer(self.outputbuf, bytes_sent)

    def advance_bytes_on_buffer(self, buf, bytes_written):
        buf.advance_buffer(bytes_written)

    def enqueue_msg(self, msg: AbstractMessage, prepend: bool = False):
        if not self.is_alive():
            return

        self.outputbuf.enqueue_msgbytes(msg.rawbytes())
        self.enqueued_messages.append(msg)

    def enqueue_msg_bytes(self,
                          msg_bytes: Union[bytearray, memoryview],
                          prepend: bool = False,
                          full_message: Optional[AbstractMessage] = None):

        if not self.is_alive():
            return

        self.outputbuf.enqueue_msgbytes(msg_bytes)
        self.enqueued_messages.append(msg_bytes)

    def process_message(self):
        pass

    def send_ping(self):
        return PING_INTERVAL_S

    def special_memory_size(self,
                            ids: Optional[Set[int]] = None) -> SpecialTuple:
        return memory_utils.add_special_objects(self.inputbuf,
                                                self.outputbuf,
                                                ids=ids)
 def setUp(self):
     self.output_buffer = OutputBuffer(enable_buffering=True)
class AbstractConnection(Generic[Node]):
    __metaclass__ = ABCMeta

    CONNECTION_TYPE: ClassVar[ConnectionType] = ConnectionType.NONE
    node: Node
    message_factory: AbstractMessageFactory
    format_connection_desc: str
    connection_repr: str

    # performance critical attribute, has been pulled out of connection state
    established: bool

    def __init__(self, socket_connection: AbstractSocketConnectionProtocol, node: Node) -> None:
        self.socket_connection = socket_connection
        self.file_no = socket_connection.file_no

        # (IP, Port) at time of socket creation.
        # If the version/hello message contains a different port (i.e. connection is not from me), this will
        # be updated to the one in the message.
        self.endpoint = self.socket_connection.endpoint
        self.peer_ip, self.peer_port = self.endpoint
        self.peer_id: Optional[str] = None
        self.external_ip = node.opts.external_ip
        self.external_port = node.opts.external_port
        self.direction = self.socket_connection.direction
        self.from_me = self.direction == NetworkDirection.OUTBOUND

        self.outputbuf = OutputBuffer()
        self.inputbuf = InputBuffer()
        self.node = node

        self.state = ConnectionState.CONNECTING
        self.established = False

        # Number of bad messages I've received in a row.
        self.num_bad_messages = 0
        self.peer_desc = repr(self.endpoint)

        self.can_send_pings = False
        self.pong_timeout_enabled = False

        self.hello_messages = []
        self.header_size = 0
        self.message_factory = self.connection_message_factory()
        self.message_handlers = {}

        self.log_throughput = True

        self.pong_message = None
        self.ack_message = None

        self.ping_alarm_id: Optional[AlarmId] = None
        self.ping_interval_s = constants.PING_INTERVAL_S
        self.pong_timeout_alarm_id: Optional[AlarmId] = None

        # Default network number to network number of current node. But it can change after hello message is received
        self.network_num = node.network_num

        self.message_validator = DefaultMessageValidator()

        self._debug_message_tracker = defaultdict(int)
        self._last_debug_message_log_time = time.time()

        self.processing_message_index = 0

        self.peer_model: Optional[OutboundPeerModel] = None

        self._is_authenticated = False
        self.account_id: Optional[str] = None
        self.tier_name: Optional[str] = None

        self._close_waiter: Optional[Future] = None
        self.format_connection()

        self.log_debug("Connection initialized.")

    def __repr__(self):
        if logger.isEnabledFor(LogLevel.DEBUG):
            details = f"file_no: {self.file_no}, address: {self.peer_desc}, network_num: {self.network_num}"
        else:
            details = f"file_no: {self.file_no}, address: {self.peer_desc}"

        return f"{self.CONNECTION_TYPE} ({details})"

    @abstractmethod
    def ping_message(self) -> AbstractMessage:
        """
        Define ping message characteristics for pinging on connection.

        This function may have side-effects; only call this if the ping message will
        be used.
        """

    @abstractmethod
    def connection_message_factory(self) -> AbstractMessageFactory:
        pass

    def log_trace(self, message, *args, **kwargs):
        if logger.isEnabledFor(LogLevel.TRACE):
            self._log_message(LogLevel.TRACE, message, *args, **kwargs)

    def log_debug(self, message, *args, **kwargs):
        if logger.isEnabledFor(LogLevel.DEBUG):
            self._log_message(LogLevel.DEBUG, message, *args, **kwargs)

    def log_info(self, message, *args, **kwargs):
        self._log_message(LogLevel.INFO, message, *args, **kwargs)

    def log_warning(self, message, *args, **kwargs):
        self._log_message(LogLevel.WARNING, message, *args, **kwargs)

    def log_error(self, message, *args, **kwargs):
        self._log_message(LogLevel.ERROR, message, *args, **kwargs)

    def log(self, level: LogLevel, message, *args, **kwargs):
        self._log_message(level, message, *args, **kwargs)

    def is_active(self) -> bool:
        """
        Indicates whether the connection is established and ready for normal messages.

        This function is very frequently called. Avoid doing any sort of complex
        operations, inline function calls, and avoid flags.
        """
        return self.established and self.socket_connection is not None and self.socket_connection.alive

    def is_alive(self) -> bool:
        """
        Indicates whether the connection's socket is alive.

        This function is very frequently called. Avoid doing any sort of complex
        operations, inline function calls, and avoid flags.
        """
        if self.socket_connection is None:
            return False
        return self.socket_connection.alive

    def on_connection_established(self):
        if not self.is_active():
            self.state |= ConnectionState.HELLO_RECVD
            self.state |= ConnectionState.HELLO_ACKD
            self.state |= ConnectionState.ESTABLISHED
            self.established = True

            self.log_info("Connection established.")

            # Reset num_retries when a connection established in order to support resetting the Fibonnaci logic
            # to determine next retry
            self.node.num_retries_by_ip[(self.peer_ip, self.peer_port)] = 0

            for peer_model in self.node.outbound_peers:
                if (
                    (
                        peer_model.ip == self.peer_ip
                        and peer_model.port == self.peer_port
                    )
                    or peer_model.node_id == self.peer_id
                ):
                    self.peer_model = peer_model

    def add_received_bytes(self, bytes_received: Union[bytearray, bytes]):
        """
        Adds bytes received from socket connection to input buffer

        :param bytes_received: new bytes received from socket connection
        """
        assert self.is_alive()

        self.inputbuf.add_bytes(bytes_received)

    def get_bytes_to_send(self):
        assert self.is_alive()

        return self.outputbuf.get_buffer()

    def advance_sent_bytes(self, bytes_sent):
        self.advance_bytes_on_buffer(self.outputbuf, bytes_sent)

    def enqueue_msg(self, msg: AbstractMessage, prepend: bool = False):
        """
        Enqueues the contents of a Message instance, msg, to our outputbuf and attempts to send it if the underlying
        socket has room in the send buffer.

        :param msg: message
        :param prepend: if the message should be bumped to the front of the outputbuf
        """
        self._log_message(msg.log_level(), "Enqueued message: {}", msg)
        self.enqueue_msg_bytes(msg.rawbytes(), prepend)

    def enqueue_msg_bytes(
        self,
        msg_bytes: Union[bytearray, memoryview],
        prepend: bool = False,
    ):
        """
        Enqueues the raw bytes of a message, msg_bytes, to our outputbuf and attempts to send it if the
        underlying socket has room in the send buffer.

        This function is very frequently called. Avoid doing any sort of complex
        operations, inline function calls, and avoid flags.

        :param msg_bytes: message bytes
        :param prepend: if the message should be bumped to the front of the outputbuf
        """

        if not self.socket_connection.alive:
            return

        self.log_trace("Enqueued {} bytes.", len(msg_bytes))

        if prepend:
            self.outputbuf.prepend_msgbytes(msg_bytes)
        else:
            self.outputbuf.enqueue_msgbytes(msg_bytes)

        self.socket_connection.send()

    def pre_process_msg(self) -> ConnectionMessagePreview:
        is_full_msg, msg_type, payload_len = self.message_factory.get_message_header_preview_from_input_buffer(
            self.inputbuf
        )

        return ConnectionMessagePreview(is_full_msg, True, msg_type, payload_len)

    def process_msg_type(self, message_type, is_full_msg, payload_len):
        """
        Processes messages that require changes to the regular message handling flow
        (pop off single message, process it, continue on with the stream)

        :param message_type: message type
        :param is_full_msg: flag indicating if full message is available on input buffer
        :param payload_len: length of payload
        :return:
        """

    def process_message(self):
        """
        Processes the next bytes on the socket's inputbuffer.
        Returns 0 in order to avoid being rescheduled if this was an alarm.
        """
        # pylint: disable=too-many-return-statements, too-many-branches, too-many-statements

        logger.trace("START PROCESSING from {}", self)

        start_time = time.time()
        messages_processed = defaultdict(int)
        total_bytes_processed = 0

        self.processing_message_index = 0

        while True:
            input_buffer_len_before = self.inputbuf.length
            is_full_msg = False
            payload_len = None
            msg = None
            msg_type = None

            try:
                # abort message processing if connection has been closed
                if not self.socket_connection.alive:
                    return

                is_full_msg, should_process, msg_type, payload_len = self.pre_process_msg()

                if not should_process and is_full_msg:
                    self.pop_next_bytes(payload_len)
                    continue

                self.message_validator.validate(
                    is_full_msg,
                    msg_type,
                    self.header_size,
                    payload_len,
                    self.inputbuf
                )

                self.process_msg_type(msg_type, is_full_msg, payload_len)

                if not is_full_msg:
                    break

                msg = self.pop_next_message(payload_len)
                total_bytes_processed += len(msg.rawbytes())

                # If there was some error in parsing this message, then continue the loop.
                if msg is None:
                    if self._report_bad_message():
                        return
                    continue

                # Full messages must be one of the handshake messages if the connection isn't established yet.
                if (
                    not self.established and msg_type not in self.hello_messages
                ):
                    self.log_warning(log_messages.UNEXPECTED_MESSAGE, msg_type)
                    self.mark_for_close()
                    return

                if self.log_throughput:
                    hooks.add_throughput_event(
                        NetworkDirection.INBOUND,
                        msg_type,
                        len(msg.rawbytes()),
                        self.peer_desc,
                        self.peer_id
                    )

                if not logger.isEnabledFor(msg.log_level()) and logger.isEnabledFor(LogLevel.INFO):
                    self._debug_message_tracker[msg_type] += 1
                elif len(self._debug_message_tracker) > 0:
                    self.log_debug(
                        "Processed the following messages types: {} over {:.2f} seconds.",
                        self._debug_message_tracker,
                        time.time() - self._last_debug_message_log_time
                    )
                    self._debug_message_tracker.clear()
                    self._last_debug_message_log_time = time.time()

                self._log_message(msg.log_level(), "Processing message: {}", msg)

                if msg_type in self.message_handlers:
                    msg_handler = self.message_handlers[msg_type]

                    handler_start = time.time()
                    msg_handler(msg)
                    performance_utils.log_operation_duration(
                        msg_handling_logger,
                        "Single message handler",
                        handler_start,
                        constants.MSG_HANDLERS_CYCLE_DURATION_WARN_THRESHOLD_S,
                        connection=self,
                        handler=msg_handler,
                        message=msg
                    )
                messages_processed[msg_type] += 1

            # TODO: Investigate possible solutions to recover from PayloadLenError errors
            except PayloadLenError as e:
                self.log_error(log_messages.COULD_NOT_PARSE_MESSAGE, e.msg)
                self.mark_for_close()
                return

            except MemoryError as e:
                self.log_error(log_messages.OUT_OF_MEMORY, e, exc_info=True)
                self.log_debug(
                    "Failed message bytes: {}",
                    self._get_last_msg_bytes(msg, input_buffer_len_before, payload_len)
                )
                raise

            except UnauthorizedMessageError as e:
                self.log_error(log_messages.UNAUTHORIZED_MESSAGE, e.msg.MESSAGE_TYPE, self.peer_desc)
                self.log_debug(
                    "Failed message bytes: {}",
                    self._get_last_msg_bytes(msg, input_buffer_len_before, payload_len)
                )

                # give connection a chance to restore its state and get ready to process next message
                self.clean_up_current_msg(payload_len, input_buffer_len_before == self.inputbuf.length)

                if self._report_bad_message():
                    return

            except MessageValidationError as e:
                if self.node.NODE_TYPE not in NodeType.GATEWAY_TYPE:
                    if isinstance(e, ControlFlagValidationError):
                        if e.is_cancelled_cut_through:
                            self.log_debug(
                                "Message validation failed for {} message: {}. Probably cut-through cancellation",
                                msg_type, e.msg)
                        else:
                            self.log_warning(log_messages.MESSAGE_VALIDATION_FAILED, msg_type, e.msg)
                    else:
                        self.log_warning(log_messages.MESSAGE_VALIDATION_FAILED, msg_type, e.msg)
                else:
                    self.log_debug("Message validation failed for {} message: {}.", msg_type, e.msg)
                self.log_debug("Failed message bytes: {}",
                               self._get_last_msg_bytes(msg, input_buffer_len_before, payload_len))

                if is_full_msg:
                    self.clean_up_current_msg(payload_len, input_buffer_len_before == self.inputbuf.length)
                else:
                    self.log_error(log_messages.UNABLE_TO_RECOVER_PARTIAL_MESSAGE)
                    self.mark_for_close()
                    return

                if self._report_bad_message():
                    return

            except NonVersionMessageError as e:
                if e.is_known:
                    self.log_debug("Received invalid handshake request on {}:{}, {}", self.peer_ip, self.peer_port,
                                   e.msg)
                else:
                    self.log_warning(log_messages.INVALID_HANDSHAKE, self.peer_ip, self.peer_port, e.msg)
                self.log_debug("Failed message bytes: {}",
                               self._get_last_msg_bytes(msg, input_buffer_len_before, payload_len))

                self.mark_for_close()
                return

            # TODO: Throw custom exception for any errors that come from input that has not been
            # validated and only catch that subclass of exceptions
            # pylint: disable=broad-except
            except Exception as e:

                # Attempt to recover connection by removing bad full message
                if is_full_msg:
                    self.log_error(log_messages.TRYING_TO_RECOVER_MESSAGE, e, exc_info=True)
                    self.log_debug("Failed message bytes: {}",
                                   self._get_last_msg_bytes(msg, input_buffer_len_before, payload_len))

                    # give connection a chance to restore its state and get ready to process next message
                    self.clean_up_current_msg(payload_len, input_buffer_len_before == self.inputbuf.length)

                # Connection is unable to recover from message processing error if incomplete message is received
                else:
                    self.log_error(log_messages.UNABLE_TO_RECOVER_FULL_MESSAGE, e, exc_info=True)
                    self.log_debug("Failed message bytes: {}",
                                   self._get_last_msg_bytes(msg, input_buffer_len_before, payload_len))
                    self.mark_for_close()
                    return

                if self._report_bad_message():
                    return
            else:
                self.num_bad_messages = 0

            self.processing_message_index += 1

        performance_utils.log_operation_duration(msg_handling_logger,
                                                 "Message handlers",
                                                 start_time,
                                                 constants.MSG_HANDLERS_DURATION_WARN_THRESHOLD_S,
                                                 connection=self, count=messages_processed)
        duration_ms = (time.time() - start_time) * 1000
        logger.trace("DONE PROCESSING from {}. Bytes processed: {}. Messages processed: {}. Duration: {}",
                     self, total_bytes_processed, messages_processed, stats_format.duration(duration_ms))

    def pop_next_message(self, payload_len: int) -> AbstractMessage:
        """
        Pop the next full message off of the buffer given the message length.
        Preserves invariant of self.inputbuf always containing the start of a
        valid message. The caller of this function is responsible for ensuring
        there is a valid message on the buffer.

        :param payload_len: length of payload
        :return: message object
        """
        return self.message_factory.create_message_from_buffer(
            self.pop_next_bytes(payload_len)
        )

    def pop_next_bytes(self, payload_len: int) -> Union[memoryview, bytearray, bytes]:
        msg_len = self.message_factory.base_message_type.HEADER_LENGTH + payload_len
        return self.inputbuf.remove_bytes(msg_len)

    def advance_bytes_on_buffer(self, buf, bytes_written):
        hooks.add_throughput_event(NetworkDirection.OUTBOUND, None, bytes_written, self.peer_desc, self.peer_id)
        try:
            buf.advance_buffer(bytes_written)
        except ValueError as e:
            raise RuntimeError("Connection: {}, Failed to advance buffer".format(self)) from e

    def schedule_pings(self) -> None:
        """
        Schedules ping on the connection. Multiple calls of this method will
        override the existing alarm.
        """
        self._unschedule_pings()

        if self.can_send_pings:
            self.ping_alarm_id = self.node.alarm_queue.register_alarm(
                self.ping_interval_s, self.send_ping
            )

    def send_ping(self) -> float:
        """
        Send a ping (and reschedule if called from alarm queue)
        """
        if self.can_send_pings and self.is_alive():
            self.enqueue_msg(self.ping_message())

            if self.pong_timeout_enabled:
                self.schedule_pong_timeout()

            return self.ping_interval_s
        return constants.CANCEL_ALARMS

    def msg_hello(self, msg):
        self.state |= ConnectionState.HELLO_RECVD
        if msg.node_id() is None:
            self.log_debug("Received hello message without peer id.")
        if self.peer_id is None:
            self.peer_id = msg.node_id()

        # This should only be necessary for pre1.6 connections.
        if self.peer_id in self.node.connection_pool.by_node_id and \
                self.socket_connection.endpoint.ip_address[0:3] != "172":
            existing_connection = self.node.connection_pool.get_by_node_id(
                self.peer_id
            )
            if existing_connection != self:
                if self.from_me:
                    self.mark_for_close()

                if existing_connection.from_me:
                    existing_connection.mark_for_close()

                self.log_warning(log_messages.DUPLICATE_CONNECTION, self.peer_id, existing_connection)

        self.enqueue_msg(self.ack_message)
        if ConnectionState.INITIALIZED | ConnectionState.HELLO_ACKD in self.state:
            self.on_connection_established()

    def msg_ack(self, _msg):
        """
        Handle an Ack Message
        """
        self.state |= ConnectionState.HELLO_ACKD
        if ConnectionState.INITIALIZED | ConnectionState.HELLO_RECVD in self.state:
            self.on_connection_established()

    def msg_ping(self, _msg):
        self.enqueue_msg(self.pong_message)

    def msg_pong(self, _msg):
        self.cancel_pong_timeout()

    def mark_for_close(self, should_retry: Optional[bool] = None):
        """
        Marks a connection for close, so AbstractNode can dispose of this class.
        Use this where possible for a clean shutdown.
        """
        loop = asyncio.get_event_loop()
        self._close_waiter = loop.create_future()

        if should_retry is None:
            should_retry = self.from_me

        self.log_debug("Marking connection for close, should_retry: {}.", should_retry)
        self.socket_connection.mark_for_close(should_retry)

    def dispose(self):
        """
        Performs any need operations after connection object has been discarded by the AbstractNode.
        """
        if self._close_waiter is not None:
            self._close_waiter.set_result(True)
        self._unschedule_pings()
        self.node = None
        self.socket_connection = None

    def clean_up_current_msg(self, payload_len: int, msg_is_in_input_buffer: bool) -> None:
        """
        Removes current message from the input buffer and resets connection to a state ready to process next message.
        Called during the handling of message processing exceptions.

        :param payload_len: length of the payload of the currently processing message
        :param msg_is_in_input_buffer: flag indicating if message bytes are still in the input buffer
        :return:
        """

        if msg_is_in_input_buffer:
            self.inputbuf.remove_bytes(self.header_size + payload_len)

    def on_input_received(self) -> bool:
        """handles an input event from the event loop

        :return: True if the connection is receivable, otherwise False
        """
        return True

    def log_connection_mem_stats(self) -> None:
        """
        logs the connection's memory stats
        """
        class_name = self.__class__.__name__
        hooks.add_obj_mem_stats(
            class_name,
            self.network_num,
            self.inputbuf,
            "input_buffer",
            memory_utils.ObjectSize("input_buffer", 0,
                                    is_actual_size=False),
            object_item_count=len(self.inputbuf.input_list),
            object_type=memory_utils.ObjectType.BASE,
            size_type=memory_utils.SizeType.TRUE
        )
        hooks.add_obj_mem_stats(
            class_name,
            self.network_num,
            self.outputbuf,
            "output_buffer",
            memory_utils.ObjectSize("output_buffer", 0,
                                    is_actual_size=False),
            object_item_count=len(self.outputbuf.output_msgs),
            object_type=memory_utils.ObjectType.BASE,
            size_type=memory_utils.SizeType.TRUE
        )

    def update_model(self, model: OutboundPeerModel):
        self.log_trace("Updated connection model: {}", model)
        self.peer_model = model

    def schedule_pong_timeout(self) -> None:
        if self.pong_timeout_alarm_id is None:
            self.log_trace(
                "Schedule pong reply timeout for ping message in {} seconds",
                constants.PING_PONG_REPLY_TIMEOUT_S
            )
            self.pong_timeout_alarm_id = self.node.alarm_queue.register_alarm(
                constants.PING_PONG_REPLY_TIMEOUT_S, self._pong_msg_timeout
            )

    def cancel_pong_timeout(self):
        if self.pong_timeout_alarm_id is not None:
            self.node.alarm_queue.unregister_alarm(self.pong_timeout_alarm_id)
            self.pong_timeout_alarm_id = None

    def on_connection_authenticated(self, peer_info: AuthenticatedPeerInfo) -> None:
        self.peer_id = peer_info.peer_id
        if self.CONNECTION_TYPE != peer_info.connection_type:
            self.node.connection_pool.update_connection_type(self, peer_info.connection_type)
        self.account_id = peer_info.account_id
        self._is_authenticated = True

    async def wait_closed(self):
        if self._close_waiter is not None:
            await self._close_waiter
            self._close_waiter = None
        else:
            await asyncio.sleep(0)

        if self.is_alive():
            raise ConnectionStateError("Connection is still alive after closed", self)

    def set_account_id(self, account_id: Optional[str], tier_name: str):
        if self._is_authenticated and account_id != self.account_id:
            raise ConnectionAuthenticationError(
                f"Invalid account id {account_id} is different than connection account id: {self.account_id}")

        if not self._is_authenticated:
            self.account_id = account_id
        self.tier_name = tier_name

    def get_backlog_size(self) -> int:
        output_buffer_backlog = self.outputbuf.length
        socket_buffer_backlog = self.socket_connection.get_write_buffer_size()
        self.log_trace("Output backlog: {}, socket backlog: {}", output_buffer_backlog, socket_buffer_backlog)
        return output_buffer_backlog + socket_buffer_backlog

    def format_connection(self) -> None:
        self.format_connection_desc = "{} - {}".format(
            self.peer_desc, self.CONNECTION_TYPE.format_short()
        )
        self.connection_repr = repr(self)

    def _unschedule_pings(self) -> None:
        existing_alarm = self.ping_alarm_id
        if existing_alarm:
            self.node.alarm_queue.unregister_alarm(existing_alarm)
            self.ping_alarm_id = None

    def _pong_msg_timeout(self) -> None:
        if self.is_alive():
            self.log_info(
                "Connection appears to be broken. Peer did not reply to PING message within allocated time. "
                "Closing connection."
            )
            self.mark_for_close()
            self.pong_timeout_alarm_id = None

    def _report_bad_message(self):
        """
        Increments counter for bad messages. Returns True if connection should be closed.
        :return: if connection should be closed
        """
        if self.num_bad_messages == constants.MAX_BAD_MESSAGES:
            self.log_warning(log_messages.TOO_MANY_BAD_MESSAGES)
            self.mark_for_close()
            return True
        else:
            self.num_bad_messages += 1
            return False

    def _get_last_msg_bytes(self, msg, input_buffer_len_before, payload_len):

        if msg is not None:
            return convert.bytes_to_hex(msg.rawbytes()[:constants.MAX_LOGGED_BYTES_LEN])

        # bytes still available on input buffer
        if input_buffer_len_before == self.inputbuf.length and payload_len is not None:
            return convert.bytes_to_hex(
                self.inputbuf.peek_message(min(self.header_size + payload_len, constants.MAX_LOGGED_BYTES_LEN)))

        return "<not available>"

    def _log_message(self, level: LogLevel, message, *args, **kwargs):
        logger.log(
            level, message, HAS_PREFIX, self.connection_repr, *args, **kwargs
        )
    def __init__(self, socket_connection: AbstractSocketConnectionProtocol, node: Node) -> None:
        self.socket_connection = socket_connection
        self.file_no = socket_connection.file_no

        # (IP, Port) at time of socket creation.
        # If the version/hello message contains a different port (i.e. connection is not from me), this will
        # be updated to the one in the message.
        self.endpoint = self.socket_connection.endpoint
        self.peer_ip, self.peer_port = self.endpoint
        self.peer_id: Optional[str] = None
        self.external_ip = node.opts.external_ip
        self.external_port = node.opts.external_port
        self.direction = self.socket_connection.direction
        self.from_me = self.direction == NetworkDirection.OUTBOUND

        self.outputbuf = OutputBuffer()
        self.inputbuf = InputBuffer()
        self.node = node

        self.state = ConnectionState.CONNECTING
        self.established = False

        # Number of bad messages I've received in a row.
        self.num_bad_messages = 0
        self.peer_desc = repr(self.endpoint)

        self.can_send_pings = False
        self.pong_timeout_enabled = False

        self.hello_messages = []
        self.header_size = 0
        self.message_factory = self.connection_message_factory()
        self.message_handlers = {}

        self.log_throughput = True

        self.pong_message = None
        self.ack_message = None

        self.ping_alarm_id: Optional[AlarmId] = None
        self.ping_interval_s = constants.PING_INTERVAL_S
        self.pong_timeout_alarm_id: Optional[AlarmId] = None

        # Default network number to network number of current node. But it can change after hello message is received
        self.network_num = node.network_num

        self.message_validator = DefaultMessageValidator()

        self._debug_message_tracker = defaultdict(int)
        self._last_debug_message_log_time = time.time()

        self.processing_message_index = 0

        self.peer_model: Optional[OutboundPeerModel] = None

        self._is_authenticated = False
        self.account_id: Optional[str] = None
        self.tier_name: Optional[str] = None

        self._close_waiter: Optional[Future] = None
        self.format_connection()

        self.log_debug("Connection initialized.")
class TestOutputBuffer(unittest.TestCase):
    def setUp(self):
        self.output_buffer = OutputBuffer(enable_buffering=True)

    def test_get_buffer(self):
        self.assertEqual(OutputBuffer.EMPTY, self.output_buffer.get_buffer())

        data1 = bytearray([i for i in range(20)])
        self.output_buffer.enqueue_msgbytes(data1)
        self.output_buffer.flush()
        self.assertEqual(data1, self.output_buffer.get_buffer())

        data2 = bytearray([i for i in range(20, 40)])
        self.output_buffer.enqueue_msgbytes(data2)
        self.output_buffer.flush()
        self.assertEqual(data1, self.output_buffer.get_buffer())

        new_index = 10
        self.output_buffer.index = new_index
        self.assertEqual(data1[new_index:], self.output_buffer.get_buffer())

    def test_advance_buffer(self):
        with self.assertRaises(ValueError):
            self.output_buffer.advance_buffer(5)

        data1 = bytearray([i for i in range(20)])
        self.output_buffer.enqueue_msgbytes(data1)
        self.output_buffer.flush()
        data2 = bytearray([i for i in range(20, 40)])
        self.output_buffer.enqueue_msgbytes(data2)
        self.output_buffer.flush()

        self.output_buffer.advance_buffer(10)
        self.assertEqual(10, self.output_buffer.index)
        self.assertEqual(30, self.output_buffer.length)

        self.output_buffer.advance_buffer(10)
        self.assertEqual(0, self.output_buffer.index)
        self.assertEqual(1, len(self.output_buffer.output_msgs))

    def test_at_msg_boundary(self):
        self.assertTrue(self.output_buffer.at_msg_boundary())
        self.output_buffer.index = 1
        self.assertFalse(self.output_buffer.at_msg_boundary())

    def test_enqueue_msgbytes(self):
        with self.assertRaises(ValueError):
            self.output_buffer.enqueue_msgbytes("f")

        data1 = bytearray([i for i in range(20)])
        self.output_buffer.enqueue_msgbytes(data1)
        self.output_buffer.flush()
        self.assertEqual(data1, self.output_buffer.get_buffer())

        data2 = bytearray([i for i in range(20, 40)])
        self.output_buffer.enqueue_msgbytes(data2)
        self.output_buffer.flush()
        self.assertEqual(data1, self.output_buffer.get_buffer())

        new_index = 10
        self.output_buffer.index = new_index
        self.assertEqual(data1[new_index:], self.output_buffer.get_buffer())

    def test_prepend_msgbytes(self):
        with self.assertRaises(ValueError):
            self.output_buffer.prepend_msgbytes("f")

        data1 = bytearray([i for i in range(20)])
        self.output_buffer.prepend_msgbytes(data1)

        data2 = bytearray([i for i in range(20, 40)])
        self.output_buffer.prepend_msgbytes(data2)

        confirm1 = deque()
        confirm1.append(data2)
        confirm1.append(data1)

        self.assertEqual(confirm1, self.output_buffer.output_msgs)
        self.assertEqual(40, self.output_buffer.length)

        self.output_buffer.advance_buffer(10)

        data3 = bytearray([i for i in range(40, 60)])
        self.output_buffer.prepend_msgbytes(data3)

        confirm2 = deque()
        confirm2.append(data2)
        confirm2.append(data3)
        confirm2.append(data1)

        self.assertEqual(confirm2, self.output_buffer.output_msgs)
        self.assertEqual(50, self.output_buffer.length)

    def test_has_more_bytes(self):
        self.assertFalse(self.output_buffer.has_more_bytes())
        self.output_buffer.length = 1
        self.assertTrue(self.output_buffer.has_more_bytes())

    def test_flush_get_buffer_on_time(self):
        data1 = bytearray(i for i in range(20))
        self.output_buffer.enqueue_msgbytes(data1)
        self.assertEqual(OutputBuffer.EMPTY, self.output_buffer.get_buffer())

        time.time = MagicMock(return_value=time.time() +
                              OUTPUT_BUFFER_BATCH_MAX_HOLD_TIME + 0.001)
        self.assertEqual(data1, self.output_buffer.get_buffer())

    def test_flush_get_buffer_on_size(self):
        data1 = bytearray(i for i in range(20))
        self.output_buffer.enqueue_msgbytes(data1)
        self.assertEqual(OutputBuffer.EMPTY, self.output_buffer.get_buffer())

        data2 = bytearray(1 for _ in range(OUTPUT_BUFFER_MIN_SIZE))
        self.output_buffer.enqueue_msgbytes(data2)
        self.assertNotEqual(OutputBuffer.EMPTY,
                            self.output_buffer.get_buffer())

    def test_safe_empty(self):
        self.output_buffer = OutputBuffer(enable_buffering=False)
        messages = [
            helpers.generate_bytearray(10),
            helpers.generate_bytearray(10)
        ]
        for message in messages:
            self.output_buffer.enqueue_msgbytes(message)

        self.output_buffer.advance_buffer(5)
        self.assertEqual(15, len(self.output_buffer))

        self.output_buffer.safe_empty()
        self.assertEqual(5, len(self.output_buffer))

    def test_safe_empty_no_contents(self):
        self.output_buffer = OutputBuffer(enable_buffering=False)
        self.output_buffer.safe_empty()

    def test_safe_empty_buffering(self):
        messages = [
            helpers.generate_bytearray(10),
            helpers.generate_bytearray(10)
        ]
        for message in messages:
            self.output_buffer.enqueue_msgbytes(message)

        self.assertEqual(20, len(self.output_buffer))
        self.assertEqual(OutputBuffer.EMPTY, self.output_buffer.get_buffer())

        self.output_buffer.safe_empty()
        self.assertEqual(0, len(self.output_buffer))
        self.assertEqual(OutputBuffer.EMPTY, self.output_buffer.get_buffer())
Exemple #16
0
class AbstractConnection(Generic[Node]):
    __metaclass__ = ABCMeta

    CONNECTION_TYPE: ClassVar[ConnectionType] = ConnectionType.NONE
    node: Node

    def __init__(self, socket_connection, address, node: Node, from_me=False):
        if not isinstance(socket_connection, SocketConnection):
            raise ValueError(
                "SocketConnection type is expected for socket_connection arg but was {0}."
                .format(type(socket_connection)))

        self.socket_connection = socket_connection
        self.fileno = socket_connection.fileno()

        # (IP, Port) at time of socket creation.
        # If the version/hello message contains a different port (i.e. connection is not from me), this will
        # be updated to the one in the message.
        self.peer_ip, self.peer_port = address
        self.peer_id: Optional[str] = None
        self.external_ip = node.opts.external_ip
        self.external_port = node.opts.external_port

        self.from_me = from_me  # Whether or not I initiated the connection

        if node.opts.track_detailed_sent_messages:
            self.message_tracker = MessageTracker(self)
        self.outputbuf = OutputBuffer()
        self.inputbuf = InputBuffer()
        self.node = node

        self.state = ConnectionState.CONNECTING

        # Number of bad messages I've received in a row.
        self.num_bad_messages = 0
        self.peer_desc = "%s %d" % (self.peer_ip, self.peer_port)

        self.can_send_pings = False

        self.hello_messages = []
        self.header_size = 0
        self.message_factory = None
        self.message_handlers = None

        self.log_throughput = True

        self.ping_message = None
        self.pong_message = None
        self.ack_message = None

        # Default network number to network number of current node. But it can change after hello message is received
        self.network_num = node.network_num

        self.message_validator = DefaultMessageValidator()

        self._debug_message_tracker = defaultdict(int)
        self._last_debug_message_log_time = time.time()
        self.ping_interval_s: int = constants.PING_INTERVAL_S
        self.peer_model: Optional[OutboundPeerModel] = None

        self.log_debug("Connection initialized.")

    def __repr__(self):
        if logger.isEnabledFor(LogLevel.DEBUG):
            details = f"fileno: {self.fileno}, address: {self.peer_desc}, network_num: {self.network_num}"
        else:
            details = f"fileno: {self.fileno}, address: {self.peer_desc}"

        return f"{self.CONNECTION_TYPE}({details})"

    def _log_message(self, level: LogLevel, message, *args, **kwargs):
        logger.log(level, f"[{self}] {message}", *args, **kwargs)

    def log_trace(self, message, *args, **kwargs):
        self._log_message(LogLevel.TRACE, message, *args, **kwargs)

    def log_debug(self, message, *args, **kwargs):
        self._log_message(LogLevel.DEBUG, message, *args, **kwargs)

    def log_info(self, message, *args, **kwargs):
        self._log_message(LogLevel.INFO, message, *args, **kwargs)

    def log_warning(self, message, *args, **kwargs):
        self._log_message(LogLevel.WARNING, message, *args, **kwargs)

    def log_error(self, message, *args, **kwargs):
        self._log_message(LogLevel.ERROR, message, *args, **kwargs)

    def is_active(self):
        """
        Indicates whether the connection is established and not marked for close.
        """
        return self.state & ConnectionState.ESTABLISHED == ConnectionState.ESTABLISHED and \
               not self.state & ConnectionState.MARK_FOR_CLOSE

    def is_sendable(self):
        """
        Indicates whether the connection should send bytes on broadcast.
        """
        return self.is_active()

    def on_connection_established(self):
        self.state |= ConnectionState.ESTABLISHED
        self.log_info("Connection established.")

    def add_received_bytes(self, bytes_received: int):
        """
        Adds bytes received from socket connection to input buffer

        :param bytes_received: new bytes received from socket connection
        """
        assert not self.state & ConnectionState.MARK_FOR_CLOSE

        self.inputbuf.add_bytes(bytes_received)

    def get_bytes_to_send(self):
        assert not self.state & ConnectionState.MARK_FOR_CLOSE

        return self.outputbuf.get_buffer()

    def advance_sent_bytes(self, bytes_sent):
        self.advance_bytes_on_buffer(self.outputbuf, bytes_sent)
        if self.message_tracker:
            self.message_tracker.advance_bytes(bytes_sent)

    def enqueue_msg(self, msg: AbstractMessage, prepend: bool = False):
        """
        Enqueues the contents of a Message instance, msg, to our outputbuf and attempts to send it if the underlying
        socket has room in the send buffer.

        :param msg: message
        :param prepend: if the message should be bumped to the front of the outputbuf
        """
        self._log_message(msg.log_level(), "Enqueued message: {}", msg)

        if self.message_tracker:
            full_message = msg
        else:
            full_message = None
        self.enqueue_msg_bytes(msg.rawbytes(), prepend, full_message)

    def enqueue_msg_bytes(self,
                          msg_bytes: Union[bytearray, memoryview],
                          prepend: bool = False,
                          full_message: Optional[AbstractMessage] = None):
        """
        Enqueues the raw bytes of a message, msg_bytes, to our outputbuf and attempts to send it if the
        underlying socket has room in the send buffer.

        :param msg_bytes: message bytes
        :param prepend: if the message should be bumped to the front of the outputbuf
        :param full_message: full message for detailed logging
        """

        if self.state & ConnectionState.MARK_FOR_CLOSE:
            return

        size = len(msg_bytes)

        self.log_trace("Enqueued {} bytes.", size)

        if prepend:
            self.outputbuf.prepend_msgbytes(msg_bytes)
            if self.message_tracker:
                self.message_tracker.prepend_message(len(msg_bytes),
                                                     full_message)
        else:
            self.outputbuf.enqueue_msgbytes(msg_bytes)
            if self.message_tracker:
                self.message_tracker.append_message(len(msg_bytes),
                                                    full_message)

        # TODO: temporary fix for some situations where, see https://bloxroute.atlassian.net/browse/BX-1153
        self.socket_connection.can_send = True
        self.socket_connection.send()

    def pre_process_msg(self):
        is_full_msg, msg_type, payload_len = self.message_factory.get_message_header_preview_from_input_buffer(
            self.inputbuf)

        return is_full_msg, msg_type, payload_len

    def process_msg_type(self, message_type, is_full_msg, payload_len):
        """
        Processes messages that require changes to the regular message handling flow
        (pop off single message, process it, continue on with the stream)

        :param message_type: message type
        :param is_full_msg: flag indicating if full message is available on input buffer
        :param payload_len: length of payload
        :return:
        """

        pass

    def process_message(self):
        """
        Processes the next bytes on the socket's inputbuffer.
        Returns 0 in order to avoid being rescheduled if this was an alarm.
        """

        start_time = time.time()
        messages_processed = defaultdict(int)

        while True:
            input_buffer_len_before = self.inputbuf.length
            is_full_msg = False
            payload_len = 0
            msg = None
            msg_type = None

            try:
                # abort message processing if connection has been closed
                if self.state & ConnectionState.MARK_FOR_CLOSE:
                    return

                is_full_msg, msg_type, payload_len = self.pre_process_msg()

                self.message_validator.validate(is_full_msg, msg_type,
                                                self.header_size, payload_len,
                                                self.inputbuf)

                self.process_msg_type(msg_type, is_full_msg, payload_len)

                if not is_full_msg:
                    break

                msg = self.pop_next_message(payload_len)

                # If there was some error in parsing this message, then continue the loop.
                if msg is None:
                    if self._report_bad_message():
                        return
                    continue

                # Full messages must be one of the handshake messages if the connection isn't established yet.
                if not (self.state & ConnectionState.ESTABLISHED == ConnectionState.ESTABLISHED) \
                        and msg_type not in self.hello_messages:
                    self.log_warning(
                        "Received unexpected message ({}) before handshake completed. Closing.",
                        msg_type)
                    self.mark_for_close()
                    return

                if self.log_throughput:
                    hooks.add_throughput_event(Direction.INBOUND, msg_type,
                                               len(msg.rawbytes()),
                                               self.peer_desc)

                if not logger.isEnabledFor(
                        msg.log_level()) and logger.isEnabledFor(
                            LogLevel.INFO):
                    self._debug_message_tracker[msg_type] += 1
                elif len(self._debug_message_tracker) > 0:
                    self.log_debug(
                        "Processed the following messages types: {} over {:.2f} seconds.",
                        self._debug_message_tracker,
                        time.time() - self._last_debug_message_log_time)
                    self._debug_message_tracker.clear()
                    self._last_debug_message_log_time = time.time()

                self._log_message(msg.log_level(), "Processing message: {}",
                                  msg)

                if msg_type in self.message_handlers:
                    msg_handler = self.message_handlers[msg_type]
                    msg_handler(msg)

                messages_processed[msg_type] += 1

            # TODO: Investigate possible solutions to recover from PayloadLenError errors
            except PayloadLenError as e:
                self.log_error("Could not parse message. Error: {}", e.msg)
                self.mark_for_close()
                return

            except MemoryError as e:
                self.log_error(
                    "Out of memory error occurred during message processing. Error: {}. ",
                    e,
                    exc_info=True)
                self.log_debug(
                    "Failed message bytes: {}",
                    self._get_last_msg_bytes(msg, input_buffer_len_before,
                                             payload_len))
                raise

            except UnauthorizedMessageError as e:
                self.log_error("Unauthorized message {} from {}.",
                               e.msg.MESSAGE_TYPE, self.peer_desc)
                self.log_debug(
                    "Failed message bytes: {}",
                    self._get_last_msg_bytes(msg, input_buffer_len_before,
                                             payload_len))

                # give connection a chance to restore its state and get ready to process next message
                self.clean_up_current_msg(
                    payload_len,
                    input_buffer_len_before == self.inputbuf.length)

                if self._report_bad_message():
                    return

            except MessageValidationError as e:
                self.log_warning(
                    "Message validation failed for {} message: {}.", msg_type,
                    e.msg)
                self.log_debug(
                    "Failed message bytes: {}",
                    self._get_last_msg_bytes(msg, input_buffer_len_before,
                                             payload_len))

                if is_full_msg:
                    self.clean_up_current_msg(
                        payload_len,
                        input_buffer_len_before == self.inputbuf.length)
                else:
                    self.log_error(
                        "Unable to recover after message that failed validation. Closing connection."
                    )
                    self.mark_for_close()
                    return

                if self._report_bad_message():
                    return

            # TODO: Throw custom exception for any errors that come from input that has not been validated and only catch that subclass of exceptions
            except Exception as e:

                # Attempt to recover connection by removing bad full message
                if is_full_msg:
                    self.log_error(
                        "Message processing error; trying to recover. Error: {}.",
                        e,
                        exc_info=True)
                    self.log_debug(
                        "Failed message bytes: {}",
                        self._get_last_msg_bytes(msg, input_buffer_len_before,
                                                 payload_len))

                    # give connection a chance to restore its state and get ready to process next message
                    self.clean_up_current_msg(
                        payload_len,
                        input_buffer_len_before == self.inputbuf.length)

                # Connection is unable to recover from message processing error if incomplete message is received
                else:
                    self.log_error(
                        "Message processing error; unable to recover. Error: {}.",
                        e,
                        exc_info=True)
                    self.log_debug(
                        "Failed message bytes: {}",
                        self._get_last_msg_bytes(msg, input_buffer_len_before,
                                                 payload_len))
                    self.mark_for_close()
                    return

                if self._report_bad_message():
                    return
            else:
                self.num_bad_messages = 0

        time_elapsed = time.time() - start_time
        self.log_trace("Processed {} messages in {:.2f} seconds",
                       messages_processed, time_elapsed)

    def pop_next_message(self, payload_len):
        """
        Pop the next message off of the buffer given the message length.
        Preserve invariant of self.inputbuf always containing the start of a valid message.

        :param payload_len: length of payload
        :return: message object
        """

        msg_len = self.message_factory.base_message_type.HEADER_LENGTH + payload_len
        msg_contents = self.inputbuf.remove_bytes(msg_len)
        return self.message_factory.create_message_from_buffer(msg_contents)

    def advance_bytes_on_buffer(self, buf, bytes_written):
        hooks.add_throughput_event(Direction.OUTBOUND, None, bytes_written,
                                   self.peer_desc)
        try:
            buf.advance_buffer(bytes_written)
        except ValueError as e:
            raise RuntimeError(
                "Connection: {}, Failed to advance buffer".format(self)) from e

    def send_ping(self):
        """
        Send a ping (and reschedule if called from alarm queue)
        """
        if self.can_send_pings and not self.state & ConnectionState.MARK_FOR_CLOSE:
            self.enqueue_msg(self.ping_message)
            return self.ping_interval_s
        return constants.CANCEL_ALARMS

    def msg_hello(self, msg):
        self.state |= ConnectionState.HELLO_RECVD
        if msg.node_id() is None:
            self.log_debug("Received hello message without peer id.")
        self.peer_id = msg.node_id()
        self.node.connection_pool.index_conn_node_id(self.peer_id, self)

        if len(self.node.connection_pool.get_by_node_id(self.peer_id)) > 1:
            if self.from_me:
                self.log_info(
                    "Received duplicate connection from: {}. Closing.",
                    self.peer_id)
                self.mark_for_close()
            return

        self.enqueue_msg(self.ack_message)
        if self.is_active():
            self.on_connection_established()

    def msg_ack(self, _msg):
        """
        Handle an Ack Message
        """
        self.state |= ConnectionState.HELLO_ACKD
        if self.is_active():
            self.on_connection_established()

    def msg_ping(self, msg):
        self.enqueue_msg(self.pong_message)

    def msg_pong(self, _msg):
        pass

    def mark_for_close(self):
        """
        Marks a connection for close. Prefer using this method to close a connection over
        AbstractConnection#destroy_conn, as this allows a cleaner showdown and finish processing messages.
        """
        self.state |= ConnectionState.MARK_FOR_CLOSE
        self.log_debug("Marking connection for close.")

    def close(self):
        """
        Cleans up connection state after socket has been terminated.

        Do not call this directly from connection event handlers.
        """
        assert self.state & ConnectionState.MARK_FOR_CLOSE

    def clean_up_current_msg(self, payload_len: int,
                             msg_is_in_input_buffer: bool) -> None:
        """
        Removes current message from the input buffer and resets connection to a state ready to process next message.
        Called during the handling of message processing exceptions.

        :param payload_len: length of the payload of the currently processing message
        :param msg_is_in_input_buffer: flag indicating if message bytes are still in the input buffer
        :return:
        """

        if msg_is_in_input_buffer:
            self.inputbuf.remove_bytes(self.header_size + payload_len)

    def on_input_received(self) -> bool:
        """handles an input event from the event loop

        :return: True if the connection is receivable, otherwise False
        """
        return True

    def log_connection_mem_stats(self) -> None:
        """
        logs the connection's memory stats
        """
        class_name = self.__class__.__name__
        hooks.add_obj_mem_stats(
            class_name,
            self.network_num,
            self.inputbuf,
            "input_buffer",
            memory_utils.ObjectSize("input_buffer",
                                    memory_utils.get_special_size(
                                        self.inputbuf).size,
                                    is_actual_size=True),
            object_item_count=len(self.inputbuf.input_list),
            object_type=memory_utils.ObjectType.BASE,
            size_type=memory_utils.SizeType.TRUE)
        hooks.add_obj_mem_stats(
            class_name,
            self.network_num,
            self.outputbuf,
            "output_buffer",
            memory_utils.ObjectSize("output_buffer",
                                    memory_utils.get_special_size(
                                        self.outputbuf).size,
                                    is_actual_size=True),
            object_item_count=len(self.outputbuf.output_msgs),
            object_type=memory_utils.ObjectType.BASE,
            size_type=memory_utils.SizeType.TRUE)

    def update_model(self, model: OutboundPeerModel):
        self.log_trace("Updated connection model: {}", model)
        self.peer_model = model

    def _report_bad_message(self):
        """
        Increments counter for bad messages. Returns True if connection should be closed.
        :return: if connection should be closed
        """
        if self.num_bad_messages == constants.MAX_BAD_MESSAGES:
            self.log_warning("Received too many bad messages. Closing.")
            self.mark_for_close()
            return True
        else:
            self.num_bad_messages += 1
            return False

    def _get_last_msg_bytes(self, msg, input_buffer_len_before, payload_len):

        if msg is not None:
            return convert.bytes_to_hex(
                msg.rawbytes()[:constants.MAX_LOGGED_BYTES_LEN])

        # bytes still available on input buffer
        if input_buffer_len_before == self.inputbuf.length and payload_len is not None:
            return convert.bytes_to_hex(
                self.inputbuf.peek_message(
                    min(self.header_size + payload_len,
                        constants.MAX_LOGGED_BYTES_LEN)))

        return "<not available>"
class InternalNodeConnection(AbstractConnection[Node]):
    __metaclass__ = ABCMeta

    def __init__(self, sock: AbstractSocketConnectionProtocol,
                 node: Node) -> None:
        super(InternalNodeConnection, self).__init__(sock, node)

        # Enable buffering only on internal connections
        self.enable_buffered_send = node.opts.enable_buffered_send
        self.outputbuf = OutputBuffer(
            enable_buffering=self.enable_buffered_send)

        self.network_num = node.network_num
        self.version_manager = bloxroute_version_manager

        # Setting default protocol version; override when hello message received
        self.protocol_version = self.version_manager.CURRENT_PROTOCOL_VERSION

        self.pong_message = PongMessage()
        self.ack_message = AckMessage()

        self.can_send_pings = True
        self.pong_timeout_enabled = True

        self.ping_message_timestamps = ExpiringDict(
            self.node.alarm_queue, constants.REQUEST_EXPIRATION_TIME,
            f"{str(self)}_ping_timestamps")

        self.sync_ping_latencies: Dict[int, Optional[float]] = {}
        self._nonce_to_network_num: Dict[int, int] = {}
        self.message_validator = BloxrouteMessageValidator(
            None, self.protocol_version)
        self.tx_sync_service = TxSyncService(self)
        self.inbound_peer_latency: float = time.time()

    def connection_message_factory(self) -> AbstractMessageFactory:
        return bloxroute_message_factory

    def ping_message(self) -> AbstractMessage:
        nonce = nonce_generator.get_nonce()
        self.ping_message_timestamps.add(nonce, time.time())
        return PingMessage(nonce)

    def disable_buffering(self):
        """
        Disable buffering on this particular connection.
        :return:
        """
        self.enable_buffered_send = False
        self.outputbuf.flush()
        self.outputbuf.enable_buffering = False
        self.socket_connection.send()

    def set_protocol_version_and_message_factory(self) -> bool:
        """
        Gets protocol version from the first bytes of hello message if not known.
        Sets protocol version and creates message factory for that protocol version
        """

        # Outgoing connections use current version of protocol and message factory
        if ConnectionState.HELLO_RECVD in self.state:
            return True

        protocol_version = self.version_manager.get_connection_protocol_version(
            self.inputbuf)

        if protocol_version is None:
            return False

        if not self.version_manager.is_protocol_supported(protocol_version):
            self.log_debug(
                "Protocol version {} of remote node '{}' is not supported. Closing connection.",
                protocol_version, self.peer_desc)
            self.mark_for_close()
            return False

        if protocol_version > self.version_manager.CURRENT_PROTOCOL_VERSION:
            self.log_debug(
                "Got message protocol {} that is higher the current version {}. Using current protocol version",
                protocol_version,
                self.version_manager.CURRENT_PROTOCOL_VERSION)
            protocol_version = self.version_manager.CURRENT_PROTOCOL_VERSION

        self.protocol_version = protocol_version
        self.message_factory = self.version_manager.get_message_factory_for_version(
            protocol_version)

        self.log_trace("Setting connection protocol version to {}".format(
            protocol_version))

        return True

    def pre_process_msg(self) -> ConnectionMessagePreview:
        success = self.set_protocol_version_and_message_factory()

        if not success:
            return ConnectionMessagePreview(False, True, None, None)

        return super(InternalNodeConnection, self).pre_process_msg()

    def enqueue_msg(self, msg, prepend=False):
        if not self.is_alive():
            return

        if self.protocol_version < self.version_manager.CURRENT_PROTOCOL_VERSION:
            versioned_message = self.version_manager.convert_message_to_older_version(
                self.protocol_version, msg)
        else:
            versioned_message = msg

        super(InternalNodeConnection,
              self).enqueue_msg(versioned_message, prepend)

    def pop_next_message(self, payload_len: int) -> AbstractMessage:
        msg = super(InternalNodeConnection, self).pop_next_message(payload_len)

        if msg is None or self.protocol_version >= self.version_manager.CURRENT_PROTOCOL_VERSION:
            return msg

        versioned_msg = self.version_manager.convert_message_from_older_version(
            self.protocol_version, msg)

        return versioned_msg

    def check_ping_latency_for_network(self, network_num: int) -> None:
        ping_message = cast(PingMessage, self.ping_message())
        self.enqueue_msg(ping_message)
        self._nonce_to_network_num[ping_message.nonce()] = network_num
        self.sync_ping_latencies[network_num] = None

    def msg_hello(self, msg):
        super(InternalNodeConnection, self).msg_hello(msg)

        if not self.is_alive():
            self.log_trace("Connection has been closed: {}, Ignoring: {} ",
                           self, msg)
            return

        network_num = msg.network_num()

        if self.node.network_num != constants.ALL_NETWORK_NUM and network_num != self.node.network_num:
            self.log_warning(log_messages.NETWORK_NUMBER_MISMATCH,
                             self.node.network_num, network_num)
            self.mark_for_close()
            return

        self.network_num = network_num

        self.schedule_pings()

    def peek_broadcast_msg_network_num(self, input_buffer):

        if self.protocol_version == 1:
            return constants.DEFAULT_NETWORK_NUM

        return BroadcastMessage.peek_network_num(input_buffer)

    # pylint: disable=arguments-differ
    def msg_ping(self, msg: PingMessage):
        nonce = msg.nonce()
        assumed_request_time = time.time(
        ) - nonce_generator.get_timestamp_from_nonce(nonce)
        self.inbound_peer_latency = assumed_request_time
        hooks.add_measurement(self.peer_desc, MeasurementType.PING_INCOMING,
                              assumed_request_time, self.peer_id)

        self.enqueue_msg(
            PongMessage(nonce=nonce, timestamp=nonce_generator.get_nonce()))

    # pylint: disable=arguments-differ
    def msg_pong(self, msg: PongMessage):
        super(InternalNodeConnection, self).msg_pong(msg)

        nonce = msg.nonce()
        timestamp = msg.timestamp()
        if timestamp:
            self.inbound_peer_latency = time.time(
            ) - nonce_generator.get_timestamp_from_nonce(timestamp)
        if nonce in self.ping_message_timestamps.contents:
            request_msg_timestamp = self.ping_message_timestamps.contents[
                nonce]
            request_response_time = time.time() - request_msg_timestamp

            if nonce in self._nonce_to_network_num:
                self.sync_ping_latencies[
                    self._nonce_to_network_num[nonce]] = request_response_time

            if request_response_time > constants.PING_PONG_TRESHOLD:
                self.log_debug(
                    "Ping/pong exchange nonce {} took {:.2f} seconds to complete.",
                    msg.nonce(), request_response_time)
            else:
                self.log_trace(
                    "Ping/pong exchange nonce {} took {:.2f} seconds to complete.",
                    msg.nonce(), request_response_time)

            hooks.add_measurement(self.peer_desc, MeasurementType.PING,
                                  request_response_time, self.peer_id)
            if timestamp:
                assumed_peer_response_time = nonce_generator.get_timestamp_from_nonce(
                    timestamp) - request_msg_timestamp
                hooks.add_measurement(self.peer_desc,
                                      MeasurementType.PING_OUTGOING,
                                      assumed_peer_response_time, self.peer_id)

        elif nonce is not None:
            self.log_debug(
                "Pong message had no matching ping request. Nonce: {}", nonce)

    def mark_for_close(self, should_retry: Optional[bool] = None):
        super(InternalNodeConnection, self).mark_for_close(should_retry)
        self.cancel_pong_timeout()

    def is_gateway_connection(self):
        return self.CONNECTION_TYPE in ConnectionType.GATEWAY

    def is_external_gateway_connection(self):
        # self.CONNECTION_TYPE == ConnectionType.GATEWAY is equal True only for V1 gateways
        return self.CONNECTION_TYPE in ConnectionType.EXTERNAL_GATEWAY or self.CONNECTION_TYPE == ConnectionType.GATEWAY

    def is_internal_gateway_connection(self) -> bool:
        return self.CONNECTION_TYPE in ConnectionType.INTERNAL_GATEWAY

    def is_relay_connection(self):
        return self.CONNECTION_TYPE in ConnectionType.RELAY_ALL

    def is_proxy_connection(self) -> bool:
        return self.CONNECTION_TYPE in ConnectionType.RELAY_PROXY

    def update_tx_sync_complete(self, network_num: int):
        if network_num in self.sync_ping_latencies:
            del self.sync_ping_latencies[network_num]
        self._nonce_to_network_num = {
            nonce: other_network_num
            for nonce, other_network_num in self._nonce_to_network_num.items()
            if other_network_num != network_num
        }
class MessageTrackerTest(AbstractTestCase):
    def setUp(self) -> None:
        self.node = MockNode(
            helpers.get_common_opts(1001, external_ip="128.128.128.128"))
        self.tracker = MessageTracker(
            MockConnection(MockSocketConnection(), self.node))
        self.output_buffer = OutputBuffer(enable_buffering=True)

    def test_empty_bytes_no_bytes_sent(self):
        message = TxMessage(
            helpers.generate_object_hash(),
            5,
            tx_val=helpers.generate_bytearray(250),
        )
        message_length = len(message.rawbytes())
        self.output_buffer.enqueue_msgbytes(message.rawbytes())
        self.output_buffer.flush()
        self.tracker.append_message(message_length, message)

        self.output_buffer.safe_empty()
        self.tracker.empty_bytes(self.output_buffer.length)

        self.assertEqual(0, self.output_buffer.length)
        self.assertEqual(0, self.tracker.bytes_remaining)
        self.assertEqual(0, len(self.tracker.messages))

    def test_empty_bytes(self):
        message1 = TxMessage(
            helpers.generate_object_hash(),
            5,
            tx_val=helpers.generate_bytearray(250),
        )
        message2 = TxMessage(
            helpers.generate_object_hash(),
            5,
            tx_val=helpers.generate_bytearray(250),
        )
        message3 = TxMessage(
            helpers.generate_object_hash(),
            5,
            tx_val=helpers.generate_bytearray(250),
        )
        message_length = len(message1.rawbytes())

        self.output_buffer.enqueue_msgbytes(message1.rawbytes())
        self.output_buffer.flush()
        self.output_buffer.enqueue_msgbytes(message2.rawbytes())
        self.output_buffer.enqueue_msgbytes(message3.rawbytes())

        self.tracker.append_message(message_length, message1)
        self.tracker.append_message(message_length, message2)
        self.tracker.append_message(message_length, message3)

        self.output_buffer.advance_buffer(120)
        self.tracker.advance_bytes(120)

        self.output_buffer.safe_empty()
        self.assertEqual(message_length - 120, self.output_buffer.length)

        self.tracker.empty_bytes(self.output_buffer.length)

        self.assertEqual(1, len(self.tracker.messages))
        self.assertEqual(message_length - 120, self.tracker.bytes_remaining)
        self.assertEqual(120, self.tracker.messages[0].sent_bytes)

    def test_empty_bytes_more_bytes(self):
        total_bytes = 0
        for _ in range(100):
            message = TxMessage(
                helpers.generate_object_hash(),
                5,
                tx_val=helpers.generate_bytearray(2500),
            )
            message_length = len(message.rawbytes())
            total_bytes += message_length
            self.output_buffer.enqueue_msgbytes(message.rawbytes())
            self.tracker.append_message(message_length, message)

        self.output_buffer.advance_buffer(3500)
        self.tracker.advance_bytes(3500)

        self.output_buffer.safe_empty()
        self.tracker.empty_bytes(self.output_buffer.length)

        self.assertEqual(self.output_buffer.length,
                         self.tracker.bytes_remaining)
Exemple #19
0
class InternalNodeConnection(AbstractConnection[Node]):
    __metaclass__ = ABCMeta

    def __init__(self, sock, address, node, from_me=False):
        super(InternalNodeConnection, self).__init__(sock, address, node, from_me)

        # Enable buffering only on internal connections
        self.enable_buffered_send = node.opts.enable_buffered_send
        self.outputbuf = OutputBuffer(enable_buffering=self.enable_buffered_send)

        self.network_num = node.network_num
        self.version_manager = bloxroute_version_manager

        # Setting default protocol version and message factory; override when hello message received
        self.message_factory = bloxroute_message_factory
        self.protocol_version = self.version_manager.CURRENT_PROTOCOL_VERSION

        self.ping_message = PingMessage()
        self.pong_message = PongMessage()
        self.ack_message = AckMessage()

        self.can_send_pings = True
        self.ping_message_timestamps = ExpiringDict(self.node.alarm_queue, constants.REQUEST_EXPIRATION_TIME)
        self.message_validator = BloxrouteMessageValidator(None, self.protocol_version)

    def disable_buffering(self):
        """
        Disable buffering on this particular connection.
        :return:
        """
        self.enable_buffered_send = False
        self.outputbuf.flush()
        self.outputbuf.enable_buffering = False
        self.socket_connection.send()

    def set_protocol_version_and_message_factory(self):
        """
        Gets protocol version from the first bytes of hello message if not known.
        Sets protocol version and creates message factory for that protocol version
        """

        # Outgoing connections use current version of protocol and message factory
        if self.from_me or self.state & ConnectionState.HELLO_RECVD:
            return True

        protocol_version = self.version_manager.get_connection_protocol_version(self.inputbuf)

        if protocol_version is None:
            return False

        if not self.version_manager.is_protocol_supported(protocol_version):
            self.log_debug(
                "Protocol version {} of remote node '{}' is not supported. Closing connection.",
                protocol_version,
                self.peer_desc
            )
            self.mark_for_close()
            return False

        self.protocol_version = protocol_version
        self.message_factory = self.version_manager.get_message_factory_for_version(protocol_version)

        self.log_trace("Detected incoming connection with protocol version {}".format(protocol_version))

        return True

    def pre_process_msg(self):
        success = self.set_protocol_version_and_message_factory()

        if not success:
            return False, None, None

        return super(InternalNodeConnection, self).pre_process_msg()

    def enqueue_msg(self, msg, prepend=False):
        if self.state & ConnectionState.MARK_FOR_CLOSE:
            return

        if self.protocol_version < self.version_manager.CURRENT_PROTOCOL_VERSION:
            versioned_message = self.version_manager.convert_message_to_older_version(self.protocol_version, msg)
        else:
            versioned_message = msg

        super(InternalNodeConnection, self).enqueue_msg(versioned_message, prepend)

    def pop_next_message(self, payload_len):
        msg = super(InternalNodeConnection, self).pop_next_message(payload_len)

        if msg is None or self.protocol_version >= self.version_manager.CURRENT_PROTOCOL_VERSION:
            return msg

        versioned_msg = self.version_manager.convert_message_from_older_version(self.protocol_version, msg)

        return versioned_msg

    def msg_hello(self, msg):
        super(InternalNodeConnection, self).msg_hello(msg)

        if self.state & ConnectionState.MARK_FOR_CLOSE:
            self.log_trace("Connection has been closed: {}, Ignoring: {} ", self, msg)
            return

        network_num = msg.network_num()

        if self.node.network_num != constants.ALL_NETWORK_NUM and network_num != self.node.network_num:
            self.log_warning(
                "Network number mismatch. Current network num {}, remote network num {}. Closing connection.",
                self.node.network_num, network_num)
            self.mark_for_close()
            return

        self.network_num = network_num
        self.node.alarm_queue.register_alarm(self.ping_interval_s, self.send_ping)

    def peek_broadcast_msg_network_num(self, input_buffer):

        if self.protocol_version == 1:
            return constants.DEFAULT_NETWORK_NUM

        return BroadcastMessage.peek_network_num(input_buffer)

    def send_ping(self):
        """
        Send a ping (and reschedule if called from alarm queue)
        """
        if self.can_send_pings and not self.state & ConnectionState.MARK_FOR_CLOSE:
            nonce = nonce_generator.get_nonce()
            msg = PingMessage(nonce=nonce)
            self.enqueue_msg(msg)
            self.ping_message_timestamps.add(nonce, time.time())
            return self.ping_interval_s
        return constants.CANCEL_ALARMS

    def msg_ping(self, msg):
        nonce = msg.nonce()
        self.enqueue_msg(PongMessage(nonce=nonce))

    def msg_pong(self, msg: PongMessage):
        nonce = msg.nonce()
        if nonce in self.ping_message_timestamps.contents:
            request_msg_timestamp = self.ping_message_timestamps.contents[nonce]
            request_response_time = time.time() - request_msg_timestamp
            self.log_trace("Pong for nonce {} had response time: {}", msg.nonce(), request_response_time)
            hooks.add_measurement(self.peer_desc, MeasurementType.PING, request_response_time)
        elif nonce is not None:
            self.log_debug("Pong message had no matching ping request. Nonce: {}", nonce)

    def msg_tx_service_sync_txs(self, msg: TxServiceSyncTxsMessage):
        """
        Transaction service sync message receive txs data
        """
        network_num = msg.network_num()
        self.node.last_sync_message_received_by_network[network_num] = time.time()
        tx_service = self.node.get_tx_service(network_num)
        txs_content_short_ids = msg.txs_content_short_ids()

        for tx_content_short_ids in txs_content_short_ids:
            tx_hash = tx_content_short_ids.tx_hash
            tx_service.set_transaction_contents(tx_hash, tx_content_short_ids.tx_content)
            for short_id in tx_content_short_ids.short_ids:
                tx_service.assign_short_id(tx_hash, short_id)

    def _create_txs_service_msg(self, network_num: int, tx_service_snap: List[Sha256Hash]) -> List[TxContentShortIds]:
        txs_content_short_ids: List[TxContentShortIds] = []
        txs_msg_len = 0

        while tx_service_snap:
            tx_hash = tx_service_snap.pop()
            tx_content_short_ids = TxContentShortIds(
                tx_hash,
                self.node.get_tx_service(network_num).get_transaction_by_hash(tx_hash),
                self.node.get_tx_service(network_num).get_short_ids(tx_hash)
            )

            txs_msg_len += txs_serializer.get_serialized_tx_content_short_ids_bytes_len(tx_content_short_ids)

            txs_content_short_ids.append(tx_content_short_ids)
            if txs_msg_len >= constants.TXS_MSG_SIZE:
                break

        return txs_content_short_ids

    def send_tx_service_sync_req(self, network_num: int):
        """
        sending transaction service sync request
        """
        self.enqueue_msg(TxServiceSyncReqMessage(network_num))

    def send_tx_service_sync_complete(self, network_num: int):
        self.enqueue_msg(TxServiceSyncCompleteMessage(network_num))

    def send_tx_service_sync_blocks_short_ids(self, network_num: int):
        blocks_short_ids: List[BlockShortIds] = []
        start_time = time.time()
        for block_hash, short_ids in self.node.get_tx_service(network_num).iter_short_ids_seen_in_block():
            blocks_short_ids.append(BlockShortIds(block_hash, short_ids))

        block_short_ids_msg = TxServiceSyncBlocksShortIdsMessage(network_num, blocks_short_ids)
        duration = time.time() - start_time
        self.log_trace("Sending {} block short ids took {:.3f} seconds.", len(blocks_short_ids), duration)
        self.enqueue_msg(block_short_ids_msg)

    def send_tx_service_sync_txs(self, network_num: int, tx_service_snap: List[Sha256Hash], duration: float = 0, msgs_count: int = 0, total_tx_count: int = 0, sending_tx_msgs_start_time: float = 0):
        if (time.time() - sending_tx_msgs_start_time) < constants.SENDING_TX_MSGS_TIMEOUT_MS:
            if tx_service_snap:
                start = time.time()
                txs_content_short_ids = self._create_txs_service_msg(network_num, tx_service_snap)
                self.enqueue_msg(TxServiceSyncTxsMessage(network_num, txs_content_short_ids))
                duration += time.time() - start
                msgs_count += 1
                total_tx_count += len(txs_content_short_ids)
            # checks again if tx_snap in case we still have msgs to send, else no need to wait
            # TX_SERVICE_SYNC_TXS_S seconds
            if tx_service_snap:
                self.node.alarm_queue.register_alarm(
                    constants.TX_SERVICE_SYNC_TXS_S, self.send_tx_service_sync_txs, network_num, tx_service_snap, duration, msgs_count, total_tx_count, sending_tx_msgs_start_time
                )
            else:   # if all txs were sent, send complete msg
                self.log_trace("Sending {} transactions and {} messages took {:.3f} seconds.",
                               total_tx_count, msgs_count, duration)
                self.send_tx_service_sync_complete(network_num)
        else:   # if time is up - upgrade this node as synced - giving up
            self.log_trace("Sending {} transactions and {} messages took more than {} seconds. Giving up.",
                           total_tx_count, msgs_count, constants.SENDING_TX_MSGS_TIMEOUT_MS)
            self.send_tx_service_sync_complete(network_num)

    def msg_tx_service_sync_complete(self, msg: TxServiceSyncCompleteMessage):
        network_num = msg.network_num()
        self.node.last_sync_message_received_by_network.pop(network_num, None)
        self.node.on_fully_updated_tx_service()
 def setUp(self) -> None:
     self.node = MockNode(
         helpers.get_common_opts(1001, external_ip="128.128.128.128"))
     self.tracker = MessageTracker(
         MockConnection(MockSocketConnection(), self.node))
     self.output_buffer = OutputBuffer(enable_buffering=True)