コード例 #1
0
    async def perform_handshake(self, network_id: str, protocol_version: str, server_port: int, local_type: NodeType):
        if self.is_outbound:
            outbound_handshake = make_msg(
                ProtocolMessageTypes.handshake,
                Handshake(
                    network_id,
                    protocol_version,
                    chia_full_version_str(),
                    uint16(server_port),
                    uint8(local_type.value),
                    [(uint16(Capability.BASE.value), "1")],
                ),
            )
            assert outbound_handshake is not None
            await self._send_message(outbound_handshake)
            inbound_handshake_msg = await self._read_one_message()
            if inbound_handshake_msg is None:
                raise ProtocolError(Err.INVALID_HANDSHAKE)
            inbound_handshake = Handshake.from_bytes(inbound_handshake_msg.data)
            if ProtocolMessageTypes(inbound_handshake_msg.type) != ProtocolMessageTypes.handshake:
                raise ProtocolError(Err.INVALID_HANDSHAKE)
            if inbound_handshake.network_id != network_id:
                raise ProtocolError(Err.INCOMPATIBLE_NETWORK_ID)

            self.peer_server_port = inbound_handshake.server_port
            self.connection_type = NodeType(inbound_handshake.node_type)

        else:
            try:
                message = await self._read_one_message()
            except Exception:
                raise ProtocolError(Err.INVALID_HANDSHAKE)

            if message is None:
                raise ProtocolError(Err.INVALID_HANDSHAKE)
            inbound_handshake = Handshake.from_bytes(message.data)
            if ProtocolMessageTypes(message.type) != ProtocolMessageTypes.handshake:
                raise ProtocolError(Err.INVALID_HANDSHAKE)
            if inbound_handshake.network_id != network_id:
                raise ProtocolError(Err.INCOMPATIBLE_NETWORK_ID)
            outbound_handshake = make_msg(
                ProtocolMessageTypes.handshake,
                Handshake(
                    network_id,
                    protocol_version,
                    chia_full_version_str(),
                    uint16(server_port),
                    uint8(local_type.value),
                    [(uint16(Capability.BASE.value), "1")],
                ),
            )
            await self._send_message(outbound_handshake)
            self.peer_server_port = inbound_handshake.server_port
            self.connection_type = NodeType(inbound_handshake.node_type)

        self.outbound_task = asyncio.create_task(self.outbound_handler())
        self.inbound_task = asyncio.create_task(self.inbound_handler())
        return True
コード例 #2
0
    async def _send_message(self, message: Message):
        encoded: bytes = bytes(message)
        size = len(encoded)
        assert len(encoded) < (2**(LENGTH_BYTES * 8))
        if not self.outbound_rate_limiter.process_msg_and_check(message):
            if not is_localhost(self.peer_host):
                self.log.debug(
                    f"Rate limiting ourselves. message type: {ProtocolMessageTypes(message.type).name}, "
                    f"peer: {self.peer_host}")

                # TODO: fix this special case. This function has rate limits which are too low.
                if ProtocolMessageTypes(
                        message.type) != ProtocolMessageTypes.respond_peers:
                    asyncio.create_task(
                        self._wait_and_retry(message, self.outgoing_queue))

                return None
            else:
                self.log.debug(
                    f"Not rate limiting ourselves. message type: {ProtocolMessageTypes(message.type).name}, "
                    f"peer: {self.peer_host}")

        await self.ws.send_bytes(encoded)
        self.log.debug(
            f"-> {ProtocolMessageTypes(message.type).name} to peer {self.peer_host} {self.peer_node_id}"
        )
        self.bytes_written += size
コード例 #3
0
        async def invoke(*args, **kwargs):
            timeout = 60
            if "timeout" in kwargs:
                timeout = kwargs["timeout"]
            attribute = getattr(class_for_type(self.connection_type),
                                attr_name, None)
            if attribute is None:
                raise AttributeError(
                    f"Node type {self.connection_type} does not have method {attr_name}"
                )

            msg = Message(
                uint8(getattr(ProtocolMessageTypes, attr_name).value), None,
                args[0])
            request_start_t = time.time()
            result = await self.create_request(msg, timeout)
            self.log.debug(
                f"Time for request {attr_name}: {self.get_peer_info()} = {time.time() - request_start_t}, "
                f"None? {result is None}")
            if result is not None:
                ret_attr = getattr(class_for_type(self.local_type),
                                   ProtocolMessageTypes(result.type).name,
                                   None)

                req_annotations = ret_attr.__annotations__
                req = None
                for key in req_annotations:
                    if key == "return" or key == "peer":
                        continue
                    else:
                        req = req_annotations[key]
                assert req is not None
                result = req.from_bytes(result.data)
            return result
コード例 #4
0
 async def bool_f():
     if incoming_queue.qsize() < count:
         return False
     for _ in range(count):
         response = (await incoming_queue.get())[0].type
         if ProtocolMessageTypes(response).name != msg_name:
             # log.warning(f"time_out_message: found {response} instead of {msg_name}")
             return False
     return True
コード例 #5
0
        async def invoke(*args, **kwargs):
            timeout = 60
            if "timeout" in kwargs:
                timeout = kwargs["timeout"]
            attribute = getattr(class_for_type(self.connection_type),
                                attr_name, None)
            if attribute is None:
                raise AttributeError(
                    f"Node type {self.connection_type} does not have method {attr_name}"
                )

            msg: Message = Message(
                uint8(getattr(ProtocolMessageTypes, attr_name).value), None,
                args[0])
            request_start_t = time.time()
            result = await self.send_request(msg, timeout)
            self.log.debug(
                f"Time for request {attr_name}: {self.get_peer_logging()} = {time.time() - request_start_t}, "
                f"None? {result is None}")
            if result is not None:
                sent_message_type = ProtocolMessageTypes(msg.type)
                recv_message_type = ProtocolMessageTypes(result.type)
                if not message_response_ok(sent_message_type,
                                           recv_message_type):
                    # peer protocol violation
                    error_message = f"WSConnection.invoke sent message {sent_message_type.name} "
                    f"but received {recv_message_type.name}"
                    await self.ban_peer_bad_protocol(self.error_message)
                    raise ProtocolError(Err.INVALID_PROTOCOL_MESSAGE,
                                        [error_message])
                ret_attr = getattr(class_for_type(self.local_type),
                                   ProtocolMessageTypes(result.type).name,
                                   None)

                req_annotations = ret_attr.__annotations__
                req = None
                for key in req_annotations:
                    if key == "return" or key == "peer":
                        continue
                    else:
                        req = req_annotations[key]
                assert req is not None
                result = req.from_bytes(result.data)
            return result
コード例 #6
0
    def process_msg_and_check(self, message: Message) -> bool:
        """
        Returns True if message can be processed successfully, false if a rate limit is passed.
        """

        current_minute = time.time() // self.reset_seconds
        if current_minute != self.current_minute:
            self.current_minute = current_minute
            self.message_counts = Counter()
            self.message_cumulative_sizes = Counter()
            self.non_tx_message_counts = 0
            self.non_tx_cumulative_size = 0
        try:
            message_type = ProtocolMessageTypes(message.type)
        except Exception as e:
            log.warning(f"Invalid message: {message.type}, {e}")
            return True

        self.message_counts[message_type] += 1
        self.message_cumulative_sizes[message_type] += len(message.data)
        proportion_of_limit = self.percentage_of_limit / 100

        limits = DEFAULT_SETTINGS
        if message_type in rate_limits_tx:
            limits = rate_limits_tx[message_type]
        elif message_type in rate_limits_other:
            limits = rate_limits_other[message_type]
            self.non_tx_message_counts += 1
            self.non_tx_cumulative_size += len(message.data)
            if self.non_tx_message_counts > NON_TX_FREQ * proportion_of_limit:
                return False
            if self.non_tx_cumulative_size > NON_TX_MAX_TOTAL_SIZE * proportion_of_limit:
                return False
        else:
            log.warning(f"Message type {message_type} not found in rate limits")

        if limits.max_total_size is None:
            limits = dataclasses.replace(limits, max_total_size=limits.frequency * limits.max_size)

        if self.message_counts[message_type] > limits.frequency * proportion_of_limit:
            return False
        if len(message.data) > limits.max_size:
            return False
        if self.message_cumulative_sizes[message_type] > limits.max_total_size * proportion_of_limit:
            return False
        return True
コード例 #7
0
ファイル: server.py プロジェクト: olivernyc/chia-blockchain
 async def validate_broadcast_message_type(self, messages: List[Message],
                                           node_type: NodeType):
     for message in messages:
         if message_requires_reply(ProtocolMessageTypes(message.type)):
             # Internal protocol logic error - we will raise, blocking messages to all peers
             self.log.error(
                 f"Attempt to broadcast message requiring protocol response: {message.type}"
             )
             for _, connection in self.all_connections.items():
                 if connection.connection_type is node_type:
                     await connection.close(
                         self.invalid_protocol_ban_seconds,
                         WSCloseCode.INTERNAL_ERROR,
                         Err.INTERNAL_PROTOCOL_ERROR,
                     )
             raise ProtocolError(Err.INTERNAL_PROTOCOL_ERROR,
                                 [message.type])
コード例 #8
0
            async def api_call(full_message: Message,
                               connection: WSChiaConnection, task_id):
                start_time = time.time()
                try:
                    if self.received_message_callback is not None:
                        await self.received_message_callback(connection)
                    connection.log.debug(
                        f"<- {ProtocolMessageTypes(full_message.type).name} from peer "
                        f"{connection.peer_node_id} {connection.peer_host}")
                    message_type: str = ProtocolMessageTypes(
                        full_message.type).name

                    f = getattr(self.api, message_type, None)

                    if f is None:
                        self.log.error(
                            f"Non existing function: {message_type}")
                        raise ProtocolError(Err.INVALID_PROTOCOL_MESSAGE,
                                            [message_type])

                    if not hasattr(f, "api_function"):
                        self.log.error(
                            f"Peer trying to call non api function {message_type}"
                        )
                        raise ProtocolError(Err.INVALID_PROTOCOL_MESSAGE,
                                            [message_type])

                    # If api is not ready ignore the request
                    if hasattr(self.api, "api_ready"):
                        if self.api.api_ready is False:
                            return None

                    timeout: Optional[int] = 600
                    if hasattr(f, "execute_task"):
                        # Don't timeout on methods with execute_task decorator, these need to run fully
                        self.execute_tasks.add(task_id)
                        timeout = None

                    if hasattr(f, "peer_required"):
                        coroutine = f(full_message.data, connection)
                    else:
                        coroutine = f(full_message.data)

                    async def wrapped_coroutine() -> Optional[Message]:
                        try:
                            result = await coroutine
                            return result
                        except asyncio.CancelledError:
                            pass
                        except Exception as e:
                            tb = traceback.format_exc()
                            connection.log.error(
                                f"Exception: {e}, {connection.get_peer_info()}. {tb}"
                            )
                            raise e
                        return None

                    response: Optional[Message] = await asyncio.wait_for(
                        wrapped_coroutine(), timeout=timeout)
                    connection.log.debug(
                        f"Time taken to process {message_type} from {connection.peer_node_id} is "
                        f"{time.time() - start_time} seconds")

                    if response is not None:
                        response_message = Message(response.type,
                                                   full_message.id,
                                                   response.data)
                        await connection.reply_to_request(response_message)
                except Exception as e:
                    if self.connection_close_task is None:
                        tb = traceback.format_exc()
                        connection.log.error(
                            f"Exception: {e} {type(e)}, closing connection {connection.get_peer_info()}. {tb}"
                        )
                    else:
                        connection.log.debug(
                            f"Exception: {e} while closing connection")
                    # TODO: actually throw one of the errors from errors.py and pass this to close
                    await connection.close(self.api_exception_ban_seconds,
                                           WSCloseCode.PROTOCOL_ERROR,
                                           Err.UNKNOWN)
                finally:
                    if task_id in self.api_tasks:
                        self.api_tasks.pop(task_id)
                    if task_id in self.tasks_from_peer[
                            connection.peer_node_id]:
                        self.tasks_from_peer[connection.peer_node_id].remove(
                            task_id)
                    if task_id in self.execute_tasks:
                        self.execute_tasks.remove(task_id)
コード例 #9
0
ファイル: rate_limits.py プロジェクト: amuDev/Chia-Pooling
    def process_msg_and_check(self, message: Message) -> bool:
        """
        Returns True if message can be processed successfully, false if a rate limit is passed.
        """

        current_minute = int(time.time() // self.reset_seconds)
        if current_minute != self.current_minute:
            self.current_minute = current_minute
            self.message_counts = Counter()
            self.message_cumulative_sizes = Counter()
            self.non_tx_message_counts = 0
            self.non_tx_cumulative_size = 0
        try:
            message_type = ProtocolMessageTypes(message.type)
        except Exception as e:
            log.warning(f"Invalid message: {message.type}, {e}")
            return True

        new_message_counts: int = self.message_counts[message_type] + 1
        new_cumulative_size: int = self.message_cumulative_sizes[
            message_type] + len(message.data)
        new_non_tx_count: int = self.non_tx_message_counts
        new_non_tx_size: int = self.non_tx_cumulative_size
        proportion_of_limit: float = self.percentage_of_limit / 100

        ret: bool = False
        try:

            limits = DEFAULT_SETTINGS
            if message_type in rate_limits_tx:
                limits = rate_limits_tx[message_type]
            elif message_type in rate_limits_other:
                limits = rate_limits_other[message_type]
                new_non_tx_count = self.non_tx_message_counts + 1
                new_non_tx_size = self.non_tx_cumulative_size + len(
                    message.data)
                if new_non_tx_count > NON_TX_FREQ * proportion_of_limit:
                    return False
                if new_non_tx_size > NON_TX_MAX_TOTAL_SIZE * proportion_of_limit:
                    return False
            else:
                log.warning(
                    f"Message type {message_type} not found in rate limits")

            if limits.max_total_size is None:
                limits = dataclasses.replace(limits,
                                             max_total_size=limits.frequency *
                                             limits.max_size)
            assert limits.max_total_size is not None

            if new_message_counts > limits.frequency * proportion_of_limit:
                return False
            if len(message.data) > limits.max_size:
                return False
            if new_cumulative_size > limits.max_total_size * proportion_of_limit:
                return False

            ret = True
            return True
        finally:
            if self.incoming or ret:
                # now that we determined that it's OK to send the message, commit the
                # updates to the counters. Alternatively, if this was an
                # incoming message, we already received it and it should
                # increment the counters unconditionally
                self.message_counts[message_type] = new_message_counts
                self.message_cumulative_sizes[
                    message_type] = new_cumulative_size
                self.non_tx_message_counts = new_non_tx_count
                self.non_tx_cumulative_size = new_non_tx_size
コード例 #10
0
    async def _read_one_message(self) -> Optional[Message]:
        try:
            message: WSMessage = await self.ws.receive(30)
        except asyncio.TimeoutError:
            # self.ws._closed if we didn't receive a ping / pong
            if self.ws._closed:
                asyncio.create_task(self.close())
                await asyncio.sleep(3)
                return None
            return None

        if self.connection_type is not None:
            connection_type_str = NodeType(self.connection_type).name.lower()
        else:
            connection_type_str = ""
        if message.type == WSMsgType.CLOSING:
            self.log.debug(
                f"Closing connection to {connection_type_str} {self.peer_host}:"
                f"{self.peer_server_port}/"
                f"{self.peer_port}")
            asyncio.create_task(self.close())
            await asyncio.sleep(3)
        elif message.type == WSMsgType.CLOSE:
            self.log.debug(
                f"Peer closed connection {connection_type_str} {self.peer_host}:"
                f"{self.peer_server_port}/"
                f"{self.peer_port}")
            asyncio.create_task(self.close())
            await asyncio.sleep(3)
        elif message.type == WSMsgType.CLOSED:
            if not self.closed:
                asyncio.create_task(self.close())
                await asyncio.sleep(3)
                return None
        elif message.type == WSMsgType.BINARY:
            data = message.data
            full_message_loaded: Message = Message.from_bytes(data)
            self.bytes_read += len(data)
            self.last_message_time = time.time()
            try:
                message_type = ProtocolMessageTypes(
                    full_message_loaded.type).name
            except Exception:
                message_type = "Unknown"
            if not self.inbound_rate_limiter.process_msg_and_check(
                    full_message_loaded):
                if self.local_type == NodeType.FULL_NODE and not is_localhost(
                        self.peer_host):
                    self.log.error(
                        f"Peer has been rate limited and will be disconnected: {self.peer_host}, "
                        f"message: {message_type}")
                    # Only full node disconnects peers, to prevent abuse and crashing timelords, farmers, etc
                    asyncio.create_task(self.close(300))
                    await asyncio.sleep(3)
                    return None
                else:
                    self.log.warning(
                        f"Peer surpassed rate limit {self.peer_host}, message: {message_type}, "
                        f"port {self.peer_port} but not disconnecting")
                    return full_message_loaded
            return full_message_loaded
        elif message.type == WSMsgType.ERROR:
            self.log.error(f"WebSocket Error: {message}")
            if message.data.code == WSCloseCode.MESSAGE_TOO_BIG:
                asyncio.create_task(self.close(300))
            else:
                asyncio.create_task(self.close())
            await asyncio.sleep(3)

        else:
            self.log.error(f"Unexpected WebSocket message type: {message}")
            asyncio.create_task(self.close())
            await asyncio.sleep(3)
        return None
コード例 #11
0
    async def test_transaction_freeze(self, wallet_node_30_freeze):
        num_blocks = 5
        full_nodes, wallets = wallet_node_30_freeze
        full_node_api: FullNodeSimulator = full_nodes[0]
        full_node_server = full_node_api.server
        wallet_node, server_2 = wallets[0]
        wallet = wallet_node.wallet_state_manager.main_wallet
        ph = await wallet.get_new_puzzlehash()
        incoming_queue, node_id = await add_dummy_connection(
            full_node_server, 12312)

        await server_2.start_client(
            PeerInfo(self_hostname, uint16(full_node_server._port)), None)
        for i in range(num_blocks):
            await full_node_api.farm_new_transaction_block(
                FarmNewBlockProtocol(ph))

        funds = sum([
            calculate_pool_reward(uint32(i)) +
            calculate_base_farmer_reward(uint32(i))
            for i in range(1, num_blocks)
        ])
        # funds += calculate_base_farmer_reward(0)
        await asyncio.sleep(2)
        print(await wallet.get_confirmed_balance(), funds)
        await time_out_assert(10, wallet.get_confirmed_balance, funds)
        assert int(time.time(
        )) < full_node_api.full_node.constants.INITIAL_FREEZE_END_TIMESTAMP
        tx: TransactionRecord = await wallet.generate_signed_transaction(
            100, ph, 0)
        spend = wallet_protocol.SendTransaction(tx.spend_bundle)
        response = await full_node_api.send_transaction(spend)
        assert wallet_protocol.TransactionAck.from_bytes(
            response.data).status == MempoolInclusionStatus.FAILED

        new_spend = full_node_protocol.NewTransaction(tx.spend_bundle.name(),
                                                      1, 0)
        response = await full_node_api.new_transaction(new_spend)
        assert response is None

        peer = full_node_server.all_connections[node_id]
        new_spend = full_node_protocol.RespondTransaction(tx.spend_bundle)
        response = await full_node_api.respond_transaction(new_spend,
                                                           peer=peer)
        assert response is None

        for i in range(26):
            await asyncio.sleep(2)
            await full_node_api.farm_new_transaction_block(
                FarmNewBlockProtocol(ph))
            if int(
                    time.time()
            ) > full_node_api.full_node.constants.INITIAL_FREEZE_END_TIMESTAMP:
                break

        new_spend = full_node_protocol.NewTransaction(tx.spend_bundle.name(),
                                                      1, 0)
        response = await full_node_api.new_transaction(new_spend)
        assert response is not None
        assert ProtocolMessageTypes(
            response.type) == ProtocolMessageTypes.request_transaction

        tx: TransactionRecord = await wallet.generate_signed_transaction(
            100, ph, 0)
        spend = wallet_protocol.SendTransaction(tx.spend_bundle)
        response = await full_node_api.send_transaction(spend)
        assert response is not None

        assert wallet_protocol.TransactionAck.from_bytes(
            response.data).status == MempoolInclusionStatus.SUCCESS
        assert ProtocolMessageTypes(
            response.type) == ProtocolMessageTypes.transaction_ack