Esempio n. 1
0
async def test_api_manager_list_objects(state_api_manager):
    data_source_client = state_api_manager.data_source_client
    obj_1_id = b"1" * 28
    obj_2_id = b"2" * 28
    data_source_client.get_all_registered_raylet_ids = MagicMock()
    data_source_client.get_all_registered_raylet_ids.return_value = ["1", "2"]

    data_source_client.get_object_info = AsyncMock()
    data_source_client.get_object_info.side_effect = [
        GetNodeStatsReply(core_workers_stats=[generate_object_info(obj_1_id)]),
        GetNodeStatsReply(core_workers_stats=[generate_object_info(obj_2_id)]),
    ]
    result = await state_api_manager.list_objects(option=list_api_options())
    data = result.result
    data_source_client.get_object_info.assert_any_await(
        "1", timeout=DEFAULT_RPC_TIMEOUT)
    data_source_client.get_object_info.assert_any_await(
        "2", timeout=DEFAULT_RPC_TIMEOUT)
    data = list(data.values())
    assert len(data) == 2
    verify_schema(ObjectState, data[0])
    verify_schema(ObjectState, data[1])
    """
    Test limit
    """
    data_source_client.get_object_info.side_effect = [
        GetNodeStatsReply(core_workers_stats=[generate_object_info(obj_1_id)]),
        GetNodeStatsReply(core_workers_stats=[generate_object_info(obj_2_id)]),
    ]
    result = await state_api_manager.list_objects(option=list_api_options(
        limit=1))
    data = result.result
    assert len(data) == 1
    """
    Test error handling
    """
    data_source_client.get_object_info.side_effect = [
        DataSourceUnavailable(),
        GetNodeStatsReply(core_workers_stats=[generate_object_info(obj_2_id)]),
    ]
    result = await state_api_manager.list_objects(option=list_api_options(
        limit=1))
    # Make sure warnings are returned.
    warning = result.partial_failure_warning
    assert (NODE_QUERY_FAILURE_WARNING.format(type="raylet",
                                              total=2,
                                              network_failures=1,
                                              log_command="raylet.out")
            in warning)

    # Test if all RPCs fail, it will raise an exception.
    data_source_client.get_object_info.side_effect = [
        DataSourceUnavailable(),
        DataSourceUnavailable(),
    ]
    with pytest.raises(DataSourceUnavailable):
        result = await state_api_manager.list_objects(option=list_api_options(
            limit=1))
Esempio n. 2
0
async def test_logs_manager_list_logs(logs_manager):
    logs_client = logs_manager.data_source_client

    logs_client.get_all_registered_agent_ids = MagicMock()
    logs_client.get_all_registered_agent_ids.return_value = ["1", "2"]

    logs_client.list_logs.side_effect = [
        generate_list_logs(["gcs_server.out"]),
        DataSourceUnavailable(),
    ]

    # Unregistered node id should raise a DataSourceUnavailable.
    with pytest.raises(DataSourceUnavailable):
        result = await logs_manager.list_logs(
            node_id="3", timeout=30, glob_filter="*gcs*"
        )

    result = await logs_manager.list_logs(node_id="2", timeout=30, glob_filter="*gcs*")
    assert len(result) == 1
    assert result["gcs_server"] == ["gcs_server.out"]
    assert result["raylet"] == []
    logs_client.get_all_registered_agent_ids.assert_called()
    logs_client.list_logs.assert_awaited_with("2", "*gcs*", timeout=30)

    # The second call raises DataSourceUnavailable, which will
    # return DataSourceUnavailable to the caller.
    with pytest.raises(DataSourceUnavailable):
        result = await logs_manager.list_logs(
            node_id="1", timeout=30, glob_filter="*gcs*"
        )
Esempio n. 3
0
    async def api_with_network_error_handler(*args, **kwargs):
        """Apply the network error handling logic to each APIs,
        such as retry or exception policies.

        Returns:
            If RPC succeeds, it returns what the original function returns.
            If RPC fails, it raises exceptions.
        Exceptions:
            DataSourceUnavailable: if the source is unavailable because it is down
                or there's a slow network issue causing timeout.
            Otherwise, the raw network exceptions (e.g., gRPC) will be raised.
        """
        # TODO(sang): Add a retry policy.
        try:
            return await func(*args, **kwargs)
        except grpc.aio.AioRpcError as e:
            if (e.code() == grpc.StatusCode.DEADLINE_EXCEEDED
                    or e.code() == grpc.StatusCode.UNAVAILABLE):
                raise DataSourceUnavailable(
                    "Failed to query the data source. "
                    "It is either there's a network issue, or the source is down."
                )
            else:
                logger.exception(e)
                raise e
Esempio n. 4
0
async def test_api_manager_list_workers(state_api_manager):
    data_source_client = state_api_manager.data_source_client
    id = b"1234"
    data_source_client.get_all_worker_info.return_value = GetAllWorkerInfoReply(
        worker_table_data=[
            generate_worker_data(id),
            generate_worker_data(b"12345"),
        ])
    result = await state_api_manager.list_workers(option=list_api_options())
    data = result.result
    data = list(data.values())[0]
    verify_schema(WorkerState, data)
    """
    Test limit
    """
    assert len(result.result) == 2
    result = await state_api_manager.list_workers(option=list_api_options(
        limit=1))
    data = result.result
    assert len(data) == 1
    """
    Test error handling
    """
    data_source_client.get_all_worker_info.side_effect = DataSourceUnavailable(
    )
    with pytest.raises(DataSourceUnavailable) as exc_info:
        result = await state_api_manager.list_workers(option=list_api_options(
            limit=1))
    assert exc_info.value.args[0] == GCS_QUERY_FAILURE_WARNING
Esempio n. 5
0
 def _verify_node_registered(self, node_id: str):
     if node_id not in self.client.get_all_registered_agent_ids():
         raise DataSourceUnavailable(
             f"Given node id {node_id} is not available. "
             "It's either the node is dead, or it is not registered. "
             "Use `ray list nodes` "
             "to see the node status. If the node is registered, "
             "it is highly likely "
             "a transient issue. Try again.")
     assert node_id is not None
Esempio n. 6
0
 def get_job_info(self) -> Optional[Dict[str, JobInfo]]:
     # Cannot use @handle_grpc_network_errors because async def is not supported yet.
     # TODO(sang): Support timeout & make it async
     try:
         return self._job_client.get_all_jobs()
     except grpc.aio.AioRpcError as e:
         if (e.code == grpc.StatusCode.DEADLINE_EXCEEDED
                 or e.code == grpc.StatusCode.UNAVAILABLE):
             raise DataSourceUnavailable(
                 "Failed to query the data source. "
                 "It is either there's a network issue, or the source is down."
             )
         else:
             logger.exception(e)
             raise e
Esempio n. 7
0
async def test_api_manager_list_workers(state_api_manager):
    data_source_client = state_api_manager.data_source_client
    id = b"1234"
    data_source_client.get_all_worker_info.return_value = GetAllWorkerInfoReply(
        worker_table_data=[
            generate_worker_data(id, pid=1),
            generate_worker_data(b"12345", pid=2),
        ])
    result = await state_api_manager.list_workers(option=create_api_options())
    data = result.result
    data = data[0]
    verify_schema(WorkerState, data)
    """
    Test limit
    """
    assert len(result.result) == 2
    result = await state_api_manager.list_workers(option=create_api_options(
        limit=1))
    data = result.result
    assert len(data) == 1
    """
    Test filters
    """
    # If the column is not supported for filtering, it should raise an exception.
    with pytest.raises(ValueError):
        result = await state_api_manager.list_workers(
            option=create_api_options(filters=[("stat", "DEAD")]))
    result = await state_api_manager.list_workers(option=create_api_options(
        filters=[("worker_id", bytearray(id).hex())]))
    assert len(result.result) == 1
    # Make sure it works with int type.
    result = await state_api_manager.list_workers(option=create_api_options(
        filters=[("pid", 2)]))
    assert len(result.result) == 1
    """
    Test error handling
    """
    data_source_client.get_all_worker_info.side_effect = DataSourceUnavailable(
    )
    with pytest.raises(DataSourceUnavailable) as exc_info:
        result = await state_api_manager.list_workers(
            option=create_api_options(limit=1))
    assert exc_info.value.args[0] == GCS_QUERY_FAILURE_WARNING
Esempio n. 8
0
async def test_api_manager_list_pgs(state_api_manager):
    data_source_client = state_api_manager.data_source_client
    id = b"1234"
    data_source_client.get_all_placement_group_info.return_value = (
        GetAllPlacementGroupReply(placement_group_table_data=[
            generate_pg_data(id),
            generate_pg_data(b"12345"),
        ]))
    result = await state_api_manager.list_placement_groups(
        option=create_api_options())
    data = result.result
    data = data[0]
    verify_schema(PlacementGroupState, data)
    """
    Test limit
    """
    assert len(result.result) == 2
    result = await state_api_manager.list_placement_groups(
        option=create_api_options(limit=1))
    data = result.result
    assert len(data) == 1
    """
    Test filters
    """
    # If the column is not supported for filtering, it should raise an exception.
    with pytest.raises(ValueError):
        result = await state_api_manager.list_placement_groups(
            option=create_api_options(filters=[("stat", "DEAD")]))
    result = await state_api_manager.list_placement_groups(
        option=create_api_options(filters=[("placement_group_id",
                                            bytearray(id).hex())]))
    assert len(result.result) == 1
    """
    Test error handling
    """
    data_source_client.get_all_placement_group_info.side_effect = (
        DataSourceUnavailable())
    with pytest.raises(DataSourceUnavailable) as exc_info:
        result = await state_api_manager.list_placement_groups(
            option=create_api_options(limit=1))
    assert exc_info.value.args[0] == GCS_QUERY_FAILURE_WARNING
Esempio n. 9
0
async def test_api_manager_list_actors(state_api_manager):
    data_source_client = state_api_manager.data_source_client
    actor_id = b"1234"
    data_source_client.get_all_actor_info.return_value = GetAllActorInfoReply(
        actor_table_data=[
            generate_actor_data(actor_id),
            generate_actor_data(b"12345",
                                state=ActorTableData.ActorState.DEAD),
        ])
    result = await state_api_manager.list_actors(option=create_api_options())
    data = result.result
    actor_data = data[0]
    verify_schema(ActorState, actor_data)
    """
    Test limit
    """
    assert len(data) == 2
    result = await state_api_manager.list_actors(option=create_api_options(
        limit=1))
    data = result.result
    assert len(data) == 1
    """
    Test filters
    """
    # If the column is not supported for filtering, it should raise an exception.
    with pytest.raises(ValueError):
        result = await state_api_manager.list_actors(option=create_api_options(
            filters=[("stat", "DEAD")]))
    result = await state_api_manager.list_actors(option=create_api_options(
        filters=[("state", "DEAD")]))
    assert len(result.result) == 1
    """
    Test error handling
    """
    data_source_client.get_all_actor_info.side_effect = DataSourceUnavailable()
    with pytest.raises(DataSourceUnavailable) as exc_info:
        result = await state_api_manager.list_actors(option=create_api_options(
            limit=1))
    assert exc_info.value.args[0] == GCS_QUERY_FAILURE_WARNING
Esempio n. 10
0
async def test_api_manager_list_runtime_envs(state_api_manager):
    data_source_client = state_api_manager.data_source_client
    data_source_client.get_all_registered_agent_ids = MagicMock()
    data_source_client.get_all_registered_agent_ids.return_value = [
        "1", "2", "3"
    ]

    data_source_client.get_runtime_envs_info = AsyncMock()
    data_source_client.get_runtime_envs_info.side_effect = [
        generate_runtime_env_info(RuntimeEnv(**{"pip": ["requests"]})),
        generate_runtime_env_info(RuntimeEnv(**{"pip": ["tensorflow"]}),
                                  creation_time=15),
        generate_runtime_env_info(RuntimeEnv(**{"pip": ["ray"]}),
                                  creation_time=10),
    ]
    result = await state_api_manager.list_runtime_envs(
        option=create_api_options())
    print(result)
    data = result.result
    data_source_client.get_runtime_envs_info.assert_any_await(
        "1", timeout=DEFAULT_RPC_TIMEOUT)
    data_source_client.get_runtime_envs_info.assert_any_await(
        "2", timeout=DEFAULT_RPC_TIMEOUT)

    data_source_client.get_runtime_envs_info.assert_any_await(
        "3", timeout=DEFAULT_RPC_TIMEOUT)
    assert len(data) == 3
    verify_schema(RuntimeEnvState, data[0])
    verify_schema(RuntimeEnvState, data[1])
    verify_schema(RuntimeEnvState, data[2])

    # Make sure the higher creation time is sorted first.
    assert "creation_time_ms" not in data[0]
    data[1]["creation_time_ms"] > data[2]["creation_time_ms"]
    """
    Test limit
    """
    data_source_client.get_runtime_envs_info.side_effect = [
        generate_runtime_env_info(RuntimeEnv(**{"pip": ["requests"]})),
        generate_runtime_env_info(RuntimeEnv(**{"pip": ["tensorflow"]}),
                                  creation_time=15),
        generate_runtime_env_info(RuntimeEnv(**{"pip": ["ray"]})),
    ]
    result = await state_api_manager.list_runtime_envs(
        option=create_api_options(limit=1))
    data = result.result
    assert len(data) == 1
    """
    Test filters
    """
    data_source_client.get_runtime_envs_info.side_effect = [
        generate_runtime_env_info(RuntimeEnv(**{"pip": ["requests"]}),
                                  success=True),
        generate_runtime_env_info(RuntimeEnv(**{"pip": ["tensorflow"]}),
                                  creation_time=15,
                                  success=True),
        generate_runtime_env_info(RuntimeEnv(**{"pip": ["ray"]}),
                                  success=False),
    ]
    result = await state_api_manager.list_runtime_envs(
        option=create_api_options(filters=[("success", False)]))
    assert len(result.result) == 1
    """
    Test error handling
    """
    data_source_client.get_runtime_envs_info.side_effect = [
        DataSourceUnavailable(),
        generate_runtime_env_info(RuntimeEnv(**{"pip": ["ray"]})),
        generate_runtime_env_info(RuntimeEnv(**{"pip": ["ray"]})),
    ]
    result = await state_api_manager.list_runtime_envs(
        option=create_api_options(limit=1))
    # Make sure warnings are returned.
    warning = result.partial_failure_warning
    assert (NODE_QUERY_FAILURE_WARNING.format(
        type="agent",
        total=3,
        network_failures=1,
        log_command="dashboard_agent.log") in warning)

    # Test if all RPCs fail, it will raise an exception.
    data_source_client.get_runtime_envs_info.side_effect = [
        DataSourceUnavailable(),
        DataSourceUnavailable(),
        DataSourceUnavailable(),
    ]
    with pytest.raises(DataSourceUnavailable):
        result = await state_api_manager.list_runtime_envs(
            option=create_api_options(limit=1))
Esempio n. 11
0
async def test_api_manager_list_tasks(state_api_manager):
    data_source_client = state_api_manager.data_source_client
    data_source_client.get_all_registered_raylet_ids = MagicMock()
    data_source_client.get_all_registered_raylet_ids.return_value = ["1", "2"]

    first_task_name = "1"
    second_task_name = "2"
    data_source_client.get_task_info = AsyncMock()
    id = b"1234"
    data_source_client.get_task_info.side_effect = [
        generate_task_data(id, first_task_name),
        generate_task_data(b"2345", second_task_name),
    ]
    result = await state_api_manager.list_tasks(option=create_api_options())
    data_source_client.get_task_info.assert_any_await(
        "1", timeout=DEFAULT_RPC_TIMEOUT)
    data_source_client.get_task_info.assert_any_await(
        "2", timeout=DEFAULT_RPC_TIMEOUT)
    data = result.result
    data = data
    assert len(data) == 2
    verify_schema(TaskState, data[0])
    verify_schema(TaskState, data[1])
    """
    Test limit
    """
    data_source_client.get_task_info.side_effect = [
        generate_task_data(id, first_task_name),
        generate_task_data(b"2345", second_task_name),
    ]
    result = await state_api_manager.list_tasks(option=create_api_options(
        limit=1))
    data = result.result
    assert len(data) == 1
    """
    Test filters
    """
    data_source_client.get_task_info.side_effect = [
        generate_task_data(id, first_task_name),
        generate_task_data(b"2345", second_task_name),
    ]
    result = await state_api_manager.list_tasks(option=create_api_options(
        filters=[("task_id", bytearray(id).hex())]))
    assert len(result.result) == 1
    """
    Test error handling
    """
    data_source_client.get_task_info.side_effect = [
        DataSourceUnavailable(),
        generate_task_data(b"2345", second_task_name),
    ]
    result = await state_api_manager.list_tasks(option=create_api_options(
        limit=1))
    # Make sure warnings are returned.
    warning = result.partial_failure_warning
    assert (NODE_QUERY_FAILURE_WARNING.format(type="raylet",
                                              total=2,
                                              network_failures=1,
                                              log_command="raylet.out")
            in warning)

    # Test if all RPCs fail, it will raise an exception.
    data_source_client.get_task_info.side_effect = [
        DataSourceUnavailable(),
        DataSourceUnavailable(),
    ]
    with pytest.raises(DataSourceUnavailable):
        result = await state_api_manager.list_tasks(option=create_api_options(
            limit=1))