def test_client_actor_ref_basics(ray_start_regular):
    with ray_start_client_server_pair() as pair:
        ray, server = pair

        @ray.remote
        class Counter:
            def __init__(self):
                self.acc = 0

            def inc(self):
                self.acc += 1

            def get(self):
                return self.acc

        counter = Counter.remote()
        ref = counter.actor_ref

        # Make sure ClientActorRef is a subclass of ActorID
        assert isinstance(ref, ClientActorRef)
        assert isinstance(ref, ActorID)

        # Invalid ref format.
        with pytest.raises(Exception):
            ClientActorRef(b"\0")

        actor_id = b"\0" * 16
        fut = Future()
        fut.set_result(actor_id)
        server_ref = ActorID(actor_id)
        for client_ref in [ClientActorRef(actor_id), ClientActorRef(fut)]:
            client_members = {
                m
                for m in client_ref.__dir__() if not m.startswith("_")
            }
            server_members = {
                m
                for m in server_ref.__dir__() if not m.startswith("_")
            }
            assert client_members.difference(server_members) == {"id"}
            assert server_members.difference(client_members) == set()

            # Test __eq__()
            assert client_ref == ClientActorRef(actor_id)
            assert client_ref != ref
            assert client_ref != server_ref

            # Test other methods
            assert client_ref.__repr__() == f"ClientActorRef({actor_id.hex()})"
            assert client_ref.binary() == actor_id
            assert client_ref.hex() == actor_id.hex()
            assert not client_ref.is_nil()
Esempio n. 2
0
async def test_api_manager_summary_actors(state_api_manager):
    data_source_client = state_api_manager.data_source_client
    actor_ids = [ActorID((f"{i}" * 16).encode()) for i in range(9)]
    class_a = "A"
    class_b = "B"
    data_source_client.get_all_actor_info.return_value = GetAllActorInfoReply(
        actor_table_data=[
            generate_actor_data(
                actor_ids[0].binary(),
                state=ActorTableData.ActorState.ALIVE,
                class_name=class_a,
            ),
            generate_actor_data(
                actor_ids[1].binary(),
                state=ActorTableData.ActorState.DEAD,
                class_name=class_b,
            ),
            generate_actor_data(
                actor_ids[2].binary(),
                state=ActorTableData.ActorState.PENDING_CREATION,
                class_name=class_b,
            ),
            generate_actor_data(
                actor_ids[3].binary(),
                state=ActorTableData.ActorState.DEPENDENCIES_UNREADY,
                class_name=class_b,
            ),
            generate_actor_data(
                actor_ids[4].binary(),
                state=ActorTableData.ActorState.RESTARTING,
                class_name=class_b,
            ),
            generate_actor_data(
                actor_ids[5].binary(),
                state=ActorTableData.ActorState.RESTARTING,
                class_name=class_b,
            ),
        ])
    result = await state_api_manager.summarize_actors(
        option=create_summary_options())
    data = result.result
    assert "cluster" in result.result.node_id_to_summary
    data = result.result.node_id_to_summary["cluster"]
    assert data.total_actors == 6

    assert data.summary[class_a].class_name == class_a
    assert data.summary[class_a].state_counts["ALIVE"] == 1

    assert data.summary[class_b].class_name == class_b
    assert data.summary[class_b].state_counts["DEAD"] == 1
    assert data.summary[class_b].state_counts["DEPENDENCIES_UNREADY"] == 1
    assert data.summary[class_b].state_counts["PENDING_CREATION"] == 1
    assert data.summary[class_b].state_counts["RESTARTING"] == 2
    """
    Test if it can be correctly modified to a dictionary.
    """
    print(result.result)
    result_in_dict = asdict(result.result)
    assert json.loads(json.dumps(result_in_dict)) == result_in_dict
Esempio n. 3
0
def test_client_actor_ref_basics(ray_start_regular):
    with ray_start_client_server_pair() as pair:
        ray, server = pair

        @ray.remote
        class Counter:
            def __init__(self):
                self.acc = 0

            def inc(self):
                self.acc += 1

            def get(self):
                return self.acc

        counter = Counter.remote()
        ref = counter.actor_ref

        # Make sure ClientActorRef is a subclass of ActorID
        assert isinstance(ref, ClientActorRef)
        assert isinstance(ref, ActorID)

        # Invalid ref format.
        with pytest.raises(Exception):
            ClientActorRef(b"\0")

        # Test __eq__()
        id = b"\0" * 16
        assert ClientActorRef(id) == ClientActorRef(id)
        assert ClientActorRef(id) != ref
        assert ClientActorRef(id) != ActorID(id)

        assert ClientActorRef(id).__repr__() == f"ClientActorRef({id.hex()})"
        assert ClientActorRef(id).binary() == id
        assert ClientActorRef(id).hex() == id.hex()
        assert not ClientActorRef(id).is_nil()
Esempio n. 4
0
async def test_logs_manager_resolve_file(logs_manager):
    node_id = NodeID(b"1" * 28)
    """
    Test filename is given.
    """
    logs_client = logs_manager.data_source_client
    logs_client.get_all_registered_agent_ids = MagicMock()
    logs_client.get_all_registered_agent_ids.return_value = [node_id.hex()]
    expected_filename = "filename"
    log_file_name, n = await logs_manager.resolve_filename(
        node_id=node_id,
        log_filename=expected_filename,
        actor_id=None,
        task_id=None,
        pid=None,
        get_actor_fn=lambda _: True,
        timeout=10,
    )
    assert log_file_name == expected_filename
    assert n == node_id
    """
    Test actor id is given.
    """
    # Actor doesn't exist.
    with pytest.raises(ValueError):
        actor_id = ActorID(b"2" * 16)

        def get_actor_fn(id):
            if id == actor_id:
                return None
            assert False, "Not reachable."

        log_file_name, n = await logs_manager.resolve_filename(
            node_id=node_id,
            log_filename=None,
            actor_id=actor_id,
            task_id=None,
            pid=None,
            get_actor_fn=get_actor_fn,
            timeout=10,
        )

    # Actor exists, but it is not scheduled yet.
    actor_id = ActorID(b"2" * 16)

    with pytest.raises(ValueError):
        log_file_name, n = await logs_manager.resolve_filename(
            node_id=node_id,
            log_filename=None,
            actor_id=actor_id,
            task_id=None,
            pid=None,
            get_actor_fn=lambda _: generate_actor_data(actor_id, node_id, None),
            timeout=10,
        )

    # Actor exists.
    actor_id = ActorID(b"2" * 16)
    worker_id = WorkerID(b"3" * 28)
    logs_manager.list_logs = AsyncMock()
    logs_manager.list_logs.return_value = {
        "worker_out": [f"worker-{worker_id.hex()}-123-123.out"]
    }
    log_file_name, n = await logs_manager.resolve_filename(
        node_id=node_id.hex(),
        log_filename=None,
        actor_id=actor_id,
        task_id=None,
        pid=None,
        get_actor_fn=lambda _: generate_actor_data(actor_id, node_id, worker_id),
        timeout=10,
    )
    logs_manager.list_logs.assert_awaited_with(
        node_id.hex(), 10, glob_filter=f"*{worker_id.hex()}*"
    )
    assert log_file_name == f"worker-{worker_id.hex()}-123-123.out"
    assert n == node_id.hex()

    """
    Test task id is given.
    """
    with pytest.raises(NotImplementedError):
        task_id = TaskID(b"2" * 24)
        log_file_name, n = await logs_manager.resolve_filename(
            node_id=node_id.hex(),
            log_filename=None,
            actor_id=None,
            task_id=task_id,
            pid=None,
            get_actor_fn=lambda _: generate_actor_data(actor_id, node_id, worker_id),
            timeout=10,
        )

    """
    Test pid is given.
    """
    # Pid doesn't exist.
    with pytest.raises(FileNotFoundError):
        pid = 456
        logs_manager.list_logs = AsyncMock()
        # Provide the wrong pid.
        logs_manager.list_logs.return_value = {"worker_out": ["worker-123-123-123.out"]}
        log_file_name = await logs_manager.resolve_filename(
            node_id=node_id.hex(),
            log_filename=None,
            actor_id=None,
            task_id=None,
            pid=pid,
            get_actor_fn=lambda _: generate_actor_data(actor_id, node_id, worker_id),
            timeout=10,
        )

    # Pid exists.
    pid = 123
    logs_manager.list_logs = AsyncMock()
    # Provide the wrong pid.
    logs_manager.list_logs.return_value = {"worker_out": [f"worker-123-123-{pid}.out"]}
    log_file_name, n = await logs_manager.resolve_filename(
        node_id=node_id.hex(),
        log_filename=None,
        actor_id=None,
        task_id=None,
        pid=pid,
        get_actor_fn=lambda _: generate_actor_data(actor_id, node_id, worker_id),
        timeout=10,
    )
    logs_manager.list_logs.assert_awaited_with(
        node_id.hex(), 10, glob_filter=f"*{pid}*"
    )
    assert log_file_name == f"worker-123-123-{pid}.out"

    """
    Test nothing is given.
    """
    with pytest.raises(FileNotFoundError):
        log_file_name = await logs_manager.resolve_filename(
            node_id=node_id.hex(),
            log_filename=None,
            actor_id=None,
            task_id=None,
            pid=None,
            get_actor_fn=lambda _: generate_actor_data(actor_id, node_id, worker_id),
            timeout=10,
        )
Esempio n. 5
0
import base64
import logging
from collections import defaultdict
from enum import Enum
from typing import List

import ray
from ray._private.internal_api import node_stats
from ray._raylet import ActorID, JobID, TaskID

logger = logging.getLogger(__name__)

# These values are used to calculate if objectRefs are actor handles.
TASKID_BYTES_SIZE = TaskID.size()
ACTORID_BYTES_SIZE = ActorID.size()
JOBID_BYTES_SIZE = JobID.size()
# We need to multiply 2 because we need bits size instead of bytes size.
TASKID_RANDOM_BITS_SIZE = (TASKID_BYTES_SIZE - ACTORID_BYTES_SIZE) * 2
ACTORID_RANDOM_BITS_SIZE = (ACTORID_BYTES_SIZE - JOBID_BYTES_SIZE) * 2


def decode_object_ref_if_needed(object_ref: str) -> bytes:
    """Decode objectRef bytes string.

    gRPC reply contains an objectRef that is encodded by Base64.
    This function is used to decode the objectRef.
    Note that there are times that objectRef is already decoded as
    a hex string. In this case, just convert it to a binary number.
    """
    if object_ref.endswith("="):
        # If the object ref ends with =, that means it is base64 encoded.