async def test_too_much_data(self): # Too much data r = RateLimiter() tx_message = make_msg(ProtocolMessageTypes.respond_transaction, bytes([1] * 500 * 1024)) for i in range(10): assert r.process_msg_and_check(tx_message) saw_disconnect = False for i in range(300): response = r.process_msg_and_check(tx_message) if not response: saw_disconnect = True assert saw_disconnect r = RateLimiter() block_message = make_msg(ProtocolMessageTypes.respond_block, bytes([1] * 1024 * 1024)) for i in range(10): assert r.process_msg_and_check(block_message) saw_disconnect = False for i in range(40): response = r.process_msg_and_check(block_message) if not response: saw_disconnect = True assert saw_disconnect
async def test_too_many_messages(self): # Too many messages r = RateLimiter() new_tx_message = make_msg(ProtocolMessageTypes.new_transaction, bytes([1] * 40)) for i in range(3000): assert r.process_msg_and_check(new_tx_message) saw_disconnect = False for i in range(3000): response = r.process_msg_and_check(new_tx_message) if not response: saw_disconnect = True assert saw_disconnect # Non-tx message r = RateLimiter() new_peak_message = make_msg(ProtocolMessageTypes.new_peak, bytes([1] * 40)) for i in range(20): assert r.process_msg_and_check(new_peak_message) saw_disconnect = False for i in range(200): response = r.process_msg_and_check(new_peak_message) if not response: saw_disconnect = True assert saw_disconnect
async def test_spam_tx(self, setup_two_nodes): nodes, _ = setup_two_nodes full_node_1, full_node_2 = nodes server_1 = nodes[0].full_node.server server_2 = nodes[1].full_node.server await server_2.start_client( PeerInfo(self_hostname, uint16(server_1._port)), full_node_2.full_node.on_connect) assert len(server_1.all_connections) == 1 ws_con: WSChiaConnection = list(server_1.all_connections.values())[0] ws_con_2: WSChiaConnection = list(server_2.all_connections.values())[0] ws_con.peer_host = "1.2.3.4" ws_con_2.peer_host = "1.2.3.4" new_tx_message = make_msg( ProtocolMessageTypes.new_transaction, full_node_protocol.NewTransaction(bytes([9] * 32), uint64(0), uint64(0)), ) for i in range(4000): await ws_con._send_message(new_tx_message) await asyncio.sleep(1) assert not ws_con.closed # Tests outbound rate limiting, we will not send too much data for i in range(2000): await ws_con._send_message(new_tx_message) await asyncio.sleep(1) assert not ws_con.closed # Remove outbound rate limiter to test inbound limits ws_con.outbound_rate_limiter = RateLimiter(incoming=True, percentage_of_limit=10000) for i in range(6000): await ws_con._send_message(new_tx_message) await asyncio.sleep(1) def is_closed(): return ws_con.closed await time_out_assert(15, is_closed) assert ws_con.closed def is_banned(): return "1.2.3.4" in server_2.banned_peers await time_out_assert(15, is_banned)
async def test_non_tx_aggregate_limits(self): # Frequency limits r = RateLimiter() message_1 = make_msg(ProtocolMessageTypes.request_additions, bytes([1] * 5 * 1024)) message_2 = make_msg(ProtocolMessageTypes.request_removals, bytes([1] * 1024)) message_3 = make_msg(ProtocolMessageTypes.respond_additions, bytes([1] * 1024)) for i in range(450): assert r.process_msg_and_check(message_1) for i in range(450): assert r.process_msg_and_check(message_2) saw_disconnect = False for i in range(450): response = r.process_msg_and_check(message_3) if not response: saw_disconnect = True assert saw_disconnect # Size limits r = RateLimiter() message_4 = make_msg(ProtocolMessageTypes.respond_proof_of_weight, bytes([1] * 49 * 1024 * 1024)) message_5 = make_msg(ProtocolMessageTypes.respond_blocks, bytes([1] * 49 * 1024 * 1024)) for i in range(2): assert r.process_msg_and_check(message_4) saw_disconnect = False for i in range(2): response = r.process_msg_and_check(message_5) if not response: saw_disconnect = True assert saw_disconnect
async def test_large_message(self): # Large tx small_tx_message = make_msg(ProtocolMessageTypes.respond_transaction, bytes([1] * 500 * 1024)) large_tx_message = make_msg(ProtocolMessageTypes.new_transaction, bytes([1] * 3 * 1024 * 1024)) r = RateLimiter() assert r.process_msg_and_check(small_tx_message) assert r.process_msg_and_check(small_tx_message) assert not r.process_msg_and_check(large_tx_message) small_vdf_message = make_msg( ProtocolMessageTypes.respond_signage_point, bytes([1] * 5 * 1024)) large_vdf_message = make_msg( ProtocolMessageTypes.respond_signage_point, bytes([1] * 600 * 1024)) r = RateLimiter() assert r.process_msg_and_check(small_vdf_message) assert r.process_msg_and_check(small_vdf_message) assert not r.process_msg_and_check(large_vdf_message)
def __init__( self, local_type: NodeType, ws: Any, # Websocket server_port: int, log: logging.Logger, is_outbound: bool, is_feeler: bool, # Special type of connection, that disconnects after the handshake. peer_host, incoming_queue, close_callback: Callable, peer_id, inbound_rate_limit_percent: int, outbound_rate_limit_percent: int, close_event=None, session=None, ): # Local properties self.ws: Any = ws self.local_type = local_type self.local_port = server_port # Remote properties self.peer_host = peer_host peername = self.ws._writer.transport.get_extra_info("peername") if peername is None: raise ValueError( f"Was not able to get peername from {self.ws_witer} at {self.peer_host}" ) connection_port = peername[1] self.peer_port = connection_port self.peer_server_port: Optional[uint16] = None self.peer_node_id = peer_id self.log = log # connection properties self.is_outbound = is_outbound self.is_feeler = is_feeler # ChiaConnection metrics self.creation_time = time.time() self.bytes_read = 0 self.bytes_written = 0 self.last_message_time: float = 0 # Messaging self.incoming_queue: asyncio.Queue = incoming_queue self.outgoing_queue: asyncio.Queue = asyncio.Queue() self.inbound_task: Optional[asyncio.Task] = None self.outbound_task: Optional[asyncio.Task] = None self.active: bool = False # once handshake is successful this will be changed to True self.close_event: asyncio.Event = close_event self.session = session self.close_callback = close_callback self.pending_requests: Dict[bytes32, asyncio.Event] = {} self.pending_timeouts: Dict[bytes32, asyncio.Task] = {} self.request_results: Dict[bytes32, Message] = {} self.closed = False self.connection_type: Optional[NodeType] = None if is_outbound: self.request_nonce: uint16 = uint16(0) else: # Different nonce to reduce chances of overlap. Each peer will increment the nonce by one for each # request. The receiving peer (not is_outbound), will use 2^15 to 2^16 - 1 self.request_nonce = uint16(2**15) # This means that even if the other peer's boundaries for each minute are not aligned, we will not # disconnect. Also it allows a little flexibility. self.outbound_rate_limiter = RateLimiter( incoming=False, percentage_of_limit=outbound_rate_limit_percent) self.inbound_rate_limiter = RateLimiter( incoming=True, percentage_of_limit=inbound_rate_limit_percent)
class WSChiaConnection: """ Represents a connection to another node. Local host and port are ours, while peer host and port are the host and port of the peer that we are connected to. Node_id and connection_type are set after the handshake is performed in this connection. """ def __init__( self, local_type: NodeType, ws: Any, # Websocket server_port: int, log: logging.Logger, is_outbound: bool, is_feeler: bool, # Special type of connection, that disconnects after the handshake. peer_host, incoming_queue, close_callback: Callable, peer_id, inbound_rate_limit_percent: int, outbound_rate_limit_percent: int, close_event=None, session=None, ): # Local properties self.ws: Any = ws self.local_type = local_type self.local_port = server_port # Remote properties self.peer_host = peer_host peername = self.ws._writer.transport.get_extra_info("peername") if peername is None: raise ValueError( f"Was not able to get peername from {self.ws_witer} at {self.peer_host}" ) connection_port = peername[1] self.peer_port = connection_port self.peer_server_port: Optional[uint16] = None self.peer_node_id = peer_id self.log = log # connection properties self.is_outbound = is_outbound self.is_feeler = is_feeler # ChiaConnection metrics self.creation_time = time.time() self.bytes_read = 0 self.bytes_written = 0 self.last_message_time: float = 0 # Messaging self.incoming_queue: asyncio.Queue = incoming_queue self.outgoing_queue: asyncio.Queue = asyncio.Queue() self.inbound_task: Optional[asyncio.Task] = None self.outbound_task: Optional[asyncio.Task] = None self.active: bool = False # once handshake is successful this will be changed to True self.close_event: asyncio.Event = close_event self.session = session self.close_callback = close_callback self.pending_requests: Dict[bytes32, asyncio.Event] = {} self.pending_timeouts: Dict[bytes32, asyncio.Task] = {} self.request_results: Dict[bytes32, Message] = {} self.closed = False self.connection_type: Optional[NodeType] = None if is_outbound: self.request_nonce: uint16 = uint16(0) else: # Different nonce to reduce chances of overlap. Each peer will increment the nonce by one for each # request. The receiving peer (not is_outbound), will use 2^15 to 2^16 - 1 self.request_nonce = uint16(2**15) # This means that even if the other peer's boundaries for each minute are not aligned, we will not # disconnect. Also it allows a little flexibility. self.outbound_rate_limiter = RateLimiter( incoming=False, percentage_of_limit=outbound_rate_limit_percent) self.inbound_rate_limiter = RateLimiter( incoming=True, percentage_of_limit=inbound_rate_limit_percent) async def perform_handshake(self, network_id: str, protocol_version: str, server_port: int, local_type: NodeType): if self.is_outbound: outbound_handshake = make_msg( ProtocolMessageTypes.handshake, Handshake( network_id, protocol_version, chia_full_version_str(), uint16(server_port), uint8(local_type.value), [(uint16(Capability.BASE.value), "1")], ), ) assert outbound_handshake is not None await self._send_message(outbound_handshake) inbound_handshake_msg = await self._read_one_message() if inbound_handshake_msg is None: raise ProtocolError(Err.INVALID_HANDSHAKE) inbound_handshake = Handshake.from_bytes( inbound_handshake_msg.data) # Handle case of invalid ProtocolMessageType try: message_type: ProtocolMessageTypes = ProtocolMessageTypes( inbound_handshake_msg.type) except Exception: raise ProtocolError(Err.INVALID_HANDSHAKE) if message_type != ProtocolMessageTypes.handshake: raise ProtocolError(Err.INVALID_HANDSHAKE) if inbound_handshake.network_id != network_id: raise ProtocolError(Err.INCOMPATIBLE_NETWORK_ID) self.peer_server_port = inbound_handshake.server_port self.connection_type = NodeType(inbound_handshake.node_type) else: try: message = await self._read_one_message() except Exception: raise ProtocolError(Err.INVALID_HANDSHAKE) if message is None: raise ProtocolError(Err.INVALID_HANDSHAKE) # Handle case of invalid ProtocolMessageType try: message_type = ProtocolMessageTypes(message.type) except Exception: raise ProtocolError(Err.INVALID_HANDSHAKE) if message_type != ProtocolMessageTypes.handshake: raise ProtocolError(Err.INVALID_HANDSHAKE) inbound_handshake = Handshake.from_bytes(message.data) if inbound_handshake.network_id != network_id: raise ProtocolError(Err.INCOMPATIBLE_NETWORK_ID) outbound_handshake = make_msg( ProtocolMessageTypes.handshake, Handshake( network_id, protocol_version, chia_full_version_str(), uint16(server_port), uint8(local_type.value), [(uint16(Capability.BASE.value), "1")], ), ) await self._send_message(outbound_handshake) self.peer_server_port = inbound_handshake.server_port self.connection_type = NodeType(inbound_handshake.node_type) self.outbound_task = asyncio.create_task(self.outbound_handler()) self.inbound_task = asyncio.create_task(self.inbound_handler()) return True async def close(self, ban_time: int = 0, ws_close_code: WSCloseCode = WSCloseCode.OK, error: Optional[Err] = None): """ Closes the connection, and finally calls the close_callback on the server, so the connections gets removed from the global list. """ if self.closed: return None self.closed = True if error is None: message = b"" else: message = str(int(error.value)).encode("utf-8") try: if self.inbound_task is not None: self.inbound_task.cancel() if self.outbound_task is not None: self.outbound_task.cancel() if self.ws is not None and self.ws._closed is False: await self.ws.close(code=ws_close_code, message=message) if self.session is not None: await self.session.close() if self.close_event is not None: self.close_event.set() self.cancel_pending_timeouts() except Exception: error_stack = traceback.format_exc() self.log.warning(f"Exception closing socket: {error_stack}") self.close_callback(self, ban_time) raise self.close_callback(self, ban_time) def cancel_pending_timeouts(self): for _, task in self.pending_timeouts.items(): task.cancel() async def outbound_handler(self): try: while not self.closed: msg = await self.outgoing_queue.get() if msg is not None: await self._send_message(msg) except asyncio.CancelledError: pass except BrokenPipeError as e: self.log.warning(f"{e} {self.peer_host}") except ConnectionResetError as e: self.log.warning(f"{e} {self.peer_host}") except Exception as e: error_stack = traceback.format_exc() self.log.error(f"Exception: {e} with {self.peer_host}") self.log.error(f"Exception Stack: {error_stack}") async def inbound_handler(self): try: while not self.closed: message: Message = await self._read_one_message() if message is not None: if message.id in self.pending_requests: self.request_results[message.id] = message event = self.pending_requests[message.id] event.set() else: await self.incoming_queue.put((message, self)) else: continue except asyncio.CancelledError: self.log.debug("Inbound_handler task cancelled") except Exception as e: error_stack = traceback.format_exc() self.log.error(f"Exception: {e}") self.log.error(f"Exception Stack: {error_stack}") async def send_message(self, message: Message): """Send message sends a message with no tracking / callback.""" if self.closed: return None await self.outgoing_queue.put(message) def __getattr__(self, attr_name: str): # TODO KWARGS async def invoke(*args, **kwargs): timeout = 60 if "timeout" in kwargs: timeout = kwargs["timeout"] attribute = getattr(class_for_type(self.connection_type), attr_name, None) if attribute is None: raise AttributeError( f"Node type {self.connection_type} does not have method {attr_name}" ) msg = Message( uint8(getattr(ProtocolMessageTypes, attr_name).value), None, args[0]) request_start_t = time.time() result = await self.create_request(msg, timeout) self.log.debug( f"Time for request {attr_name}: {self.get_peer_info()} = {time.time() - request_start_t}, " f"None? {result is None}") if result is not None: ret_attr = getattr(class_for_type(self.local_type), ProtocolMessageTypes(result.type).name, None) req_annotations = ret_attr.__annotations__ req = None for key in req_annotations: if key == "return" or key == "peer": continue else: req = req_annotations[key] assert req is not None result = req.from_bytes(result.data) return result return invoke async def create_request(self, message_no_id: Message, timeout: int) -> Optional[Message]: """Sends a message and waits for a response.""" if self.closed: return None # We will wait for this event, it will be set either by the response, or the timeout event = asyncio.Event() # The request nonce is an integer between 0 and 2**16 - 1, which is used to match requests to responses # If is_outbound, 0 <= nonce < 2^15, else 2^15 <= nonce < 2^16 request_id = self.request_nonce if self.is_outbound: self.request_nonce = uint16(self.request_nonce + 1) if self.request_nonce != ( 2**15 - 1) else uint16(0) else: self.request_nonce = (uint16(self.request_nonce + 1) if self.request_nonce != (2**16 - 1) else uint16(2**15)) message = Message(message_no_id.type, request_id, message_no_id.data) self.pending_requests[message.id] = event await self.outgoing_queue.put(message) # If the timeout passes, we set the event async def time_out(req_id, req_timeout): try: await asyncio.sleep(req_timeout) if req_id in self.pending_requests: self.pending_requests[req_id].set() except asyncio.CancelledError: if req_id in self.pending_requests: self.pending_requests[req_id].set() raise timeout_task = asyncio.create_task(time_out(message.id, timeout)) self.pending_timeouts[message.id] = timeout_task await event.wait() self.pending_requests.pop(message.id) result: Optional[Message] = None if message.id in self.request_results: result = self.request_results[message.id] assert result is not None self.log.debug( f"<- {ProtocolMessageTypes(result.type).name} from: {self.peer_host}:{self.peer_port}" ) self.request_results.pop(result.id) return result async def reply_to_request(self, response: Message): if self.closed: return None await self.outgoing_queue.put(response) async def send_messages(self, messages: List[Message]): if self.closed: return None for message in messages: await self.outgoing_queue.put(message) async def _wait_and_retry(self, msg: Message, queue: asyncio.Queue): try: await asyncio.sleep(1) await queue.put(msg) except Exception as e: self.log.debug( f"Exception {e} while waiting to retry sending rate limited message" ) return None async def _send_message(self, message: Message): encoded: bytes = bytes(message) size = len(encoded) assert len(encoded) < (2**(LENGTH_BYTES * 8)) if not self.outbound_rate_limiter.process_msg_and_check(message): if not is_localhost(self.peer_host): self.log.debug( f"Rate limiting ourselves. message type: {ProtocolMessageTypes(message.type).name}, " f"peer: {self.peer_host}") # TODO: fix this special case. This function has rate limits which are too low. if ProtocolMessageTypes( message.type) != ProtocolMessageTypes.respond_peers: asyncio.create_task( self._wait_and_retry(message, self.outgoing_queue)) return None else: self.log.debug( f"Not rate limiting ourselves. message type: {ProtocolMessageTypes(message.type).name}, " f"peer: {self.peer_host}") await self.ws.send_bytes(encoded) self.log.debug( f"-> {ProtocolMessageTypes(message.type).name} to peer {self.peer_host} {self.peer_node_id}" ) self.bytes_written += size async def _read_one_message(self) -> Optional[Message]: try: message: WSMessage = await self.ws.receive(30) except asyncio.TimeoutError: # self.ws._closed if we didn't receive a ping / pong if self.ws._closed: asyncio.create_task(self.close()) await asyncio.sleep(3) return None return None if self.connection_type is not None: connection_type_str = NodeType(self.connection_type).name.lower() else: connection_type_str = "" if message.type == WSMsgType.CLOSING: self.log.debug( f"Closing connection to {connection_type_str} {self.peer_host}:" f"{self.peer_server_port}/" f"{self.peer_port}") asyncio.create_task(self.close()) await asyncio.sleep(3) elif message.type == WSMsgType.CLOSE: self.log.debug( f"Peer closed connection {connection_type_str} {self.peer_host}:" f"{self.peer_server_port}/" f"{self.peer_port}") asyncio.create_task(self.close()) await asyncio.sleep(3) elif message.type == WSMsgType.CLOSED: if not self.closed: asyncio.create_task(self.close()) await asyncio.sleep(3) return None elif message.type == WSMsgType.BINARY: data = message.data full_message_loaded: Message = Message.from_bytes(data) self.bytes_read += len(data) self.last_message_time = time.time() try: message_type = ProtocolMessageTypes( full_message_loaded.type).name except Exception: message_type = "Unknown" if not self.inbound_rate_limiter.process_msg_and_check( full_message_loaded): if self.local_type == NodeType.FULL_NODE and not is_localhost( self.peer_host): self.log.error( f"Peer has been rate limited and will be disconnected: {self.peer_host}, " f"message: {message_type}") # Only full node disconnects peers, to prevent abuse and crashing timelords, farmers, etc asyncio.create_task(self.close(300)) await asyncio.sleep(3) return None else: self.log.warning( f"Peer surpassed rate limit {self.peer_host}, message: {message_type}, " f"port {self.peer_port} but not disconnecting") return full_message_loaded return full_message_loaded elif message.type == WSMsgType.ERROR: self.log.error(f"WebSocket Error: {message}") if message.data.code == WSCloseCode.MESSAGE_TOO_BIG: asyncio.create_task(self.close(300)) else: asyncio.create_task(self.close()) await asyncio.sleep(3) else: self.log.error(f"Unexpected WebSocket message type: {message}") asyncio.create_task(self.close()) await asyncio.sleep(3) return None def get_peer_info(self) -> Optional[PeerInfo]: result = self.ws._writer.transport.get_extra_info("peername") if result is None: return None connection_host = result[0] port = self.peer_server_port if self.peer_server_port is not None else self.peer_port return PeerInfo(connection_host, port)
async def test_percentage_limits(self): r = RateLimiter(60, 40) new_peak_message = make_msg(ProtocolMessageTypes.new_peak, bytes([1] * 40)) for i in range(50): assert r.process_msg_and_check(new_peak_message) saw_disconnect = False for i in range(50): response = r.process_msg_and_check(new_peak_message) if not response: saw_disconnect = True assert saw_disconnect r = RateLimiter(60, 40) block_message = make_msg(ProtocolMessageTypes.respond_block, bytes([1] * 1024 * 1024)) for i in range(5): assert r.process_msg_and_check(block_message) saw_disconnect = False for i in range(5): response = r.process_msg_and_check(block_message) if not response: saw_disconnect = True assert saw_disconnect # Aggregate percentage limit count r = RateLimiter(60, 40) message_1 = make_msg(ProtocolMessageTypes.request_additions, bytes([1] * 5 * 1024)) message_2 = make_msg(ProtocolMessageTypes.request_removals, bytes([1] * 1024)) message_3 = make_msg(ProtocolMessageTypes.respond_additions, bytes([1] * 1024)) for i in range(180): assert r.process_msg_and_check(message_1) for i in range(180): assert r.process_msg_and_check(message_2) saw_disconnect = False for i in range(100): response = r.process_msg_and_check(message_3) if not response: saw_disconnect = True assert saw_disconnect # Aggregate percentage limit max total size r = RateLimiter(60, 40) message_4 = make_msg(ProtocolMessageTypes.respond_proof_of_weight, bytes([1] * 18 * 1024 * 1024)) message_5 = make_msg(ProtocolMessageTypes.respond_blocks, bytes([1] * 24 * 1024 * 1024)) for i in range(2): assert r.process_msg_and_check(message_4) saw_disconnect = False for i in range(2): response = r.process_msg_and_check(message_5) if not response: saw_disconnect = True assert saw_disconnect
async def test_periodic_reset(self): r = RateLimiter(5) tx_message = make_msg(ProtocolMessageTypes.respond_transaction, bytes([1] * 500 * 1024)) for i in range(10): assert r.process_msg_and_check(tx_message) saw_disconnect = False for i in range(300): response = r.process_msg_and_check(tx_message) if not response: saw_disconnect = True assert saw_disconnect assert not r.process_msg_and_check(tx_message) await asyncio.sleep(6) assert r.process_msg_and_check(tx_message) # Counts reset also r = RateLimiter(5) new_tx_message = make_msg(ProtocolMessageTypes.new_transaction, bytes([1] * 40)) for i in range(3000): assert r.process_msg_and_check(new_tx_message) saw_disconnect = False for i in range(3000): response = r.process_msg_and_check(new_tx_message) if not response: saw_disconnect = True assert saw_disconnect assert not r.process_msg_and_check(new_tx_message) await asyncio.sleep(6) assert r.process_msg_and_check(new_tx_message)