Ejemplo n.º 1
0
    async def perform_handshake(
        self, network_id: bytes32, protocol_version: str, server_port: int, local_type: NodeType
    ):
        if self.is_outbound:
            outbound_handshake = make_msg(
                ProtocolMessageTypes.handshake,
                Handshake(
                    network_id,
                    protocol_version,
                    chia_full_version_str(),
                    uint16(server_port),
                    uint8(local_type.value),
                ),
            )
            payload: Optional[Payload] = Payload(outbound_handshake, None)
            assert payload is not None
            await self._send_message(payload)
            payload = await self._read_one_message()
            if payload is None:
                raise ProtocolError(Err.INVALID_HANDSHAKE)
            inbound_handshake = Handshake.from_bytes(payload.msg.data)
            if ProtocolMessageTypes(payload.msg.type) != ProtocolMessageTypes.handshake:
                raise ProtocolError(Err.INVALID_HANDSHAKE)
            if inbound_handshake.protocol_version != protocol_version:
                raise ProtocolError(Err.INCOMPATIBLE_PROTOCOL_VERSION)
            self.peer_server_port = inbound_handshake.server_port
            self.connection_type = NodeType(inbound_handshake.node_type)

        else:
            try:
                payload = await self._read_one_message()
            except Exception:
                raise ProtocolError(Err.INVALID_HANDSHAKE)

            if payload is None:
                raise ProtocolError(Err.INVALID_HANDSHAKE)
            inbound_handshake = Handshake.from_bytes(payload.msg.data)
            if ProtocolMessageTypes(payload.msg.type) != ProtocolMessageTypes.handshake:
                raise ProtocolError(Err.INVALID_HANDSHAKE)
            if inbound_handshake.protocol_version != protocol_version:
                raise ProtocolError(Err.INCOMPATIBLE_PROTOCOL_VERSION)
            outbound_handshake = make_msg(
                ProtocolMessageTypes.handshake,
                Handshake(
                    network_id,
                    protocol_version,
                    chia_full_version_str(),
                    uint16(server_port),
                    uint8(local_type.value),
                ),
            )
            payload = Payload(outbound_handshake, None)
            await self._send_message(payload)
            self.peer_server_port = inbound_handshake.server_port
            self.connection_type = NodeType(inbound_handshake.node_type)

        self.outbound_task = asyncio.create_task(self.outbound_handler())
        self.inbound_task = asyncio.create_task(self.inbound_handler())
        return True
Ejemplo n.º 2
0
    async def test_transaction_freeze(self, wallet_node_30_freeze):
        num_blocks = 5
        full_nodes, wallets = wallet_node_30_freeze
        full_node_api: FullNodeSimulator = full_nodes[0]
        full_node_server = full_node_api.server
        wallet_node, server_2 = wallets[0]
        wallet = wallet_node.wallet_state_manager.main_wallet
        ph = await wallet.get_new_puzzlehash()

        incoming_queue, node_id = await add_dummy_connection(full_node_server, 12312)

        await server_2.start_client(PeerInfo(self_hostname, uint16(full_node_server._port)), None)
        for i in range(num_blocks):
            await full_node_api.farm_new_transaction_block(FarmNewBlockProtocol(ph))

        funds = sum(
            [calculate_pool_reward(uint32(i)) + calculate_base_farmer_reward(uint32(i)) for i in range(1, num_blocks)]
        )
        # funds += calculate_base_farmer_reward(0)
        await asyncio.sleep(2)
        print(await wallet.get_confirmed_balance(), funds)
        await time_out_assert(10, wallet.get_confirmed_balance, funds)

        tx: TransactionRecord = await wallet.generate_signed_transaction(100, ph, 0)
        spend = wallet_protocol.SendTransaction(tx.spend_bundle)
        response = await full_node_api.send_transaction(spend)
        assert wallet_protocol.TransactionAck.from_bytes(response.data).status == MempoolInclusionStatus.FAILED

        new_spend = full_node_protocol.NewTransaction(tx.spend_bundle.name(), 1, 0)
        response = await full_node_api.new_transaction(new_spend)
        assert response is None

        peer = full_node_server.all_connections[node_id]
        new_spend = full_node_protocol.RespondTransaction(tx.spend_bundle)
        response = await full_node_api.respond_transaction(new_spend, peer=peer)
        assert response is None

        for i in range(26):
            await full_node_api.farm_new_transaction_block(FarmNewBlockProtocol(ph))

        new_spend = full_node_protocol.NewTransaction(tx.spend_bundle.name(), 1, 0)
        response = await full_node_api.new_transaction(new_spend)
        assert response is not None
        assert ProtocolMessageTypes(response.type) == ProtocolMessageTypes.request_transaction

        tx: TransactionRecord = await wallet.generate_signed_transaction(100, ph, 0)
        spend = wallet_protocol.SendTransaction(tx.spend_bundle)
        response = await full_node_api.send_transaction(spend)
        assert response is not None
        assert wallet_protocol.TransactionAck.from_bytes(response.data).status == MempoolInclusionStatus.SUCCESS
        assert ProtocolMessageTypes(response.type) == ProtocolMessageTypes.transaction_ack
Ejemplo n.º 3
0
        async def invoke(*args, **kwargs):
            timeout = 60
            if "timeout" in kwargs:
                timeout = kwargs["timeout"]
            attribute = getattr(class_for_type(self.connection_type),
                                attr_name, None)
            if attribute is None:
                raise AttributeError(
                    f"Node type {self.connection_type} does not have method {attr_name}"
                )

            msg = Message(
                uint8(getattr(ProtocolMessageTypes, attr_name).value), args[0],
                None)
            request_start_t = time.time()
            result = await self.create_request(msg, timeout)
            self.log.debug(
                f"Time for request {attr_name}: {self.get_peer_info()} = {time.time() - request_start_t}, "
                f"None? {result is None}")
            if result is not None:
                ret_attr = getattr(class_for_type(self.local_type),
                                   ProtocolMessageTypes(result.type).name,
                                   None)

                req_annotations = ret_attr.__annotations__
                req = None
                for key in req_annotations:
                    if key == "return" or key == "peer":
                        continue
                    else:
                        req = req_annotations[key]
                assert req is not None
                result = req.from_bytes(result.data)
            return result
Ejemplo n.º 4
0
            async def api_call(full_message: Message, connection: WSChiaConnection, task_id):
                start_time = time.time()
                try:
                    if self.received_message_callback is not None:
                        await self.received_message_callback(connection)
                    connection.log.info(
                        f"<- {ProtocolMessageTypes(full_message.type).name} from peer "
                        f"{connection.peer_node_id} {connection.peer_host}"
                    )
                    message_type: str = ProtocolMessageTypes(full_message.type).name

                    f = getattr(self.api, message_type, None)

                    if f is None:
                        self.log.error(f"Non existing function: {message_type}")
                        raise ProtocolError(Err.INVALID_PROTOCOL_MESSAGE, [message_type])

                    if not hasattr(f, "api_function"):
                        self.log.error(f"Peer trying to call non api function {message_type}")
                        raise ProtocolError(Err.INVALID_PROTOCOL_MESSAGE, [message_type])

                    if hasattr(f, "peer_required"):
                        coroutine = f(full_message.data, connection)
                    else:
                        coroutine = f(full_message.data)

                    async def wrapped_coroutine():
                        try:
                            result = await coroutine
                            return result
                        except asyncio.CancelledError:
                            pass
                        except Exception as e:
                            tb = traceback.format_exc()
                            connection.log.error(f"Exception: {e}, {connection.get_peer_info()}. {tb}")
                            raise e

                    response: Optional[Message] = await asyncio.wait_for(wrapped_coroutine(), timeout=300)
                    connection.log.debug(
                        f"Time taken to process {message_type} from {connection.peer_node_id} is "
                        f"{time.time() - start_time} seconds"
                    )

                    if response is not None:
                        response_message = Message(response.type, response.data, full_message.id)
                        await connection.reply_to_request(response_message)
                except Exception as e:
                    if self.connection_close_task is None:
                        tb = traceback.format_exc()
                        connection.log.error(f"Exception: {e}, closing connection {connection.get_peer_info()}. {tb}")
                    else:
                        connection.log.debug(f"Exception: {e} while closing connection")
                        pass
                    # TODO: actually throw one of the errors from errors.py and pass this to close
                    await connection.close(self.api_exception_ban_seconds, WSCloseCode.PROTOCOL_ERROR, Err.UNKNOWN)
                finally:
                    if task_id in self.api_tasks:
                        self.api_tasks.pop(task_id)
                    if task_id in self.tasks_from_peer[connection.peer_node_id]:
                        self.tasks_from_peer[connection.peer_node_id].remove(task_id)
Ejemplo n.º 5
0
            async def api_call(payload: Payload, connection: WSChiaConnection,
                               task_id):
                start_time = time.time()
                try:
                    if self.received_message_callback is not None:
                        await self.received_message_callback(connection)
                    full_message = payload.msg
                    connection.log.info(
                        f"<- {ProtocolMessageTypes(full_message.type).name} from peer "
                        f"{connection.peer_node_id} {connection.peer_host}")
                    message_type: str = ProtocolMessageTypes(
                        full_message.type).name

                    f = getattr(self.api, message_type, None)

                    if f is None:
                        self.log.error(
                            f"Non existing function: {message_type}")
                        raise ProtocolError(Err.INVALID_PROTOCOL_MESSAGE,
                                            [message_type])

                    if not hasattr(f, "api_function"):
                        self.log.error(
                            f"Peer trying to call non api function {message_type}"
                        )
                        raise ProtocolError(Err.INVALID_PROTOCOL_MESSAGE,
                                            [message_type])

                    if hasattr(f, "peer_required"):
                        coroutine = f(full_message.data, connection)
                    else:
                        coroutine = f(full_message.data)
                    response: Optional[Message] = await asyncio.wait_for(
                        coroutine, timeout=300)
                    connection.log.debug(
                        f"Time taken to process {message_type} from {connection.peer_node_id} is "
                        f"{time.time() - start_time} seconds")

                    if response is not None:
                        payload_id = payload.id
                        response_payload = Payload(response, payload_id)
                        await connection.reply_to_request(response_payload)
                except Exception as e:
                    if self.connection_close_task is None:
                        tb = traceback.format_exc()
                        connection.log.error(
                            f"Exception: {e}, closing connection {connection.get_peer_info()}. {tb}"
                        )
                    else:
                        connection.log.debug(
                            f"Exception: {e} while closing connection")
                        pass
                    await connection.close()
                finally:
                    if task_id in self.api_tasks:
                        self.api_tasks.pop(task_id)
                    if task_id in self.tasks_from_peer[
                            connection.peer_node_id]:
                        self.tasks_from_peer[connection.peer_node_id].remove(
                            task_id)
Ejemplo n.º 6
0
 async def bool_f():
     if incoming_queue.qsize() < count:
         return False
     for _ in range(count):
         response = (await incoming_queue.get())[0].type
         if ProtocolMessageTypes(response).name != msg_name:
             # log.warning(f"time_out_message: found {response} instead of {msg_name}")
             return False
     return True
Ejemplo n.º 7
0
    def process_msg_and_check(self, message: Message) -> bool:
        """
        Returns True if message can be processed successfully, false if a rate limit is passed.
        """

        current_minute = time.time() // self.reset_seconds
        if current_minute != self.current_minute:
            self.current_minute = current_minute
            self.message_counts = Counter()
            self.message_cumulative_sizes = Counter()
            self.non_tx_message_counts = 0
            self.non_tx_cumulative_size = 0
        try:
            message_type = ProtocolMessageTypes(message.type)
        except Exception as e:
            log.warning(f"Invalid message: {message.type}, {e}")
            return True

        self.message_counts[message_type] += 1
        self.message_cumulative_sizes[message_type] += len(message.data)
        proportion_of_limit = self.percentage_of_limit / 100

        limits = DEFAULT_SETTINGS
        if message_type in rate_limits_tx:
            limits = rate_limits_tx[message_type]
        elif message_type in rate_limits_other:
            limits = rate_limits_other[message_type]
            self.non_tx_message_counts += 1
            self.non_tx_cumulative_size += len(message.data)
            if self.non_tx_message_counts > NON_TX_FREQ * proportion_of_limit:
                return False
            if self.non_tx_cumulative_size > NON_TX_MAX_TOTAL_SIZE * proportion_of_limit:
                return False
        else:
            log.warning(
                f"Message type {message_type} not found in rate limits")

        if limits.max_total_size is None:
            limits = dataclasses.replace(limits,
                                         max_total_size=limits.frequency *
                                         limits.max_size)

        if self.message_counts[
                message_type] > limits.frequency * proportion_of_limit:
            return False
        if len(message.data) > limits.max_size:
            return False
        if self.message_cumulative_sizes[
                message_type] > limits.max_total_size * proportion_of_limit:
            return False
        return True
Ejemplo n.º 8
0
    async def perform_handshake(self, network_id: str, protocol_version: str,
                                server_port: int, local_type: NodeType):
        if self.is_outbound:
            outbound_handshake = make_msg(
                ProtocolMessageTypes.handshake,
                Handshake(
                    network_id,
                    protocol_version,
                    chia_full_version_str(),
                    uint16(server_port),
                    uint8(local_type.value),
                    [(uint16(Capability.BASE.value), "1")],
                ),
            )
            assert outbound_handshake is not None
            await self._send_message(outbound_handshake)
            inbound_handshake_msg = await self._read_one_message()
            if inbound_handshake_msg is None:
                raise ProtocolError(Err.INVALID_HANDSHAKE)
            inbound_handshake = Handshake.from_bytes(
                inbound_handshake_msg.data)
            if ProtocolMessageTypes(inbound_handshake_msg.type
                                    ) != ProtocolMessageTypes.handshake:
                raise ProtocolError(Err.INVALID_HANDSHAKE)
            if inbound_handshake.network_id != network_id:
                raise ProtocolError(Err.INCOMPATIBLE_NETWORK_ID)

            self.peer_server_port = inbound_handshake.server_port
            self.connection_type = NodeType(inbound_handshake.node_type)

        else:
            try:
                message = await self._read_one_message()
            except Exception:
                raise ProtocolError(Err.INVALID_HANDSHAKE)

            if message is None:
                raise ProtocolError(Err.INVALID_HANDSHAKE)
            inbound_handshake = Handshake.from_bytes(message.data)
            if ProtocolMessageTypes(
                    message.type) != ProtocolMessageTypes.handshake:
                raise ProtocolError(Err.INVALID_HANDSHAKE)
            if inbound_handshake.network_id != network_id:
                raise ProtocolError(Err.INCOMPATIBLE_NETWORK_ID)
            outbound_handshake = make_msg(
                ProtocolMessageTypes.handshake,
                Handshake(
                    network_id,
                    protocol_version,
                    chia_full_version_str(),
                    uint16(server_port),
                    uint8(local_type.value),
                    [(uint16(Capability.BASE.value), "1")],
                ),
            )
            await self._send_message(outbound_handshake)
            self.peer_server_port = inbound_handshake.server_port
            self.connection_type = NodeType(inbound_handshake.node_type)

        self.outbound_task = asyncio.create_task(self.outbound_handler())
        self.inbound_task = asyncio.create_task(self.inbound_handler())
        return True