async def perform_handshake(self, network_id: str, protocol_version: str, server_port: int, local_type: NodeType): if self.is_outbound: outbound_handshake = make_msg( ProtocolMessageTypes.handshake, Handshake( network_id, protocol_version, chia_full_version_str(), uint16(server_port), uint8(local_type.value), [(uint16(Capability.BASE.value), "1")], ), ) assert outbound_handshake is not None await self._send_message(outbound_handshake) inbound_handshake_msg = await self._read_one_message() if inbound_handshake_msg is None: raise ProtocolError(Err.INVALID_HANDSHAKE) inbound_handshake = Handshake.from_bytes(inbound_handshake_msg.data) if ProtocolMessageTypes(inbound_handshake_msg.type) != ProtocolMessageTypes.handshake: raise ProtocolError(Err.INVALID_HANDSHAKE) if inbound_handshake.network_id != network_id: raise ProtocolError(Err.INCOMPATIBLE_NETWORK_ID) self.peer_server_port = inbound_handshake.server_port self.connection_type = NodeType(inbound_handshake.node_type) else: try: message = await self._read_one_message() except Exception: raise ProtocolError(Err.INVALID_HANDSHAKE) if message is None: raise ProtocolError(Err.INVALID_HANDSHAKE) inbound_handshake = Handshake.from_bytes(message.data) if ProtocolMessageTypes(message.type) != ProtocolMessageTypes.handshake: raise ProtocolError(Err.INVALID_HANDSHAKE) if inbound_handshake.network_id != network_id: raise ProtocolError(Err.INCOMPATIBLE_NETWORK_ID) outbound_handshake = make_msg( ProtocolMessageTypes.handshake, Handshake( network_id, protocol_version, chia_full_version_str(), uint16(server_port), uint8(local_type.value), [(uint16(Capability.BASE.value), "1")], ), ) await self._send_message(outbound_handshake) self.peer_server_port = inbound_handshake.server_port self.connection_type = NodeType(inbound_handshake.node_type) self.outbound_task = asyncio.create_task(self.outbound_handler()) self.inbound_task = asyncio.create_task(self.inbound_handler()) return True
async def check_node_alive(node: Node) -> Tuple[bool, Node]: ssl_context = ssl_context_for_server('keys/chia_ca.crt', 'keys/chia_ca.key', 'keys/public_full_node.crt', 'keys/public_full_node.key') try: # TODO: connect and timeout in one shot, don't do it in two connections await asyncio.wait_for(websockets.connect(node.get_websocket_url(), ssl=ssl_context), timeout=10) async with websockets.connect(node.get_websocket_url(), ssl=ssl_context) as websocket: handshake = make_msg( ProtocolMessageTypes.handshake, Handshake( 'mainnet', protocol_version, '0.0.0', uint16(8884), uint8(NodeType.INTRODUCER), [(uint16(Capability.BASE.value), '1')], )) encoded_handshake = bytes(handshake) await websocket.send(encoded_handshake) message = await websocket.recv() if message is None: logging.warning('Node ' + node.ip + ' did not return anything') return False, node full_message_loaded = Message.from_bytes(message) inbound_handshake = Handshake.from_bytes(full_message_loaded.data) await websocket.close() if inbound_handshake.network_id != 'mainnet': logging.warning('Node ' + node.ip + ' is not on main net but is on mainnet port!') return False, node logging.info('Node ' + node.ip + ' is up.') return True, node except websockets.exceptions.ConnectionClosed as e: logging.warning('Node closed the connection') return False, node except asyncio.exceptions.TimeoutError as e: logging.warning('Node timeout : ' + node.ip) return False, node except websockets.exceptions.InvalidMessage as e: return False, node except OSError as e: return False, node except Exception as e: logging.error(e) traceback.print_exc() return False, node
async def test_invalid_protocol_handshake(self, setup_two_nodes): nodes, _ = setup_two_nodes server_1 = nodes[0].full_node.server server_2 = nodes[1].full_node.server server_1.invalid_protocol_ban_seconds = 10 # Use the server_2 ssl information to connect to server_1 timeout = ClientTimeout(total=10) session = ClientSession(timeout=timeout) url = f"wss://{self_hostname}:{server_1._port}/ws" ssl_context = ssl_context_for_client(server_2.chia_ca_crt_path, server_2.chia_ca_key_path, server_2.p2p_crt_path, server_2.p2p_key_path) ws = await session.ws_connect(url, autoclose=True, autoping=True, heartbeat=60, ssl=ssl_context, max_msg_size=100 * 1024 * 1024) # Construct an otherwise valid handshake message handshake: Handshake = Handshake("test", "0.0.32", "1.0.0.0", 3456, 1, [(1, "1")]) outbound_handshake: Message = Message( 2, None, bytes(handshake)) # 2 is an invalid ProtocolType await ws.send_bytes(bytes(outbound_handshake)) response: WSMessage = await ws.receive() print(response) assert response.type == WSMsgType.CLOSE assert response.data == WSCloseCode.PROTOCOL_ERROR assert response.extra == str( int(Err.INVALID_HANDSHAKE.value )) # We want INVALID_HANDSHAKE and not UNKNOWN await ws.close() await session.close() await asyncio.sleep(1) # give some time for cleanup to work