Ejemplo n.º 1
0
    async def test_too_much_data(self):
        # Too much data
        r = RateLimiter()
        tx_message = make_msg(ProtocolMessageTypes.respond_transaction,
                              bytes([1] * 500 * 1024))
        for i in range(10):
            assert r.process_msg_and_check(tx_message)

        saw_disconnect = False
        for i in range(300):
            response = r.process_msg_and_check(tx_message)
            if not response:
                saw_disconnect = True
        assert saw_disconnect

        r = RateLimiter()
        block_message = make_msg(ProtocolMessageTypes.respond_block,
                                 bytes([1] * 1024 * 1024))
        for i in range(10):
            assert r.process_msg_and_check(block_message)

        saw_disconnect = False
        for i in range(40):
            response = r.process_msg_and_check(block_message)
            if not response:
                saw_disconnect = True
        assert saw_disconnect
Ejemplo n.º 2
0
    async def test_too_many_messages(self):
        # Too many messages
        r = RateLimiter()
        new_tx_message = make_msg(ProtocolMessageTypes.new_transaction,
                                  bytes([1] * 40))
        for i in range(3000):
            assert r.process_msg_and_check(new_tx_message)

        saw_disconnect = False
        for i in range(3000):
            response = r.process_msg_and_check(new_tx_message)
            if not response:
                saw_disconnect = True
        assert saw_disconnect

        # Non-tx message
        r = RateLimiter()
        new_peak_message = make_msg(ProtocolMessageTypes.new_peak,
                                    bytes([1] * 40))
        for i in range(20):
            assert r.process_msg_and_check(new_peak_message)

        saw_disconnect = False
        for i in range(200):
            response = r.process_msg_and_check(new_peak_message)
            if not response:
                saw_disconnect = True
        assert saw_disconnect
Ejemplo n.º 3
0
    async def test_periodic_reset(self):
        r = RateLimiter(5)
        tx_message = make_msg(ProtocolMessageTypes.respond_transaction,
                              bytes([1] * 500 * 1024))
        for i in range(10):
            assert r.process_msg_and_check(tx_message)

        saw_disconnect = False
        for i in range(300):
            response = r.process_msg_and_check(tx_message)
            if not response:
                saw_disconnect = True
        assert saw_disconnect
        assert not r.process_msg_and_check(tx_message)
        await asyncio.sleep(6)
        assert r.process_msg_and_check(tx_message)

        # Counts reset also
        r = RateLimiter(5)
        new_tx_message = make_msg(ProtocolMessageTypes.new_transaction,
                                  bytes([1] * 40))
        for i in range(3000):
            assert r.process_msg_and_check(new_tx_message)

        saw_disconnect = False
        for i in range(3000):
            response = r.process_msg_and_check(new_tx_message)
            if not response:
                saw_disconnect = True
        assert saw_disconnect
        assert not r.process_msg_and_check(new_tx_message)
        await asyncio.sleep(6)
        assert r.process_msg_and_check(new_tx_message)
Ejemplo n.º 4
0
    async def test_large_message(self):
        # Large tx
        small_tx_message = make_msg(ProtocolMessageTypes.respond_transaction,
                                    bytes([1] * 500 * 1024))
        large_tx_message = make_msg(ProtocolMessageTypes.new_transaction,
                                    bytes([1] * 3 * 1024 * 1024))

        r = RateLimiter()
        assert r.process_msg_and_check(small_tx_message)
        assert r.process_msg_and_check(small_tx_message)
        assert not r.process_msg_and_check(large_tx_message)

        small_vdf_message = make_msg(
            ProtocolMessageTypes.respond_signage_point, bytes([1] * 5 * 1024))
        large_vdf_message = make_msg(
            ProtocolMessageTypes.respond_signage_point,
            bytes([1] * 600 * 1024))
        r = RateLimiter()
        assert r.process_msg_and_check(small_vdf_message)
        assert r.process_msg_and_check(small_vdf_message)
        assert not r.process_msg_and_check(large_vdf_message)
Ejemplo n.º 5
0
    async def test_non_tx_aggregate_limits(self):
        # Frequency limits
        r = RateLimiter()
        message_1 = make_msg(ProtocolMessageTypes.request_additions,
                             bytes([1] * 5 * 1024))
        message_2 = make_msg(ProtocolMessageTypes.request_removals,
                             bytes([1] * 1024))
        message_3 = make_msg(ProtocolMessageTypes.respond_additions,
                             bytes([1] * 1024))

        for i in range(450):
            assert r.process_msg_and_check(message_1)
        for i in range(450):
            assert r.process_msg_and_check(message_2)

        saw_disconnect = False
        for i in range(450):
            response = r.process_msg_and_check(message_3)
            if not response:
                saw_disconnect = True
        assert saw_disconnect

        # Size limits
        r = RateLimiter()
        message_4 = make_msg(ProtocolMessageTypes.respond_proof_of_weight,
                             bytes([1] * 49 * 1024 * 1024))
        message_5 = make_msg(ProtocolMessageTypes.respond_blocks,
                             bytes([1] * 49 * 1024 * 1024))

        for i in range(2):
            assert r.process_msg_and_check(message_4)

        saw_disconnect = False
        for i in range(2):
            response = r.process_msg_and_check(message_5)
            if not response:
                saw_disconnect = True
        assert saw_disconnect
Ejemplo n.º 6
0
class WSChiaConnection:
    """
    Represents a connection to another node. Local host and port are ours, while peer host and
    port are the host and port of the peer that we are connected to. Node_id and connection_type are
    set after the handshake is performed in this connection.
    """
    def __init__(
        self,
        local_type: NodeType,
        ws: Any,  # Websocket
        server_port: int,
        log: logging.Logger,
        is_outbound: bool,
        is_feeler:
        bool,  # Special type of connection, that disconnects after the handshake.
        peer_host,
        incoming_queue,
        close_callback: Callable,
        peer_id,
        inbound_rate_limit_percent: int,
        outbound_rate_limit_percent: int,
        close_event=None,
        session=None,
    ):
        # Local properties
        self.ws: Any = ws
        self.local_type = local_type
        self.local_port = server_port
        # Remote properties
        self.peer_host = peer_host

        peername = self.ws._writer.transport.get_extra_info("peername")
        if peername is None:
            raise ValueError(
                f"Was not able to get peername from {self.ws_witer} at {self.peer_host}"
            )

        connection_port = peername[1]
        self.peer_port = connection_port
        self.peer_server_port: Optional[uint16] = None
        self.peer_node_id = peer_id

        self.log = log

        # connection properties
        self.is_outbound = is_outbound
        self.is_feeler = is_feeler

        # ChiaConnection metrics
        self.creation_time = time.time()
        self.bytes_read = 0
        self.bytes_written = 0
        self.last_message_time: float = 0

        # Messaging
        self.incoming_queue: asyncio.Queue = incoming_queue
        self.outgoing_queue: asyncio.Queue = asyncio.Queue()

        self.inbound_task: Optional[asyncio.Task] = None
        self.outbound_task: Optional[asyncio.Task] = None
        self.active: bool = False  # once handshake is successful this will be changed to True
        self.close_event: asyncio.Event = close_event
        self.session = session
        self.close_callback = close_callback

        self.pending_requests: Dict[bytes32, asyncio.Event] = {}
        self.pending_timeouts: Dict[bytes32, asyncio.Task] = {}
        self.request_results: Dict[bytes32, Message] = {}
        self.closed = False
        self.connection_type: Optional[NodeType] = None
        if is_outbound:
            self.request_nonce: uint16 = uint16(0)
        else:
            # Different nonce to reduce chances of overlap. Each peer will increment the nonce by one for each
            # request. The receiving peer (not is_outbound), will use 2^15 to 2^16 - 1
            self.request_nonce = uint16(2**15)

        # This means that even if the other peer's boundaries for each minute are not aligned, we will not
        # disconnect. Also it allows a little flexibility.
        self.outbound_rate_limiter = RateLimiter(
            incoming=False, percentage_of_limit=outbound_rate_limit_percent)
        self.inbound_rate_limiter = RateLimiter(
            incoming=True, percentage_of_limit=inbound_rate_limit_percent)

    async def perform_handshake(self, network_id: str, protocol_version: str,
                                server_port: int, local_type: NodeType):
        if self.is_outbound:
            outbound_handshake = make_msg(
                ProtocolMessageTypes.handshake,
                Handshake(
                    network_id,
                    protocol_version,
                    chia_full_version_str(),
                    uint16(server_port),
                    uint8(local_type.value),
                    [(uint16(Capability.BASE.value), "1")],
                ),
            )
            assert outbound_handshake is not None
            await self._send_message(outbound_handshake)
            inbound_handshake_msg = await self._read_one_message()
            if inbound_handshake_msg is None:
                raise ProtocolError(Err.INVALID_HANDSHAKE)
            inbound_handshake = Handshake.from_bytes(
                inbound_handshake_msg.data)

            # Handle case of invalid ProtocolMessageType
            try:
                message_type: ProtocolMessageTypes = ProtocolMessageTypes(
                    inbound_handshake_msg.type)
            except Exception:
                raise ProtocolError(Err.INVALID_HANDSHAKE)

            if message_type != ProtocolMessageTypes.handshake:
                raise ProtocolError(Err.INVALID_HANDSHAKE)

            if inbound_handshake.network_id != network_id:
                raise ProtocolError(Err.INCOMPATIBLE_NETWORK_ID)

            self.peer_server_port = inbound_handshake.server_port
            self.connection_type = NodeType(inbound_handshake.node_type)

        else:
            try:
                message = await self._read_one_message()
            except Exception:
                raise ProtocolError(Err.INVALID_HANDSHAKE)

            if message is None:
                raise ProtocolError(Err.INVALID_HANDSHAKE)

            # Handle case of invalid ProtocolMessageType
            try:
                message_type = ProtocolMessageTypes(message.type)
            except Exception:
                raise ProtocolError(Err.INVALID_HANDSHAKE)

            if message_type != ProtocolMessageTypes.handshake:
                raise ProtocolError(Err.INVALID_HANDSHAKE)

            inbound_handshake = Handshake.from_bytes(message.data)
            if inbound_handshake.network_id != network_id:
                raise ProtocolError(Err.INCOMPATIBLE_NETWORK_ID)
            outbound_handshake = make_msg(
                ProtocolMessageTypes.handshake,
                Handshake(
                    network_id,
                    protocol_version,
                    chia_full_version_str(),
                    uint16(server_port),
                    uint8(local_type.value),
                    [(uint16(Capability.BASE.value), "1")],
                ),
            )
            await self._send_message(outbound_handshake)
            self.peer_server_port = inbound_handshake.server_port
            self.connection_type = NodeType(inbound_handshake.node_type)

        self.outbound_task = asyncio.create_task(self.outbound_handler())
        self.inbound_task = asyncio.create_task(self.inbound_handler())
        return True

    async def close(self,
                    ban_time: int = 0,
                    ws_close_code: WSCloseCode = WSCloseCode.OK,
                    error: Optional[Err] = None):
        """
        Closes the connection, and finally calls the close_callback on the server, so the connections gets removed
        from the global list.
        """

        if self.closed:
            return None
        self.closed = True

        if error is None:
            message = b""
        else:
            message = str(int(error.value)).encode("utf-8")

        try:
            if self.inbound_task is not None:
                self.inbound_task.cancel()
            if self.outbound_task is not None:
                self.outbound_task.cancel()
            if self.ws is not None and self.ws._closed is False:
                await self.ws.close(code=ws_close_code, message=message)
            if self.session is not None:
                await self.session.close()
            if self.close_event is not None:
                self.close_event.set()
            self.cancel_pending_timeouts()
        except Exception:
            error_stack = traceback.format_exc()
            self.log.warning(f"Exception closing socket: {error_stack}")
            self.close_callback(self, ban_time)
            raise
        self.close_callback(self, ban_time)

    def cancel_pending_timeouts(self):
        for _, task in self.pending_timeouts.items():
            task.cancel()

    async def outbound_handler(self):
        try:
            while not self.closed:
                msg = await self.outgoing_queue.get()
                if msg is not None:
                    await self._send_message(msg)
        except asyncio.CancelledError:
            pass
        except BrokenPipeError as e:
            self.log.warning(f"{e} {self.peer_host}")
        except ConnectionResetError as e:
            self.log.warning(f"{e} {self.peer_host}")
        except Exception as e:
            error_stack = traceback.format_exc()
            self.log.error(f"Exception: {e} with {self.peer_host}")
            self.log.error(f"Exception Stack: {error_stack}")

    async def inbound_handler(self):
        try:
            while not self.closed:
                message: Message = await self._read_one_message()
                if message is not None:
                    if message.id in self.pending_requests:
                        self.request_results[message.id] = message
                        event = self.pending_requests[message.id]
                        event.set()
                    else:
                        await self.incoming_queue.put((message, self))
                else:
                    continue
        except asyncio.CancelledError:
            self.log.debug("Inbound_handler task cancelled")
        except Exception as e:
            error_stack = traceback.format_exc()
            self.log.error(f"Exception: {e}")
            self.log.error(f"Exception Stack: {error_stack}")

    async def send_message(self, message: Message):
        """Send message sends a message with no tracking / callback."""
        if self.closed:
            return None
        await self.outgoing_queue.put(message)

    def __getattr__(self, attr_name: str):
        # TODO KWARGS
        async def invoke(*args, **kwargs):
            timeout = 60
            if "timeout" in kwargs:
                timeout = kwargs["timeout"]
            attribute = getattr(class_for_type(self.connection_type),
                                attr_name, None)
            if attribute is None:
                raise AttributeError(
                    f"Node type {self.connection_type} does not have method {attr_name}"
                )

            msg = Message(
                uint8(getattr(ProtocolMessageTypes, attr_name).value), None,
                args[0])
            request_start_t = time.time()
            result = await self.create_request(msg, timeout)
            self.log.debug(
                f"Time for request {attr_name}: {self.get_peer_info()} = {time.time() - request_start_t}, "
                f"None? {result is None}")
            if result is not None:
                ret_attr = getattr(class_for_type(self.local_type),
                                   ProtocolMessageTypes(result.type).name,
                                   None)

                req_annotations = ret_attr.__annotations__
                req = None
                for key in req_annotations:
                    if key == "return" or key == "peer":
                        continue
                    else:
                        req = req_annotations[key]
                assert req is not None
                result = req.from_bytes(result.data)
            return result

        return invoke

    async def create_request(self, message_no_id: Message,
                             timeout: int) -> Optional[Message]:
        """Sends a message and waits for a response."""
        if self.closed:
            return None

        # We will wait for this event, it will be set either by the response, or the timeout
        event = asyncio.Event()

        # The request nonce is an integer between 0 and 2**16 - 1, which is used to match requests to responses
        # If is_outbound, 0 <= nonce < 2^15, else  2^15 <= nonce < 2^16
        request_id = self.request_nonce
        if self.is_outbound:
            self.request_nonce = uint16(self.request_nonce +
                                        1) if self.request_nonce != (
                                            2**15 - 1) else uint16(0)
        else:
            self.request_nonce = (uint16(self.request_nonce +
                                         1) if self.request_nonce !=
                                  (2**16 - 1) else uint16(2**15))

        message = Message(message_no_id.type, request_id, message_no_id.data)

        self.pending_requests[message.id] = event
        await self.outgoing_queue.put(message)

        # If the timeout passes, we set the event
        async def time_out(req_id, req_timeout):
            try:
                await asyncio.sleep(req_timeout)
                if req_id in self.pending_requests:
                    self.pending_requests[req_id].set()
            except asyncio.CancelledError:
                if req_id in self.pending_requests:
                    self.pending_requests[req_id].set()
                raise

        timeout_task = asyncio.create_task(time_out(message.id, timeout))
        self.pending_timeouts[message.id] = timeout_task
        await event.wait()

        self.pending_requests.pop(message.id)
        result: Optional[Message] = None
        if message.id in self.request_results:
            result = self.request_results[message.id]
            assert result is not None
            self.log.debug(
                f"<- {ProtocolMessageTypes(result.type).name} from: {self.peer_host}:{self.peer_port}"
            )
            self.request_results.pop(result.id)

        return result

    async def reply_to_request(self, response: Message):
        if self.closed:
            return None
        await self.outgoing_queue.put(response)

    async def send_messages(self, messages: List[Message]):
        if self.closed:
            return None
        for message in messages:
            await self.outgoing_queue.put(message)

    async def _wait_and_retry(self, msg: Message, queue: asyncio.Queue):
        try:
            await asyncio.sleep(1)
            await queue.put(msg)
        except Exception as e:
            self.log.debug(
                f"Exception {e} while waiting to retry sending rate limited message"
            )
            return None

    async def _send_message(self, message: Message):
        encoded: bytes = bytes(message)
        size = len(encoded)
        assert len(encoded) < (2**(LENGTH_BYTES * 8))
        if not self.outbound_rate_limiter.process_msg_and_check(message):
            if not is_localhost(self.peer_host):
                self.log.debug(
                    f"Rate limiting ourselves. message type: {ProtocolMessageTypes(message.type).name}, "
                    f"peer: {self.peer_host}")

                # TODO: fix this special case. This function has rate limits which are too low.
                if ProtocolMessageTypes(
                        message.type) != ProtocolMessageTypes.respond_peers:
                    asyncio.create_task(
                        self._wait_and_retry(message, self.outgoing_queue))

                return None
            else:
                self.log.debug(
                    f"Not rate limiting ourselves. message type: {ProtocolMessageTypes(message.type).name}, "
                    f"peer: {self.peer_host}")

        await self.ws.send_bytes(encoded)
        self.log.debug(
            f"-> {ProtocolMessageTypes(message.type).name} to peer {self.peer_host} {self.peer_node_id}"
        )
        self.bytes_written += size

    async def _read_one_message(self) -> Optional[Message]:
        try:
            message: WSMessage = await self.ws.receive(30)
        except asyncio.TimeoutError:
            # self.ws._closed if we didn't receive a ping / pong
            if self.ws._closed:
                asyncio.create_task(self.close())
                await asyncio.sleep(3)
                return None
            return None

        if self.connection_type is not None:
            connection_type_str = NodeType(self.connection_type).name.lower()
        else:
            connection_type_str = ""
        if message.type == WSMsgType.CLOSING:
            self.log.debug(
                f"Closing connection to {connection_type_str} {self.peer_host}:"
                f"{self.peer_server_port}/"
                f"{self.peer_port}")
            asyncio.create_task(self.close())
            await asyncio.sleep(3)
        elif message.type == WSMsgType.CLOSE:
            self.log.debug(
                f"Peer closed connection {connection_type_str} {self.peer_host}:"
                f"{self.peer_server_port}/"
                f"{self.peer_port}")
            asyncio.create_task(self.close())
            await asyncio.sleep(3)
        elif message.type == WSMsgType.CLOSED:
            if not self.closed:
                asyncio.create_task(self.close())
                await asyncio.sleep(3)
                return None
        elif message.type == WSMsgType.BINARY:
            data = message.data
            full_message_loaded: Message = Message.from_bytes(data)
            self.bytes_read += len(data)
            self.last_message_time = time.time()
            try:
                message_type = ProtocolMessageTypes(
                    full_message_loaded.type).name
            except Exception:
                message_type = "Unknown"
            if not self.inbound_rate_limiter.process_msg_and_check(
                    full_message_loaded):
                if self.local_type == NodeType.FULL_NODE and not is_localhost(
                        self.peer_host):
                    self.log.error(
                        f"Peer has been rate limited and will be disconnected: {self.peer_host}, "
                        f"message: {message_type}")
                    # Only full node disconnects peers, to prevent abuse and crashing timelords, farmers, etc
                    asyncio.create_task(self.close(300))
                    await asyncio.sleep(3)
                    return None
                else:
                    self.log.warning(
                        f"Peer surpassed rate limit {self.peer_host}, message: {message_type}, "
                        f"port {self.peer_port} but not disconnecting")
                    return full_message_loaded
            return full_message_loaded
        elif message.type == WSMsgType.ERROR:
            self.log.error(f"WebSocket Error: {message}")
            if message.data.code == WSCloseCode.MESSAGE_TOO_BIG:
                asyncio.create_task(self.close(300))
            else:
                asyncio.create_task(self.close())
            await asyncio.sleep(3)

        else:
            self.log.error(f"Unexpected WebSocket message type: {message}")
            asyncio.create_task(self.close())
            await asyncio.sleep(3)
        return None

    def get_peer_info(self) -> Optional[PeerInfo]:
        result = self.ws._writer.transport.get_extra_info("peername")
        if result is None:
            return None
        connection_host = result[0]
        port = self.peer_server_port if self.peer_server_port is not None else self.peer_port
        return PeerInfo(connection_host, port)
Ejemplo n.º 7
0
    async def test_percentage_limits(self):
        r = RateLimiter(60, 40)
        new_peak_message = make_msg(ProtocolMessageTypes.new_peak,
                                    bytes([1] * 40))
        for i in range(50):
            assert r.process_msg_and_check(new_peak_message)

        saw_disconnect = False
        for i in range(50):
            response = r.process_msg_and_check(new_peak_message)
            if not response:
                saw_disconnect = True
        assert saw_disconnect

        r = RateLimiter(60, 40)
        block_message = make_msg(ProtocolMessageTypes.respond_block,
                                 bytes([1] * 1024 * 1024))
        for i in range(5):
            assert r.process_msg_and_check(block_message)

        saw_disconnect = False
        for i in range(5):
            response = r.process_msg_and_check(block_message)
            if not response:
                saw_disconnect = True
        assert saw_disconnect

        # Aggregate percentage limit count
        r = RateLimiter(60, 40)
        message_1 = make_msg(ProtocolMessageTypes.request_additions,
                             bytes([1] * 5 * 1024))
        message_2 = make_msg(ProtocolMessageTypes.request_removals,
                             bytes([1] * 1024))
        message_3 = make_msg(ProtocolMessageTypes.respond_additions,
                             bytes([1] * 1024))

        for i in range(180):
            assert r.process_msg_and_check(message_1)
        for i in range(180):
            assert r.process_msg_and_check(message_2)

        saw_disconnect = False
        for i in range(100):
            response = r.process_msg_and_check(message_3)
            if not response:
                saw_disconnect = True
        assert saw_disconnect

        # Aggregate percentage limit max total size
        r = RateLimiter(60, 40)
        message_4 = make_msg(ProtocolMessageTypes.respond_proof_of_weight,
                             bytes([1] * 18 * 1024 * 1024))
        message_5 = make_msg(ProtocolMessageTypes.respond_blocks,
                             bytes([1] * 24 * 1024 * 1024))

        for i in range(2):
            assert r.process_msg_and_check(message_4)

        saw_disconnect = False
        for i in range(2):
            response = r.process_msg_and_check(message_5)
            if not response:
                saw_disconnect = True
        assert saw_disconnect