async def test_api_manager_list_objects(state_api_manager): data_source_client = state_api_manager.data_source_client obj_1_id = b"1" * 28 obj_2_id = b"2" * 28 data_source_client.get_all_registered_raylet_ids = MagicMock() data_source_client.get_all_registered_raylet_ids.return_value = ["1", "2"] data_source_client.get_object_info = AsyncMock() data_source_client.get_object_info.side_effect = [ GetNodeStatsReply(core_workers_stats=[generate_object_info(obj_1_id)]), GetNodeStatsReply(core_workers_stats=[generate_object_info(obj_2_id)]), ] result = await state_api_manager.list_objects(option=list_api_options()) data = result.result data_source_client.get_object_info.assert_any_await( "1", timeout=DEFAULT_RPC_TIMEOUT) data_source_client.get_object_info.assert_any_await( "2", timeout=DEFAULT_RPC_TIMEOUT) data = list(data.values()) assert len(data) == 2 verify_schema(ObjectState, data[0]) verify_schema(ObjectState, data[1]) """ Test limit """ data_source_client.get_object_info.side_effect = [ GetNodeStatsReply(core_workers_stats=[generate_object_info(obj_1_id)]), GetNodeStatsReply(core_workers_stats=[generate_object_info(obj_2_id)]), ] result = await state_api_manager.list_objects(option=list_api_options( limit=1)) data = result.result assert len(data) == 1 """ Test error handling """ data_source_client.get_object_info.side_effect = [ DataSourceUnavailable(), GetNodeStatsReply(core_workers_stats=[generate_object_info(obj_2_id)]), ] result = await state_api_manager.list_objects(option=list_api_options( limit=1)) # Make sure warnings are returned. warning = result.partial_failure_warning assert (NODE_QUERY_FAILURE_WARNING.format(type="raylet", total=2, network_failures=1, log_command="raylet.out") in warning) # Test if all RPCs fail, it will raise an exception. data_source_client.get_object_info.side_effect = [ DataSourceUnavailable(), DataSourceUnavailable(), ] with pytest.raises(DataSourceUnavailable): result = await state_api_manager.list_objects(option=list_api_options( limit=1))
async def test_logs_manager_list_logs(logs_manager): logs_client = logs_manager.data_source_client logs_client.get_all_registered_agent_ids = MagicMock() logs_client.get_all_registered_agent_ids.return_value = ["1", "2"] logs_client.list_logs.side_effect = [ generate_list_logs(["gcs_server.out"]), DataSourceUnavailable(), ] # Unregistered node id should raise a DataSourceUnavailable. with pytest.raises(DataSourceUnavailable): result = await logs_manager.list_logs( node_id="3", timeout=30, glob_filter="*gcs*" ) result = await logs_manager.list_logs(node_id="2", timeout=30, glob_filter="*gcs*") assert len(result) == 1 assert result["gcs_server"] == ["gcs_server.out"] assert result["raylet"] == [] logs_client.get_all_registered_agent_ids.assert_called() logs_client.list_logs.assert_awaited_with("2", "*gcs*", timeout=30) # The second call raises DataSourceUnavailable, which will # return DataSourceUnavailable to the caller. with pytest.raises(DataSourceUnavailable): result = await logs_manager.list_logs( node_id="1", timeout=30, glob_filter="*gcs*" )
async def api_with_network_error_handler(*args, **kwargs): """Apply the network error handling logic to each APIs, such as retry or exception policies. Returns: If RPC succeeds, it returns what the original function returns. If RPC fails, it raises exceptions. Exceptions: DataSourceUnavailable: if the source is unavailable because it is down or there's a slow network issue causing timeout. Otherwise, the raw network exceptions (e.g., gRPC) will be raised. """ # TODO(sang): Add a retry policy. try: return await func(*args, **kwargs) except grpc.aio.AioRpcError as e: if (e.code() == grpc.StatusCode.DEADLINE_EXCEEDED or e.code() == grpc.StatusCode.UNAVAILABLE): raise DataSourceUnavailable( "Failed to query the data source. " "It is either there's a network issue, or the source is down." ) else: logger.exception(e) raise e
async def test_api_manager_list_workers(state_api_manager): data_source_client = state_api_manager.data_source_client id = b"1234" data_source_client.get_all_worker_info.return_value = GetAllWorkerInfoReply( worker_table_data=[ generate_worker_data(id), generate_worker_data(b"12345"), ]) result = await state_api_manager.list_workers(option=list_api_options()) data = result.result data = list(data.values())[0] verify_schema(WorkerState, data) """ Test limit """ assert len(result.result) == 2 result = await state_api_manager.list_workers(option=list_api_options( limit=1)) data = result.result assert len(data) == 1 """ Test error handling """ data_source_client.get_all_worker_info.side_effect = DataSourceUnavailable( ) with pytest.raises(DataSourceUnavailable) as exc_info: result = await state_api_manager.list_workers(option=list_api_options( limit=1)) assert exc_info.value.args[0] == GCS_QUERY_FAILURE_WARNING
def _verify_node_registered(self, node_id: str): if node_id not in self.client.get_all_registered_agent_ids(): raise DataSourceUnavailable( f"Given node id {node_id} is not available. " "It's either the node is dead, or it is not registered. " "Use `ray list nodes` " "to see the node status. If the node is registered, " "it is highly likely " "a transient issue. Try again.") assert node_id is not None
def get_job_info(self) -> Optional[Dict[str, JobInfo]]: # Cannot use @handle_grpc_network_errors because async def is not supported yet. # TODO(sang): Support timeout & make it async try: return self._job_client.get_all_jobs() except grpc.aio.AioRpcError as e: if (e.code == grpc.StatusCode.DEADLINE_EXCEEDED or e.code == grpc.StatusCode.UNAVAILABLE): raise DataSourceUnavailable( "Failed to query the data source. " "It is either there's a network issue, or the source is down." ) else: logger.exception(e) raise e
async def test_api_manager_list_workers(state_api_manager): data_source_client = state_api_manager.data_source_client id = b"1234" data_source_client.get_all_worker_info.return_value = GetAllWorkerInfoReply( worker_table_data=[ generate_worker_data(id, pid=1), generate_worker_data(b"12345", pid=2), ]) result = await state_api_manager.list_workers(option=create_api_options()) data = result.result data = data[0] verify_schema(WorkerState, data) """ Test limit """ assert len(result.result) == 2 result = await state_api_manager.list_workers(option=create_api_options( limit=1)) data = result.result assert len(data) == 1 """ Test filters """ # If the column is not supported for filtering, it should raise an exception. with pytest.raises(ValueError): result = await state_api_manager.list_workers( option=create_api_options(filters=[("stat", "DEAD")])) result = await state_api_manager.list_workers(option=create_api_options( filters=[("worker_id", bytearray(id).hex())])) assert len(result.result) == 1 # Make sure it works with int type. result = await state_api_manager.list_workers(option=create_api_options( filters=[("pid", 2)])) assert len(result.result) == 1 """ Test error handling """ data_source_client.get_all_worker_info.side_effect = DataSourceUnavailable( ) with pytest.raises(DataSourceUnavailable) as exc_info: result = await state_api_manager.list_workers( option=create_api_options(limit=1)) assert exc_info.value.args[0] == GCS_QUERY_FAILURE_WARNING
async def test_api_manager_list_pgs(state_api_manager): data_source_client = state_api_manager.data_source_client id = b"1234" data_source_client.get_all_placement_group_info.return_value = ( GetAllPlacementGroupReply(placement_group_table_data=[ generate_pg_data(id), generate_pg_data(b"12345"), ])) result = await state_api_manager.list_placement_groups( option=create_api_options()) data = result.result data = data[0] verify_schema(PlacementGroupState, data) """ Test limit """ assert len(result.result) == 2 result = await state_api_manager.list_placement_groups( option=create_api_options(limit=1)) data = result.result assert len(data) == 1 """ Test filters """ # If the column is not supported for filtering, it should raise an exception. with pytest.raises(ValueError): result = await state_api_manager.list_placement_groups( option=create_api_options(filters=[("stat", "DEAD")])) result = await state_api_manager.list_placement_groups( option=create_api_options(filters=[("placement_group_id", bytearray(id).hex())])) assert len(result.result) == 1 """ Test error handling """ data_source_client.get_all_placement_group_info.side_effect = ( DataSourceUnavailable()) with pytest.raises(DataSourceUnavailable) as exc_info: result = await state_api_manager.list_placement_groups( option=create_api_options(limit=1)) assert exc_info.value.args[0] == GCS_QUERY_FAILURE_WARNING
async def test_api_manager_list_actors(state_api_manager): data_source_client = state_api_manager.data_source_client actor_id = b"1234" data_source_client.get_all_actor_info.return_value = GetAllActorInfoReply( actor_table_data=[ generate_actor_data(actor_id), generate_actor_data(b"12345", state=ActorTableData.ActorState.DEAD), ]) result = await state_api_manager.list_actors(option=create_api_options()) data = result.result actor_data = data[0] verify_schema(ActorState, actor_data) """ Test limit """ assert len(data) == 2 result = await state_api_manager.list_actors(option=create_api_options( limit=1)) data = result.result assert len(data) == 1 """ Test filters """ # If the column is not supported for filtering, it should raise an exception. with pytest.raises(ValueError): result = await state_api_manager.list_actors(option=create_api_options( filters=[("stat", "DEAD")])) result = await state_api_manager.list_actors(option=create_api_options( filters=[("state", "DEAD")])) assert len(result.result) == 1 """ Test error handling """ data_source_client.get_all_actor_info.side_effect = DataSourceUnavailable() with pytest.raises(DataSourceUnavailable) as exc_info: result = await state_api_manager.list_actors(option=create_api_options( limit=1)) assert exc_info.value.args[0] == GCS_QUERY_FAILURE_WARNING
async def test_api_manager_list_runtime_envs(state_api_manager): data_source_client = state_api_manager.data_source_client data_source_client.get_all_registered_agent_ids = MagicMock() data_source_client.get_all_registered_agent_ids.return_value = [ "1", "2", "3" ] data_source_client.get_runtime_envs_info = AsyncMock() data_source_client.get_runtime_envs_info.side_effect = [ generate_runtime_env_info(RuntimeEnv(**{"pip": ["requests"]})), generate_runtime_env_info(RuntimeEnv(**{"pip": ["tensorflow"]}), creation_time=15), generate_runtime_env_info(RuntimeEnv(**{"pip": ["ray"]}), creation_time=10), ] result = await state_api_manager.list_runtime_envs( option=create_api_options()) print(result) data = result.result data_source_client.get_runtime_envs_info.assert_any_await( "1", timeout=DEFAULT_RPC_TIMEOUT) data_source_client.get_runtime_envs_info.assert_any_await( "2", timeout=DEFAULT_RPC_TIMEOUT) data_source_client.get_runtime_envs_info.assert_any_await( "3", timeout=DEFAULT_RPC_TIMEOUT) assert len(data) == 3 verify_schema(RuntimeEnvState, data[0]) verify_schema(RuntimeEnvState, data[1]) verify_schema(RuntimeEnvState, data[2]) # Make sure the higher creation time is sorted first. assert "creation_time_ms" not in data[0] data[1]["creation_time_ms"] > data[2]["creation_time_ms"] """ Test limit """ data_source_client.get_runtime_envs_info.side_effect = [ generate_runtime_env_info(RuntimeEnv(**{"pip": ["requests"]})), generate_runtime_env_info(RuntimeEnv(**{"pip": ["tensorflow"]}), creation_time=15), generate_runtime_env_info(RuntimeEnv(**{"pip": ["ray"]})), ] result = await state_api_manager.list_runtime_envs( option=create_api_options(limit=1)) data = result.result assert len(data) == 1 """ Test filters """ data_source_client.get_runtime_envs_info.side_effect = [ generate_runtime_env_info(RuntimeEnv(**{"pip": ["requests"]}), success=True), generate_runtime_env_info(RuntimeEnv(**{"pip": ["tensorflow"]}), creation_time=15, success=True), generate_runtime_env_info(RuntimeEnv(**{"pip": ["ray"]}), success=False), ] result = await state_api_manager.list_runtime_envs( option=create_api_options(filters=[("success", False)])) assert len(result.result) == 1 """ Test error handling """ data_source_client.get_runtime_envs_info.side_effect = [ DataSourceUnavailable(), generate_runtime_env_info(RuntimeEnv(**{"pip": ["ray"]})), generate_runtime_env_info(RuntimeEnv(**{"pip": ["ray"]})), ] result = await state_api_manager.list_runtime_envs( option=create_api_options(limit=1)) # Make sure warnings are returned. warning = result.partial_failure_warning assert (NODE_QUERY_FAILURE_WARNING.format( type="agent", total=3, network_failures=1, log_command="dashboard_agent.log") in warning) # Test if all RPCs fail, it will raise an exception. data_source_client.get_runtime_envs_info.side_effect = [ DataSourceUnavailable(), DataSourceUnavailable(), DataSourceUnavailable(), ] with pytest.raises(DataSourceUnavailable): result = await state_api_manager.list_runtime_envs( option=create_api_options(limit=1))
async def test_api_manager_list_tasks(state_api_manager): data_source_client = state_api_manager.data_source_client data_source_client.get_all_registered_raylet_ids = MagicMock() data_source_client.get_all_registered_raylet_ids.return_value = ["1", "2"] first_task_name = "1" second_task_name = "2" data_source_client.get_task_info = AsyncMock() id = b"1234" data_source_client.get_task_info.side_effect = [ generate_task_data(id, first_task_name), generate_task_data(b"2345", second_task_name), ] result = await state_api_manager.list_tasks(option=create_api_options()) data_source_client.get_task_info.assert_any_await( "1", timeout=DEFAULT_RPC_TIMEOUT) data_source_client.get_task_info.assert_any_await( "2", timeout=DEFAULT_RPC_TIMEOUT) data = result.result data = data assert len(data) == 2 verify_schema(TaskState, data[0]) verify_schema(TaskState, data[1]) """ Test limit """ data_source_client.get_task_info.side_effect = [ generate_task_data(id, first_task_name), generate_task_data(b"2345", second_task_name), ] result = await state_api_manager.list_tasks(option=create_api_options( limit=1)) data = result.result assert len(data) == 1 """ Test filters """ data_source_client.get_task_info.side_effect = [ generate_task_data(id, first_task_name), generate_task_data(b"2345", second_task_name), ] result = await state_api_manager.list_tasks(option=create_api_options( filters=[("task_id", bytearray(id).hex())])) assert len(result.result) == 1 """ Test error handling """ data_source_client.get_task_info.side_effect = [ DataSourceUnavailable(), generate_task_data(b"2345", second_task_name), ] result = await state_api_manager.list_tasks(option=create_api_options( limit=1)) # Make sure warnings are returned. warning = result.partial_failure_warning assert (NODE_QUERY_FAILURE_WARNING.format(type="raylet", total=2, network_failures=1, log_command="raylet.out") in warning) # Test if all RPCs fail, it will raise an exception. data_source_client.get_task_info.side_effect = [ DataSourceUnavailable(), DataSourceUnavailable(), ] with pytest.raises(DataSourceUnavailable): result = await state_api_manager.list_tasks(option=create_api_options( limit=1))