예제 #1
0
        async def invoke(*args, **kwargs):
            timeout = 60
            if "timeout" in kwargs:
                timeout = kwargs["timeout"]
            attribute = getattr(class_for_type(self.connection_type),
                                attr_name, None)
            if attribute is None:
                raise AttributeError(
                    f"Node type {self.connection_type} does not have method {attr_name}"
                )

            msg = Message(
                uint8(getattr(ProtocolMessageTypes, attr_name).value), None,
                args[0])
            request_start_t = time.time()
            result = await self.create_request(msg, timeout)
            self.log.debug(
                f"Time for request {attr_name}: {self.get_peer_info()} = {time.time() - request_start_t}, "
                f"None? {result is None}")
            if result is not None:
                ret_attr = getattr(class_for_type(self.local_type),
                                   ProtocolMessageTypes(result.type).name,
                                   None)

                req_annotations = ret_attr.__annotations__
                req = None
                for key in req_annotations:
                    if key == "return" or key == "peer":
                        continue
                    else:
                        req = req_annotations[key]
                assert req is not None
                result = req.from_bytes(result.data)
            return result
예제 #2
0
    async def create_request(self, message_no_id: Message,
                             timeout: int) -> Optional[Message]:
        """Sends a message and waits for a response."""
        if self.closed:
            return None

        # We will wait for this event, it will be set either by the response, or the timeout
        event = asyncio.Event()

        # The request nonce is an integer between 0 and 2**16 - 1, which is used to match requests to responses
        # If is_outbound, 0 <= nonce < 2^15, else  2^15 <= nonce < 2^16
        request_id = self.request_nonce
        if self.is_outbound:
            self.request_nonce = uint16(self.request_nonce +
                                        1) if self.request_nonce != (
                                            2**15 - 1) else uint16(0)
        else:
            self.request_nonce = (uint16(self.request_nonce +
                                         1) if self.request_nonce !=
                                  (2**16 - 1) else uint16(2**15))

        message = Message(message_no_id.type, request_id, message_no_id.data)

        self.pending_requests[message.id] = event
        await self.outgoing_queue.put(message)

        # If the timeout passes, we set the event
        async def time_out(req_id, req_timeout):
            try:
                await asyncio.sleep(req_timeout)
                if req_id in self.pending_requests:
                    self.pending_requests[req_id].set()
            except asyncio.CancelledError:
                if req_id in self.pending_requests:
                    self.pending_requests[req_id].set()
                raise

        timeout_task = asyncio.create_task(time_out(message.id, timeout))
        self.pending_timeouts[message.id] = timeout_task
        await event.wait()

        self.pending_requests.pop(message.id)
        result: Optional[Message] = None
        if message.id in self.request_results:
            result = self.request_results[message.id]
            assert result is not None
            self.log.debug(
                f"<- {ProtocolMessageTypes(result.type).name} from: {self.peer_host}:{self.peer_port}"
            )
            self.request_results.pop(result.id)

        return result
예제 #3
0
        async def invoke(*args, **kwargs):
            timeout = 60
            if "timeout" in kwargs:
                timeout = kwargs["timeout"]
            attribute = getattr(class_for_type(self.connection_type),
                                attr_name, None)
            if attribute is None:
                raise AttributeError(
                    f"Node type {self.connection_type} does not have method {attr_name}"
                )

            msg: Message = Message(
                uint8(getattr(ProtocolMessageTypes, attr_name).value), None,
                args[0])
            request_start_t = time.time()
            result = await self.send_request(msg, timeout)
            self.log.debug(
                f"Time for request {attr_name}: {self.get_peer_logging()} = {time.time() - request_start_t}, "
                f"None? {result is None}")
            if result is not None:
                sent_message_type = ProtocolMessageTypes(msg.type)
                recv_message_type = ProtocolMessageTypes(result.type)
                if not message_response_ok(sent_message_type,
                                           recv_message_type):
                    # peer protocol violation
                    error_message = f"WSConnection.invoke sent message {sent_message_type.name} "
                    f"but received {recv_message_type.name}"
                    await self.ban_peer_bad_protocol(self.error_message)
                    raise ProtocolError(Err.INVALID_PROTOCOL_MESSAGE,
                                        [error_message])
                ret_attr = getattr(class_for_type(self.local_type),
                                   ProtocolMessageTypes(result.type).name,
                                   None)

                req_annotations = ret_attr.__annotations__
                req = None
                for key in req_annotations:
                    if key == "return" or key == "peer":
                        continue
                    else:
                        req = req_annotations[key]
                assert req is not None
                result = req.from_bytes(result.data)
            return result
예제 #4
0
    async def test_invalid_protocol_handshake(self, setup_two_nodes):
        nodes, _ = setup_two_nodes
        server_1 = nodes[0].full_node.server
        server_2 = nodes[1].full_node.server

        server_1.invalid_protocol_ban_seconds = 10
        # Use the server_2 ssl information to connect to server_1
        timeout = ClientTimeout(total=10)
        session = ClientSession(timeout=timeout)
        url = f"wss://{self_hostname}:{server_1._port}/ws"

        ssl_context = ssl_context_for_client(server_2.chia_ca_crt_path,
                                             server_2.chia_ca_key_path,
                                             server_2.p2p_crt_path,
                                             server_2.p2p_key_path)
        ws = await session.ws_connect(url,
                                      autoclose=True,
                                      autoping=True,
                                      heartbeat=60,
                                      ssl=ssl_context,
                                      max_msg_size=100 * 1024 * 1024)

        # Construct an otherwise valid handshake message
        handshake: Handshake = Handshake("test", "0.0.32", "1.0.0.0", 3456, 1,
                                         [(1, "1")])
        outbound_handshake: Message = Message(
            2, None, bytes(handshake))  # 2 is an invalid ProtocolType
        await ws.send_bytes(bytes(outbound_handshake))

        response: WSMessage = await ws.receive()
        print(response)
        assert response.type == WSMsgType.CLOSE
        assert response.data == WSCloseCode.PROTOCOL_ERROR
        assert response.extra == str(
            int(Err.INVALID_HANDSHAKE.value
                ))  # We want INVALID_HANDSHAKE and not UNKNOWN
        await ws.close()
        await session.close()
        await asyncio.sleep(1)  # give some time for cleanup to work
예제 #5
0
            async def api_call(full_message: Message,
                               connection: WSChiaConnection, task_id):
                start_time = time.time()
                try:
                    if self.received_message_callback is not None:
                        await self.received_message_callback(connection)
                    connection.log.debug(
                        f"<- {ProtocolMessageTypes(full_message.type).name} from peer "
                        f"{connection.peer_node_id} {connection.peer_host}")
                    message_type: str = ProtocolMessageTypes(
                        full_message.type).name

                    f = getattr(self.api, message_type, None)

                    if f is None:
                        self.log.error(
                            f"Non existing function: {message_type}")
                        raise ProtocolError(Err.INVALID_PROTOCOL_MESSAGE,
                                            [message_type])

                    if not hasattr(f, "api_function"):
                        self.log.error(
                            f"Peer trying to call non api function {message_type}"
                        )
                        raise ProtocolError(Err.INVALID_PROTOCOL_MESSAGE,
                                            [message_type])

                    # If api is not ready ignore the request
                    if hasattr(self.api, "api_ready"):
                        if self.api.api_ready is False:
                            return None

                    timeout: Optional[int] = 600
                    if hasattr(f, "execute_task"):
                        # Don't timeout on methods with execute_task decorator, these need to run fully
                        self.execute_tasks.add(task_id)
                        timeout = None

                    if hasattr(f, "peer_required"):
                        coroutine = f(full_message.data, connection)
                    else:
                        coroutine = f(full_message.data)

                    async def wrapped_coroutine() -> Optional[Message]:
                        try:
                            result = await coroutine
                            return result
                        except asyncio.CancelledError:
                            pass
                        except Exception as e:
                            tb = traceback.format_exc()
                            connection.log.error(
                                f"Exception: {e}, {connection.get_peer_info()}. {tb}"
                            )
                            raise e
                        return None

                    response: Optional[Message] = await asyncio.wait_for(
                        wrapped_coroutine(), timeout=timeout)
                    connection.log.debug(
                        f"Time taken to process {message_type} from {connection.peer_node_id} is "
                        f"{time.time() - start_time} seconds")

                    if response is not None:
                        response_message = Message(response.type,
                                                   full_message.id,
                                                   response.data)
                        await connection.reply_to_request(response_message)
                except Exception as e:
                    if self.connection_close_task is None:
                        tb = traceback.format_exc()
                        connection.log.error(
                            f"Exception: {e} {type(e)}, closing connection {connection.get_peer_info()}. {tb}"
                        )
                    else:
                        connection.log.debug(
                            f"Exception: {e} while closing connection")
                    # TODO: actually throw one of the errors from errors.py and pass this to close
                    await connection.close(self.api_exception_ban_seconds,
                                           WSCloseCode.PROTOCOL_ERROR,
                                           Err.UNKNOWN)
                finally:
                    if task_id in self.api_tasks:
                        self.api_tasks.pop(task_id)
                    if task_id in self.tasks_from_peer[
                            connection.peer_node_id]:
                        self.tasks_from_peer[connection.peer_node_id].remove(
                            task_id)
                    if task_id in self.execute_tasks:
                        self.execute_tasks.remove(task_id)
예제 #6
0
    async def _read_one_message(self) -> Optional[Message]:
        try:
            message: WSMessage = await self.ws.receive(30)
        except asyncio.TimeoutError:
            # self.ws._closed if we didn't receive a ping / pong
            if self.ws._closed:
                asyncio.create_task(self.close())
                await asyncio.sleep(3)
                return None
            return None

        if self.connection_type is not None:
            connection_type_str = NodeType(self.connection_type).name.lower()
        else:
            connection_type_str = ""
        if message.type == WSMsgType.CLOSING:
            self.log.debug(
                f"Closing connection to {connection_type_str} {self.peer_host}:"
                f"{self.peer_server_port}/"
                f"{self.peer_port}")
            asyncio.create_task(self.close())
            await asyncio.sleep(3)
        elif message.type == WSMsgType.CLOSE:
            self.log.debug(
                f"Peer closed connection {connection_type_str} {self.peer_host}:"
                f"{self.peer_server_port}/"
                f"{self.peer_port}")
            asyncio.create_task(self.close())
            await asyncio.sleep(3)
        elif message.type == WSMsgType.CLOSED:
            if not self.closed:
                asyncio.create_task(self.close())
                await asyncio.sleep(3)
                return None
        elif message.type == WSMsgType.BINARY:
            data = message.data
            full_message_loaded: Message = Message.from_bytes(data)
            self.bytes_read += len(data)
            self.last_message_time = time.time()
            try:
                message_type = ProtocolMessageTypes(
                    full_message_loaded.type).name
            except Exception:
                message_type = "Unknown"
            if not self.inbound_rate_limiter.process_msg_and_check(
                    full_message_loaded):
                if self.local_type == NodeType.FULL_NODE and not is_localhost(
                        self.peer_host):
                    self.log.error(
                        f"Peer has been rate limited and will be disconnected: {self.peer_host}, "
                        f"message: {message_type}")
                    # Only full node disconnects peers, to prevent abuse and crashing timelords, farmers, etc
                    asyncio.create_task(self.close(300))
                    await asyncio.sleep(3)
                    return None
                else:
                    self.log.warning(
                        f"Peer surpassed rate limit {self.peer_host}, message: {message_type}, "
                        f"port {self.peer_port} but not disconnecting")
                    return full_message_loaded
            return full_message_loaded
        elif message.type == WSMsgType.ERROR:
            self.log.error(f"WebSocket Error: {message}")
            if message.data.code == WSCloseCode.MESSAGE_TOO_BIG:
                asyncio.create_task(self.close(300))
            else:
                asyncio.create_task(self.close())
            await asyncio.sleep(3)

        else:
            self.log.error(f"Unexpected WebSocket message type: {message}")
            asyncio.create_task(self.close())
            await asyncio.sleep(3)
        return None
예제 #7
0
    async def send_request(self, message_no_id: Message,
                           timeout: int) -> Optional[Message]:
        """Sends a message and waits for a response."""
        if self.closed:
            return None

        # We will wait for this event, it will be set either by the response, or the timeout
        event = asyncio.Event()

        # The request nonce is an integer between 0 and 2**16 - 1, which is used to match requests to responses
        # If is_outbound, 0 <= nonce < 2^15, else  2^15 <= nonce < 2^16
        request_id = self.request_nonce
        if self.is_outbound:
            self.request_nonce = uint16(self.request_nonce +
                                        1) if self.request_nonce != (
                                            2**15 - 1) else uint16(0)
        else:
            self.request_nonce = (uint16(self.request_nonce +
                                         1) if self.request_nonce !=
                                  (2**16 - 1) else uint16(2**15))

        message = Message(message_no_id.type, request_id, message_no_id.data)

        # TODO: address hint error and remove ignore
        #       error: Invalid index type "Optional[uint16]" for "Dict[bytes32, Event]"; expected type "bytes32"
        #       [index]
        self.pending_requests[message.id] = event  # type: ignore[index]
        await self.outgoing_queue.put(message)

        # If the timeout passes, we set the event
        async def time_out(req_id, req_timeout):
            try:
                await asyncio.sleep(req_timeout)
                if req_id in self.pending_requests:
                    self.pending_requests[req_id].set()
            except asyncio.CancelledError:
                if req_id in self.pending_requests:
                    self.pending_requests[req_id].set()
                raise

        timeout_task = asyncio.create_task(time_out(message.id, timeout))
        # TODO: address hint error and remove ignore
        #       error: Invalid index type "Optional[uint16]" for "Dict[bytes32, Task[Any]]"; expected type "bytes32"
        #       [index]
        self.pending_timeouts[message.id] = timeout_task  # type: ignore[index]
        await event.wait()

        # TODO: address hint error and remove ignore
        #       error: No overload variant of "pop" of "MutableMapping" matches argument type "Optional[uint16]"
        #       [call-overload]
        #       note: Possible overload variants:
        #       note:     def pop(self, key: bytes32) -> Event
        #       note:     def [_T] pop(self, key: bytes32, default: Union[Event, _T] = ...) -> Union[Event, _T]
        self.pending_requests.pop(message.id)  # type: ignore[call-overload]
        result: Optional[Message] = None
        if message.id in self.request_results:
            # TODO: address hint error and remove ignore
            #       error: Invalid index type "Optional[uint16]" for "Dict[bytes32, Message]"; expected type "bytes32"
            #       [index]
            result = self.request_results[message.id]  # type: ignore[index]
            assert result is not None
            self.log.debug(
                f"<- {ProtocolMessageTypes(result.type).name} from: {self.peer_host}:{self.peer_port}"
            )
            # TODO: address hint error and remove ignore
            #       error: No overload variant of "pop" of "MutableMapping" matches argument type "Optional[uint16]"
            #       [call-overload]
            #       note: Possible overload variants:
            #       note:     def pop(self, key: bytes32) -> Message
            #       note:     def [_T] pop(self, key: bytes32, default: Union[Message, _T] = ...) -> Union[Message, _T]
            self.request_results.pop(result.id)  # type: ignore[call-overload]

        return result