Beispiel #1
0
            async def api_call(full_message: Message, connection: WSChiaConnection, task_id):
                start_time = time.time()
                try:
                    if self.received_message_callback is not None:
                        await self.received_message_callback(connection)
                    connection.log.info(
                        f"<- {ProtocolMessageTypes(full_message.type).name} from peer "
                        f"{connection.peer_node_id} {connection.peer_host}"
                    )
                    message_type: str = ProtocolMessageTypes(full_message.type).name

                    f = getattr(self.api, message_type, None)

                    if f is None:
                        self.log.error(f"Non existing function: {message_type}")
                        raise ProtocolError(Err.INVALID_PROTOCOL_MESSAGE, [message_type])

                    if not hasattr(f, "api_function"):
                        self.log.error(f"Peer trying to call non api function {message_type}")
                        raise ProtocolError(Err.INVALID_PROTOCOL_MESSAGE, [message_type])

                    if hasattr(f, "peer_required"):
                        coroutine = f(full_message.data, connection)
                    else:
                        coroutine = f(full_message.data)

                    async def wrapped_coroutine():
                        try:
                            result = await coroutine
                            return result
                        except asyncio.CancelledError:
                            pass
                        except Exception as e:
                            tb = traceback.format_exc()
                            connection.log.error(f"Exception: {e}, {connection.get_peer_info()}. {tb}")
                            raise e

                    response: Optional[Message] = await asyncio.wait_for(wrapped_coroutine(), timeout=300)
                    connection.log.debug(
                        f"Time taken to process {message_type} from {connection.peer_node_id} is "
                        f"{time.time() - start_time} seconds"
                    )

                    if response is not None:
                        response_message = Message(response.type, response.data, full_message.id)
                        await connection.reply_to_request(response_message)
                except Exception as e:
                    if self.connection_close_task is None:
                        tb = traceback.format_exc()
                        connection.log.error(f"Exception: {e}, closing connection {connection.get_peer_info()}. {tb}")
                    else:
                        connection.log.debug(f"Exception: {e} while closing connection")
                        pass
                    # TODO: actually throw one of the errors from errors.py and pass this to close
                    await connection.close(self.api_exception_ban_seconds, WSCloseCode.PROTOCOL_ERROR, Err.UNKNOWN)
                finally:
                    if task_id in self.api_tasks:
                        self.api_tasks.pop(task_id)
                    if task_id in self.tasks_from_peer[connection.peer_node_id]:
                        self.tasks_from_peer[connection.peer_node_id].remove(task_id)
            async def api_call(payload: Payload, connection: WSChiaConnection,
                               task_id):
                start_time = time.time()
                try:
                    if self.received_message_callback is not None:
                        await self.received_message_callback(connection)
                    full_message = payload.msg
                    connection.log.info(
                        f"<- {ProtocolMessageTypes(full_message.type).name} from peer "
                        f"{connection.peer_node_id} {connection.peer_host}")
                    message_type: str = ProtocolMessageTypes(
                        full_message.type).name

                    f = getattr(self.api, message_type, None)

                    if f is None:
                        self.log.error(
                            f"Non existing function: {message_type}")
                        raise ProtocolError(Err.INVALID_PROTOCOL_MESSAGE,
                                            [message_type])

                    if not hasattr(f, "api_function"):
                        self.log.error(
                            f"Peer trying to call non api function {message_type}"
                        )
                        raise ProtocolError(Err.INVALID_PROTOCOL_MESSAGE,
                                            [message_type])

                    if hasattr(f, "peer_required"):
                        coroutine = f(full_message.data, connection)
                    else:
                        coroutine = f(full_message.data)
                    response: Optional[Message] = await asyncio.wait_for(
                        coroutine, timeout=300)
                    connection.log.debug(
                        f"Time taken to process {message_type} from {connection.peer_node_id} is "
                        f"{time.time() - start_time} seconds")

                    if response is not None:
                        payload_id = payload.id
                        response_payload = Payload(response, payload_id)
                        await connection.reply_to_request(response_payload)
                except Exception as e:
                    if self.connection_close_task is None:
                        tb = traceback.format_exc()
                        connection.log.error(
                            f"Exception: {e}, closing connection {connection.get_peer_info()}. {tb}"
                        )
                    else:
                        connection.log.debug(
                            f"Exception: {e} while closing connection")
                        pass
                    await connection.close()
                finally:
                    if task_id in self.api_tasks:
                        self.api_tasks.pop(task_id)
                    if task_id in self.tasks_from_peer[
                            connection.peer_node_id]:
                        self.tasks_from_peer[connection.peer_node_id].remove(
                            task_id)
Beispiel #3
0
async def handle_message(
    triple: Tuple[ChiaConnection, Message, PeerConnections], api: Any
) -> AsyncGenerator[Tuple[ChiaConnection, OutboundMessage, PeerConnections],
                    None]:
    """
    Async generator which takes messages, parses, them, executes the right
    api function, and yields responses (to same connection, propagated, etc).
    """
    connection, full_message, global_connections = triple

    try:
        if len(full_message.function) == 0 or full_message.function.startswith(
                "_"):
            # This prevents remote calling of private methods that start with "_"
            raise ProtocolError(Err.INVALID_PROTOCOL_MESSAGE,
                                [full_message.function])

        connection.log.info(
            f"<- {full_message.function} from peer {connection.get_peername()}"
        )
        if full_message.function == "ping":
            ping_msg = Ping(full_message.data["nonce"])
            assert connection.connection_type
            outbound_message = OutboundMessage(
                connection.connection_type,
                Message("pong", Pong(ping_msg.nonce)),
                Delivery.RESPOND,
            )
            yield connection, outbound_message, global_connections
            return
        elif full_message.function == "pong":
            return

        f_with_peer_name = getattr(api,
                                   full_message.function + "_with_peer_name",
                                   None)

        if f_with_peer_name is not None:
            result = f_with_peer_name(full_message.data,
                                      connection.get_peername())
        else:
            f = getattr(api, full_message.function, None)

            if f is None:
                raise ProtocolError(Err.INVALID_PROTOCOL_MESSAGE,
                                    [full_message.function])

            result = f(full_message.data)

        if isinstance(result, AsyncGenerator):
            async for outbound_message in result:
                yield connection, outbound_message, global_connections
        else:
            await result
    except Exception:
        tb = traceback.format_exc()
        connection.log.error(f"Error, closing connection {connection}. {tb}")
        # TODO: Exception means peer gave us invalid information, so ban this peer.
        global_connections.close(connection)
    async def perform_handshake(self, network_id, protocol_version, node_id,
                                server_port, local_type):
        if self.is_outbound:
            outbound_handshake = Message(
                "handshake",
                Handshake(
                    network_id,
                    protocol_version,
                    node_id,
                    uint16(server_port),
                    local_type,
                ),
            )
            payload = Payload(outbound_handshake, None)
            await self._send_message(payload)
            payload = await self._read_one_message()
            inbound_handshake = Handshake(**payload.msg.data)
            if payload.msg.function != "handshake" or not inbound_handshake or not inbound_handshake.node_type:
                raise ProtocolError(Err.INVALID_HANDSHAKE)
            self.peer_node_id = inbound_handshake.node_id
            self.peer_server_port = int(inbound_handshake.server_port)
            self.connection_type = inbound_handshake.node_type

        else:
            payload = await self._read_one_message()
            inbound_handshake = Handshake(**payload.msg.data)
            if payload.msg.function != "handshake" or not inbound_handshake or not inbound_handshake.node_type:
                raise ProtocolError(Err.INVALID_HANDSHAKE)
            outbound_handshake = Message(
                "handshake",
                Handshake(
                    network_id,
                    protocol_version,
                    node_id,
                    uint16(server_port),
                    local_type,
                ),
            )
            payload = Payload(outbound_handshake, None)
            await self._send_message(payload)
            self.peer_node_id = inbound_handshake.node_id
            self.peer_server_port = int(inbound_handshake.server_port)
            self.connection_type = inbound_handshake.node_type

        if self.peer_node_id == node_id:
            raise ProtocolError(Err.SELF_CONNECTION)

        self.outbound_task = asyncio.create_task(self.outbound_handler())
        self.inbound_task = asyncio.create_task(self.inbound_handler())
        return True
Beispiel #5
0
    def add(self, connection: ChiaConnection) -> bool:
        if not connection.is_outbound:
            if (connection.connection_type is not None
                    and not self.accept_inbound_connections(
                        connection.connection_type)):
                raise ProtocolError(Err.MAX_INBOUND_CONNECTIONS_REACHED)

        for c in self._all_connections:
            if c.node_id == connection.node_id:
                raise ProtocolError(Err.DUPLICATE_CONNECTION, [False])
        self._all_connections.append(connection)

        if connection.connection_type == NodeType.FULL_NODE:
            self._state_changed("add_connection")
            if self.introducer_peers is not None:
                return self.introducer_peers.add(connection.get_peer_info())
        self._state_changed("add_connection")
        return True
Beispiel #6
0
    async def perform_handshake(
        self, network_id: bytes32, protocol_version: str, server_port: int, local_type: NodeType
    ):
        if self.is_outbound:
            outbound_handshake = make_msg(
                ProtocolMessageTypes.handshake,
                Handshake(
                    network_id,
                    protocol_version,
                    chia_full_version_str(),
                    uint16(server_port),
                    uint8(local_type.value),
                ),
            )
            payload: Optional[Payload] = Payload(outbound_handshake, None)
            assert payload is not None
            await self._send_message(payload)
            payload = await self._read_one_message()
            if payload is None:
                raise ProtocolError(Err.INVALID_HANDSHAKE)
            inbound_handshake = Handshake.from_bytes(payload.msg.data)
            if ProtocolMessageTypes(payload.msg.type) != ProtocolMessageTypes.handshake:
                raise ProtocolError(Err.INVALID_HANDSHAKE)
            if inbound_handshake.protocol_version != protocol_version:
                raise ProtocolError(Err.INCOMPATIBLE_PROTOCOL_VERSION)
            self.peer_server_port = inbound_handshake.server_port
            self.connection_type = NodeType(inbound_handshake.node_type)

        else:
            try:
                payload = await self._read_one_message()
            except Exception:
                raise ProtocolError(Err.INVALID_HANDSHAKE)

            if payload is None:
                raise ProtocolError(Err.INVALID_HANDSHAKE)
            inbound_handshake = Handshake.from_bytes(payload.msg.data)
            if ProtocolMessageTypes(payload.msg.type) != ProtocolMessageTypes.handshake:
                raise ProtocolError(Err.INVALID_HANDSHAKE)
            if inbound_handshake.protocol_version != protocol_version:
                raise ProtocolError(Err.INCOMPATIBLE_PROTOCOL_VERSION)
            outbound_handshake = make_msg(
                ProtocolMessageTypes.handshake,
                Handshake(
                    network_id,
                    protocol_version,
                    chia_full_version_str(),
                    uint16(server_port),
                    uint8(local_type.value),
                ),
            )
            payload = Payload(outbound_handshake, None)
            await self._send_message(payload)
            self.peer_server_port = inbound_handshake.server_port
            self.connection_type = NodeType(inbound_handshake.node_type)

        self.outbound_task = asyncio.create_task(self.outbound_handler())
        self.inbound_task = asyncio.create_task(self.inbound_handler())
        return True
Beispiel #7
0
async def perform_handshake(
    pair: Tuple[Connection, PeerConnections],
    srwt_aiter: push_aiter,
    outbound_handshake: Message,
) -> AsyncGenerator[Tuple[Connection, PeerConnections], None]:
    """
    Performs handshake with this new connection, and yields the connection. If the handshake
    is unsuccessful, or we already have a connection with this peer, the connection is closed,
    and nothing is yielded.
    """
    connection, global_connections = pair

    # Send handshake message
    try:
        await connection.send(outbound_handshake)

        # Read handshake message
        full_message = await connection.read_one_message()
        inbound_handshake = Handshake(**full_message.data)
        if (
            full_message.function != "handshake"
            or not inbound_handshake
            or not inbound_handshake.node_type
        ):
            raise ProtocolError(Err.INVALID_HANDSHAKE)

        if inbound_handshake.node_id == outbound_handshake.data.node_id:
            raise ProtocolError(Err.SELF_CONNECTION)

        # Makes sure that we only start one connection with each peer
        connection.node_id = inbound_handshake.node_id
        connection.peer_server_port = int(inbound_handshake.server_port)
        connection.connection_type = inbound_handshake.node_type

        if srwt_aiter.is_stopped():
            raise Exception("No longer accepting handshakes, closing.")

        if not global_connections.add(connection):
            raise ProtocolError(Err.DUPLICATE_CONNECTION, [False])

        # Send Ack message
        await connection.send(Message("handshake_ack", HandshakeAck()))

        # Read Ack message
        full_message = await connection.read_one_message()
        if full_message.function != "handshake_ack":
            raise ProtocolError(Err.INVALID_ACK)

        if inbound_handshake.version != protocol_version:
            raise ProtocolError(
                Err.INCOMPATIBLE_PROTOCOL_VERSION,
                [protocol_version, inbound_handshake.version],
            )

        connection.log.info(
            (
                f"Handshake with {NodeType(connection.connection_type).name} {connection.get_peername()} "
                f"{connection.node_id}"
                f" established"
            )
        )
        # Only yield a connection if the handshake is succesful and the connection is not a duplicate.
        yield connection, global_connections
    except (ProtocolError, asyncio.IncompleteReadError, OSError, Exception,) as e:
        connection.log.warning(f"{e}, handshake not completed. Connection not created.")
        # Make sure to close the connection even if it's not in global connections
        connection.close()
        # Remove the conenction from global connections
        global_connections.close(connection)
    async def perform_handshake(self, network_id: str, protocol_version: str,
                                server_port: int, local_type: NodeType):
        if self.is_outbound:
            outbound_handshake = make_msg(
                ProtocolMessageTypes.handshake,
                Handshake(
                    network_id,
                    protocol_version,
                    chia_full_version_str(),
                    uint16(server_port),
                    uint8(local_type.value),
                    [(uint16(Capability.BASE.value), "1")],
                ),
            )
            assert outbound_handshake is not None
            await self._send_message(outbound_handshake)
            inbound_handshake_msg = await self._read_one_message()
            if inbound_handshake_msg is None:
                raise ProtocolError(Err.INVALID_HANDSHAKE)
            inbound_handshake = Handshake.from_bytes(
                inbound_handshake_msg.data)
            if ProtocolMessageTypes(inbound_handshake_msg.type
                                    ) != ProtocolMessageTypes.handshake:
                raise ProtocolError(Err.INVALID_HANDSHAKE)
            if inbound_handshake.network_id != network_id:
                raise ProtocolError(Err.INCOMPATIBLE_NETWORK_ID)

            self.peer_server_port = inbound_handshake.server_port
            self.connection_type = NodeType(inbound_handshake.node_type)

        else:
            try:
                message = await self._read_one_message()
            except Exception:
                raise ProtocolError(Err.INVALID_HANDSHAKE)

            if message is None:
                raise ProtocolError(Err.INVALID_HANDSHAKE)
            inbound_handshake = Handshake.from_bytes(message.data)
            if ProtocolMessageTypes(
                    message.type) != ProtocolMessageTypes.handshake:
                raise ProtocolError(Err.INVALID_HANDSHAKE)
            if inbound_handshake.network_id != network_id:
                raise ProtocolError(Err.INCOMPATIBLE_NETWORK_ID)
            outbound_handshake = make_msg(
                ProtocolMessageTypes.handshake,
                Handshake(
                    network_id,
                    protocol_version,
                    chia_full_version_str(),
                    uint16(server_port),
                    uint8(local_type.value),
                    [(uint16(Capability.BASE.value), "1")],
                ),
            )
            await self._send_message(outbound_handshake)
            self.peer_server_port = inbound_handshake.server_port
            self.connection_type = NodeType(inbound_handshake.node_type)

        self.outbound_task = asyncio.create_task(self.outbound_handler())
        self.inbound_task = asyncio.create_task(self.inbound_handler())
        return True