async def perform_handshake(self, network_id: str, protocol_version: str, server_port: int, local_type: NodeType): if self.is_outbound: outbound_handshake = make_msg( ProtocolMessageTypes.handshake, Handshake( network_id, protocol_version, chia_full_version_str(), uint16(server_port), uint8(local_type.value), [(uint16(Capability.BASE.value), "1")], ), ) assert outbound_handshake is not None await self._send_message(outbound_handshake) inbound_handshake_msg = await self._read_one_message() if inbound_handshake_msg is None: raise ProtocolError(Err.INVALID_HANDSHAKE) inbound_handshake = Handshake.from_bytes(inbound_handshake_msg.data) if ProtocolMessageTypes(inbound_handshake_msg.type) != ProtocolMessageTypes.handshake: raise ProtocolError(Err.INVALID_HANDSHAKE) if inbound_handshake.network_id != network_id: raise ProtocolError(Err.INCOMPATIBLE_NETWORK_ID) self.peer_server_port = inbound_handshake.server_port self.connection_type = NodeType(inbound_handshake.node_type) else: try: message = await self._read_one_message() except Exception: raise ProtocolError(Err.INVALID_HANDSHAKE) if message is None: raise ProtocolError(Err.INVALID_HANDSHAKE) inbound_handshake = Handshake.from_bytes(message.data) if ProtocolMessageTypes(message.type) != ProtocolMessageTypes.handshake: raise ProtocolError(Err.INVALID_HANDSHAKE) if inbound_handshake.network_id != network_id: raise ProtocolError(Err.INCOMPATIBLE_NETWORK_ID) outbound_handshake = make_msg( ProtocolMessageTypes.handshake, Handshake( network_id, protocol_version, chia_full_version_str(), uint16(server_port), uint8(local_type.value), [(uint16(Capability.BASE.value), "1")], ), ) await self._send_message(outbound_handshake) self.peer_server_port = inbound_handshake.server_port self.connection_type = NodeType(inbound_handshake.node_type) self.outbound_task = asyncio.create_task(self.outbound_handler()) self.inbound_task = asyncio.create_task(self.inbound_handler()) return True
async def _send_message(self, message: Message): encoded: bytes = bytes(message) size = len(encoded) assert len(encoded) < (2**(LENGTH_BYTES * 8)) if not self.outbound_rate_limiter.process_msg_and_check(message): if not is_localhost(self.peer_host): self.log.debug( f"Rate limiting ourselves. message type: {ProtocolMessageTypes(message.type).name}, " f"peer: {self.peer_host}") # TODO: fix this special case. This function has rate limits which are too low. if ProtocolMessageTypes( message.type) != ProtocolMessageTypes.respond_peers: asyncio.create_task( self._wait_and_retry(message, self.outgoing_queue)) return None else: self.log.debug( f"Not rate limiting ourselves. message type: {ProtocolMessageTypes(message.type).name}, " f"peer: {self.peer_host}") await self.ws.send_bytes(encoded) self.log.debug( f"-> {ProtocolMessageTypes(message.type).name} to peer {self.peer_host} {self.peer_node_id}" ) self.bytes_written += size
async def invoke(*args, **kwargs): timeout = 60 if "timeout" in kwargs: timeout = kwargs["timeout"] attribute = getattr(class_for_type(self.connection_type), attr_name, None) if attribute is None: raise AttributeError( f"Node type {self.connection_type} does not have method {attr_name}" ) msg = Message( uint8(getattr(ProtocolMessageTypes, attr_name).value), None, args[0]) request_start_t = time.time() result = await self.create_request(msg, timeout) self.log.debug( f"Time for request {attr_name}: {self.get_peer_info()} = {time.time() - request_start_t}, " f"None? {result is None}") if result is not None: ret_attr = getattr(class_for_type(self.local_type), ProtocolMessageTypes(result.type).name, None) req_annotations = ret_attr.__annotations__ req = None for key in req_annotations: if key == "return" or key == "peer": continue else: req = req_annotations[key] assert req is not None result = req.from_bytes(result.data) return result
async def bool_f(): if incoming_queue.qsize() < count: return False for _ in range(count): response = (await incoming_queue.get())[0].type if ProtocolMessageTypes(response).name != msg_name: # log.warning(f"time_out_message: found {response} instead of {msg_name}") return False return True
async def invoke(*args, **kwargs): timeout = 60 if "timeout" in kwargs: timeout = kwargs["timeout"] attribute = getattr(class_for_type(self.connection_type), attr_name, None) if attribute is None: raise AttributeError( f"Node type {self.connection_type} does not have method {attr_name}" ) msg: Message = Message( uint8(getattr(ProtocolMessageTypes, attr_name).value), None, args[0]) request_start_t = time.time() result = await self.send_request(msg, timeout) self.log.debug( f"Time for request {attr_name}: {self.get_peer_logging()} = {time.time() - request_start_t}, " f"None? {result is None}") if result is not None: sent_message_type = ProtocolMessageTypes(msg.type) recv_message_type = ProtocolMessageTypes(result.type) if not message_response_ok(sent_message_type, recv_message_type): # peer protocol violation error_message = f"WSConnection.invoke sent message {sent_message_type.name} " f"but received {recv_message_type.name}" await self.ban_peer_bad_protocol(self.error_message) raise ProtocolError(Err.INVALID_PROTOCOL_MESSAGE, [error_message]) ret_attr = getattr(class_for_type(self.local_type), ProtocolMessageTypes(result.type).name, None) req_annotations = ret_attr.__annotations__ req = None for key in req_annotations: if key == "return" or key == "peer": continue else: req = req_annotations[key] assert req is not None result = req.from_bytes(result.data) return result
def process_msg_and_check(self, message: Message) -> bool: """ Returns True if message can be processed successfully, false if a rate limit is passed. """ current_minute = time.time() // self.reset_seconds if current_minute != self.current_minute: self.current_minute = current_minute self.message_counts = Counter() self.message_cumulative_sizes = Counter() self.non_tx_message_counts = 0 self.non_tx_cumulative_size = 0 try: message_type = ProtocolMessageTypes(message.type) except Exception as e: log.warning(f"Invalid message: {message.type}, {e}") return True self.message_counts[message_type] += 1 self.message_cumulative_sizes[message_type] += len(message.data) proportion_of_limit = self.percentage_of_limit / 100 limits = DEFAULT_SETTINGS if message_type in rate_limits_tx: limits = rate_limits_tx[message_type] elif message_type in rate_limits_other: limits = rate_limits_other[message_type] self.non_tx_message_counts += 1 self.non_tx_cumulative_size += len(message.data) if self.non_tx_message_counts > NON_TX_FREQ * proportion_of_limit: return False if self.non_tx_cumulative_size > NON_TX_MAX_TOTAL_SIZE * proportion_of_limit: return False else: log.warning(f"Message type {message_type} not found in rate limits") if limits.max_total_size is None: limits = dataclasses.replace(limits, max_total_size=limits.frequency * limits.max_size) if self.message_counts[message_type] > limits.frequency * proportion_of_limit: return False if len(message.data) > limits.max_size: return False if self.message_cumulative_sizes[message_type] > limits.max_total_size * proportion_of_limit: return False return True
async def validate_broadcast_message_type(self, messages: List[Message], node_type: NodeType): for message in messages: if message_requires_reply(ProtocolMessageTypes(message.type)): # Internal protocol logic error - we will raise, blocking messages to all peers self.log.error( f"Attempt to broadcast message requiring protocol response: {message.type}" ) for _, connection in self.all_connections.items(): if connection.connection_type is node_type: await connection.close( self.invalid_protocol_ban_seconds, WSCloseCode.INTERNAL_ERROR, Err.INTERNAL_PROTOCOL_ERROR, ) raise ProtocolError(Err.INTERNAL_PROTOCOL_ERROR, [message.type])
async def api_call(full_message: Message, connection: WSChiaConnection, task_id): start_time = time.time() try: if self.received_message_callback is not None: await self.received_message_callback(connection) connection.log.debug( f"<- {ProtocolMessageTypes(full_message.type).name} from peer " f"{connection.peer_node_id} {connection.peer_host}") message_type: str = ProtocolMessageTypes( full_message.type).name f = getattr(self.api, message_type, None) if f is None: self.log.error( f"Non existing function: {message_type}") raise ProtocolError(Err.INVALID_PROTOCOL_MESSAGE, [message_type]) if not hasattr(f, "api_function"): self.log.error( f"Peer trying to call non api function {message_type}" ) raise ProtocolError(Err.INVALID_PROTOCOL_MESSAGE, [message_type]) # If api is not ready ignore the request if hasattr(self.api, "api_ready"): if self.api.api_ready is False: return None timeout: Optional[int] = 600 if hasattr(f, "execute_task"): # Don't timeout on methods with execute_task decorator, these need to run fully self.execute_tasks.add(task_id) timeout = None if hasattr(f, "peer_required"): coroutine = f(full_message.data, connection) else: coroutine = f(full_message.data) async def wrapped_coroutine() -> Optional[Message]: try: result = await coroutine return result except asyncio.CancelledError: pass except Exception as e: tb = traceback.format_exc() connection.log.error( f"Exception: {e}, {connection.get_peer_info()}. {tb}" ) raise e return None response: Optional[Message] = await asyncio.wait_for( wrapped_coroutine(), timeout=timeout) connection.log.debug( f"Time taken to process {message_type} from {connection.peer_node_id} is " f"{time.time() - start_time} seconds") if response is not None: response_message = Message(response.type, full_message.id, response.data) await connection.reply_to_request(response_message) except Exception as e: if self.connection_close_task is None: tb = traceback.format_exc() connection.log.error( f"Exception: {e} {type(e)}, closing connection {connection.get_peer_info()}. {tb}" ) else: connection.log.debug( f"Exception: {e} while closing connection") # TODO: actually throw one of the errors from errors.py and pass this to close await connection.close(self.api_exception_ban_seconds, WSCloseCode.PROTOCOL_ERROR, Err.UNKNOWN) finally: if task_id in self.api_tasks: self.api_tasks.pop(task_id) if task_id in self.tasks_from_peer[ connection.peer_node_id]: self.tasks_from_peer[connection.peer_node_id].remove( task_id) if task_id in self.execute_tasks: self.execute_tasks.remove(task_id)
def process_msg_and_check(self, message: Message) -> bool: """ Returns True if message can be processed successfully, false if a rate limit is passed. """ current_minute = int(time.time() // self.reset_seconds) if current_minute != self.current_minute: self.current_minute = current_minute self.message_counts = Counter() self.message_cumulative_sizes = Counter() self.non_tx_message_counts = 0 self.non_tx_cumulative_size = 0 try: message_type = ProtocolMessageTypes(message.type) except Exception as e: log.warning(f"Invalid message: {message.type}, {e}") return True new_message_counts: int = self.message_counts[message_type] + 1 new_cumulative_size: int = self.message_cumulative_sizes[ message_type] + len(message.data) new_non_tx_count: int = self.non_tx_message_counts new_non_tx_size: int = self.non_tx_cumulative_size proportion_of_limit: float = self.percentage_of_limit / 100 ret: bool = False try: limits = DEFAULT_SETTINGS if message_type in rate_limits_tx: limits = rate_limits_tx[message_type] elif message_type in rate_limits_other: limits = rate_limits_other[message_type] new_non_tx_count = self.non_tx_message_counts + 1 new_non_tx_size = self.non_tx_cumulative_size + len( message.data) if new_non_tx_count > NON_TX_FREQ * proportion_of_limit: return False if new_non_tx_size > NON_TX_MAX_TOTAL_SIZE * proportion_of_limit: return False else: log.warning( f"Message type {message_type} not found in rate limits") if limits.max_total_size is None: limits = dataclasses.replace(limits, max_total_size=limits.frequency * limits.max_size) assert limits.max_total_size is not None if new_message_counts > limits.frequency * proportion_of_limit: return False if len(message.data) > limits.max_size: return False if new_cumulative_size > limits.max_total_size * proportion_of_limit: return False ret = True return True finally: if self.incoming or ret: # now that we determined that it's OK to send the message, commit the # updates to the counters. Alternatively, if this was an # incoming message, we already received it and it should # increment the counters unconditionally self.message_counts[message_type] = new_message_counts self.message_cumulative_sizes[ message_type] = new_cumulative_size self.non_tx_message_counts = new_non_tx_count self.non_tx_cumulative_size = new_non_tx_size
async def _read_one_message(self) -> Optional[Message]: try: message: WSMessage = await self.ws.receive(30) except asyncio.TimeoutError: # self.ws._closed if we didn't receive a ping / pong if self.ws._closed: asyncio.create_task(self.close()) await asyncio.sleep(3) return None return None if self.connection_type is not None: connection_type_str = NodeType(self.connection_type).name.lower() else: connection_type_str = "" if message.type == WSMsgType.CLOSING: self.log.debug( f"Closing connection to {connection_type_str} {self.peer_host}:" f"{self.peer_server_port}/" f"{self.peer_port}") asyncio.create_task(self.close()) await asyncio.sleep(3) elif message.type == WSMsgType.CLOSE: self.log.debug( f"Peer closed connection {connection_type_str} {self.peer_host}:" f"{self.peer_server_port}/" f"{self.peer_port}") asyncio.create_task(self.close()) await asyncio.sleep(3) elif message.type == WSMsgType.CLOSED: if not self.closed: asyncio.create_task(self.close()) await asyncio.sleep(3) return None elif message.type == WSMsgType.BINARY: data = message.data full_message_loaded: Message = Message.from_bytes(data) self.bytes_read += len(data) self.last_message_time = time.time() try: message_type = ProtocolMessageTypes( full_message_loaded.type).name except Exception: message_type = "Unknown" if not self.inbound_rate_limiter.process_msg_and_check( full_message_loaded): if self.local_type == NodeType.FULL_NODE and not is_localhost( self.peer_host): self.log.error( f"Peer has been rate limited and will be disconnected: {self.peer_host}, " f"message: {message_type}") # Only full node disconnects peers, to prevent abuse and crashing timelords, farmers, etc asyncio.create_task(self.close(300)) await asyncio.sleep(3) return None else: self.log.warning( f"Peer surpassed rate limit {self.peer_host}, message: {message_type}, " f"port {self.peer_port} but not disconnecting") return full_message_loaded return full_message_loaded elif message.type == WSMsgType.ERROR: self.log.error(f"WebSocket Error: {message}") if message.data.code == WSCloseCode.MESSAGE_TOO_BIG: asyncio.create_task(self.close(300)) else: asyncio.create_task(self.close()) await asyncio.sleep(3) else: self.log.error(f"Unexpected WebSocket message type: {message}") asyncio.create_task(self.close()) await asyncio.sleep(3) return None
async def test_transaction_freeze(self, wallet_node_30_freeze): num_blocks = 5 full_nodes, wallets = wallet_node_30_freeze full_node_api: FullNodeSimulator = full_nodes[0] full_node_server = full_node_api.server wallet_node, server_2 = wallets[0] wallet = wallet_node.wallet_state_manager.main_wallet ph = await wallet.get_new_puzzlehash() incoming_queue, node_id = await add_dummy_connection( full_node_server, 12312) await server_2.start_client( PeerInfo(self_hostname, uint16(full_node_server._port)), None) for i in range(num_blocks): await full_node_api.farm_new_transaction_block( FarmNewBlockProtocol(ph)) funds = sum([ calculate_pool_reward(uint32(i)) + calculate_base_farmer_reward(uint32(i)) for i in range(1, num_blocks) ]) # funds += calculate_base_farmer_reward(0) await asyncio.sleep(2) print(await wallet.get_confirmed_balance(), funds) await time_out_assert(10, wallet.get_confirmed_balance, funds) assert int(time.time( )) < full_node_api.full_node.constants.INITIAL_FREEZE_END_TIMESTAMP tx: TransactionRecord = await wallet.generate_signed_transaction( 100, ph, 0) spend = wallet_protocol.SendTransaction(tx.spend_bundle) response = await full_node_api.send_transaction(spend) assert wallet_protocol.TransactionAck.from_bytes( response.data).status == MempoolInclusionStatus.FAILED new_spend = full_node_protocol.NewTransaction(tx.spend_bundle.name(), 1, 0) response = await full_node_api.new_transaction(new_spend) assert response is None peer = full_node_server.all_connections[node_id] new_spend = full_node_protocol.RespondTransaction(tx.spend_bundle) response = await full_node_api.respond_transaction(new_spend, peer=peer) assert response is None for i in range(26): await asyncio.sleep(2) await full_node_api.farm_new_transaction_block( FarmNewBlockProtocol(ph)) if int( time.time() ) > full_node_api.full_node.constants.INITIAL_FREEZE_END_TIMESTAMP: break new_spend = full_node_protocol.NewTransaction(tx.spend_bundle.name(), 1, 0) response = await full_node_api.new_transaction(new_spend) assert response is not None assert ProtocolMessageTypes( response.type) == ProtocolMessageTypes.request_transaction tx: TransactionRecord = await wallet.generate_signed_transaction( 100, ph, 0) spend = wallet_protocol.SendTransaction(tx.spend_bundle) response = await full_node_api.send_transaction(spend) assert response is not None assert wallet_protocol.TransactionAck.from_bytes( response.data).status == MempoolInclusionStatus.SUCCESS assert ProtocolMessageTypes( response.type) == ProtocolMessageTypes.transaction_ack