예제 #1
0
    async def perform_handshake(self, network_id: str, protocol_version: str, server_port: int, local_type: NodeType):
        if self.is_outbound:
            outbound_handshake = make_msg(
                ProtocolMessageTypes.handshake,
                Handshake(
                    network_id,
                    protocol_version,
                    chia_full_version_str(),
                    uint16(server_port),
                    uint8(local_type.value),
                    [(uint16(Capability.BASE.value), "1")],
                ),
            )
            assert outbound_handshake is not None
            await self._send_message(outbound_handshake)
            inbound_handshake_msg = await self._read_one_message()
            if inbound_handshake_msg is None:
                raise ProtocolError(Err.INVALID_HANDSHAKE)
            inbound_handshake = Handshake.from_bytes(inbound_handshake_msg.data)
            if ProtocolMessageTypes(inbound_handshake_msg.type) != ProtocolMessageTypes.handshake:
                raise ProtocolError(Err.INVALID_HANDSHAKE)
            if inbound_handshake.network_id != network_id:
                raise ProtocolError(Err.INCOMPATIBLE_NETWORK_ID)

            self.peer_server_port = inbound_handshake.server_port
            self.connection_type = NodeType(inbound_handshake.node_type)

        else:
            try:
                message = await self._read_one_message()
            except Exception:
                raise ProtocolError(Err.INVALID_HANDSHAKE)

            if message is None:
                raise ProtocolError(Err.INVALID_HANDSHAKE)
            inbound_handshake = Handshake.from_bytes(message.data)
            if ProtocolMessageTypes(message.type) != ProtocolMessageTypes.handshake:
                raise ProtocolError(Err.INVALID_HANDSHAKE)
            if inbound_handshake.network_id != network_id:
                raise ProtocolError(Err.INCOMPATIBLE_NETWORK_ID)
            outbound_handshake = make_msg(
                ProtocolMessageTypes.handshake,
                Handshake(
                    network_id,
                    protocol_version,
                    chia_full_version_str(),
                    uint16(server_port),
                    uint8(local_type.value),
                    [(uint16(Capability.BASE.value), "1")],
                ),
            )
            await self._send_message(outbound_handshake)
            self.peer_server_port = inbound_handshake.server_port
            self.connection_type = NodeType(inbound_handshake.node_type)

        self.outbound_task = asyncio.create_task(self.outbound_handler())
        self.inbound_task = asyncio.create_task(self.inbound_handler())
        return True
예제 #2
0
 async def validate_broadcast_message_type(self, messages: List[Message],
                                           node_type: NodeType):
     for message in messages:
         if message_requires_reply(ProtocolMessageTypes(message.type)):
             # Internal protocol logic error - we will raise, blocking messages to all peers
             self.log.error(
                 f"Attempt to broadcast message requiring protocol response: {message.type}"
             )
             for _, connection in self.all_connections.items():
                 if connection.connection_type is node_type:
                     await connection.close(
                         self.invalid_protocol_ban_seconds,
                         WSCloseCode.INTERNAL_ERROR,
                         Err.INTERNAL_PROTOCOL_ERROR,
                     )
             raise ProtocolError(Err.INTERNAL_PROTOCOL_ERROR,
                                 [message.type])
예제 #3
0
        async def invoke(*args, **kwargs):
            timeout = 60
            if "timeout" in kwargs:
                timeout = kwargs["timeout"]
            attribute = getattr(class_for_type(self.connection_type),
                                attr_name, None)
            if attribute is None:
                raise AttributeError(
                    f"Node type {self.connection_type} does not have method {attr_name}"
                )

            msg: Message = Message(
                uint8(getattr(ProtocolMessageTypes, attr_name).value), None,
                args[0])
            request_start_t = time.time()
            result = await self.send_request(msg, timeout)
            self.log.debug(
                f"Time for request {attr_name}: {self.get_peer_logging()} = {time.time() - request_start_t}, "
                f"None? {result is None}")
            if result is not None:
                sent_message_type = ProtocolMessageTypes(msg.type)
                recv_message_type = ProtocolMessageTypes(result.type)
                if not message_response_ok(sent_message_type,
                                           recv_message_type):
                    # peer protocol violation
                    error_message = f"WSConnection.invoke sent message {sent_message_type.name} "
                    f"but received {recv_message_type.name}"
                    await self.ban_peer_bad_protocol(self.error_message)
                    raise ProtocolError(Err.INVALID_PROTOCOL_MESSAGE,
                                        [error_message])
                ret_attr = getattr(class_for_type(self.local_type),
                                   ProtocolMessageTypes(result.type).name,
                                   None)

                req_annotations = ret_attr.__annotations__
                req = None
                for key in req_annotations:
                    if key == "return" or key == "peer":
                        continue
                    else:
                        req = req_annotations[key]
                assert req is not None
                result = req.from_bytes(result.data)
            return result
예제 #4
0
            async def api_call(full_message: Message,
                               connection: WSChiaConnection, task_id):
                start_time = time.time()
                try:
                    if self.received_message_callback is not None:
                        await self.received_message_callback(connection)
                    connection.log.debug(
                        f"<- {ProtocolMessageTypes(full_message.type).name} from peer "
                        f"{connection.peer_node_id} {connection.peer_host}")
                    message_type: str = ProtocolMessageTypes(
                        full_message.type).name

                    f = getattr(self.api, message_type, None)

                    if f is None:
                        self.log.error(
                            f"Non existing function: {message_type}")
                        raise ProtocolError(Err.INVALID_PROTOCOL_MESSAGE,
                                            [message_type])

                    if not hasattr(f, "api_function"):
                        self.log.error(
                            f"Peer trying to call non api function {message_type}"
                        )
                        raise ProtocolError(Err.INVALID_PROTOCOL_MESSAGE,
                                            [message_type])

                    # If api is not ready ignore the request
                    if hasattr(self.api, "api_ready"):
                        if self.api.api_ready is False:
                            return None

                    timeout: Optional[int] = 600
                    if hasattr(f, "execute_task"):
                        # Don't timeout on methods with execute_task decorator, these need to run fully
                        self.execute_tasks.add(task_id)
                        timeout = None

                    if hasattr(f, "peer_required"):
                        coroutine = f(full_message.data, connection)
                    else:
                        coroutine = f(full_message.data)

                    async def wrapped_coroutine() -> Optional[Message]:
                        try:
                            result = await coroutine
                            return result
                        except asyncio.CancelledError:
                            pass
                        except Exception as e:
                            tb = traceback.format_exc()
                            connection.log.error(
                                f"Exception: {e}, {connection.get_peer_info()}. {tb}"
                            )
                            raise e
                        return None

                    response: Optional[Message] = await asyncio.wait_for(
                        wrapped_coroutine(), timeout=timeout)
                    connection.log.debug(
                        f"Time taken to process {message_type} from {connection.peer_node_id} is "
                        f"{time.time() - start_time} seconds")

                    if response is not None:
                        response_message = Message(response.type,
                                                   full_message.id,
                                                   response.data)
                        await connection.reply_to_request(response_message)
                except Exception as e:
                    if self.connection_close_task is None:
                        tb = traceback.format_exc()
                        connection.log.error(
                            f"Exception: {e} {type(e)}, closing connection {connection.get_peer_info()}. {tb}"
                        )
                    else:
                        connection.log.debug(
                            f"Exception: {e} while closing connection")
                    # TODO: actually throw one of the errors from errors.py and pass this to close
                    await connection.close(self.api_exception_ban_seconds,
                                           WSCloseCode.PROTOCOL_ERROR,
                                           Err.UNKNOWN)
                finally:
                    if task_id in self.api_tasks:
                        self.api_tasks.pop(task_id)
                    if task_id in self.tasks_from_peer[
                            connection.peer_node_id]:
                        self.tasks_from_peer[connection.peer_node_id].remove(
                            task_id)
                    if task_id in self.execute_tasks:
                        self.execute_tasks.remove(task_id)