Example #1
0
def test_shutdown_disconnect_global_state():
    ray.init(num_cpus=0)
    ray.shutdown()

    with pytest.raises(Exception) as e:
        ray.objects()
    assert str(e.value).endswith("ray.init has been called.")
Example #2
0
def test_lease_request_leak(shutdown_only):
    ray.init(
        num_cpus=1,
        _internal_config=json.dumps({
            "initial_reconstruction_timeout_milliseconds": 200
        }))
    assert len(ray.objects()) == 0

    @ray.remote
    def f(x):
        time.sleep(0.1)
        return

    # Submit pairs of tasks. Tasks in a pair can reuse the same worker leased
    # from the raylet.
    tasks = []
    for _ in range(10):
        obj_ref = ray.put(1)
        for _ in range(2):
            tasks.append(f.remote(obj_ref))
        del obj_ref
    ray.get(tasks)

    time.sleep(
        1)  # Sleep for an amount longer than the reconstruction timeout.
    assert len(ray.objects()) == 0, ray.objects()
Example #3
0
def test_lease_request_leak(shutdown_only):
    ray.init(
        num_cpus=1,
        _system_config={
            # This test uses ray.objects(), which only works with the GCS-based
            # object directory
            "ownership_based_object_directory_enabled": False,
            "object_timeout_milliseconds": 200
        })
    assert len(ray.objects()) == 0

    @ray.remote
    def f(x):
        time.sleep(0.1)
        return

    # Submit pairs of tasks. Tasks in a pair can reuse the same worker leased
    # from the raylet.
    tasks = []
    for _ in range(10):
        obj_ref = ray.put(1)
        for _ in range(2):
            tasks.append(f.remote(obj_ref))
        del obj_ref
    ray.get(tasks)

    time.sleep(
        1)  # Sleep for an amount longer than the reconstruction timeout.
    assert len(ray.objects()) == 0, ray.objects()
Example #4
0
def test_global_state_task_object_api(shutdown_only):
    ray.init()

    job_id = ray.utils.compute_job_id_from_driver(
        ray.WorkerID(ray.worker.global_worker.worker_id))
    driver_task_id = ray.worker.global_worker.current_task_id.hex()

    nil_actor_id_hex = ray.ActorID.nil().hex()

    @ray.remote
    def f(*xs):
        return 1

    x_id = ray.put(1)
    result_id = f.remote(1, "hi", x_id)

    # Wait for one additional task to complete.
    wait_for_num_tasks(1 + 1)
    task_table = ray.tasks()
    assert len(task_table) == 1 + 1
    task_id_set = set(task_table.keys())
    task_id_set.remove(driver_task_id)
    task_id = list(task_id_set)[0]

    task_spec = task_table[task_id]["TaskSpec"]
    assert task_spec["ActorID"] == nil_actor_id_hex
    assert task_spec["Args"] == [
        signature.DUMMY_TYPE, 1, signature.DUMMY_TYPE, "hi",
        signature.DUMMY_TYPE, x_id
    ]
    assert task_spec["JobID"] == job_id.hex()
    assert task_spec["ReturnObjectIDs"] == [result_id]

    assert task_table[task_id] == ray.tasks(task_id)

    # Wait for two objects, one for the x_id and one for result_id.
    wait_for_num_objects(2)

    def wait_for_object_table():
        timeout = 10
        start_time = time.time()
        while time.time() - start_time < timeout:
            object_table = ray.objects()
            tables_ready = (object_table[x_id]["ManagerIDs"] is not None and
                            object_table[result_id]["ManagerIDs"] is not None)
            if tables_ready:
                return
            time.sleep(0.1)
        raise RayTestTimeoutException(
            "Timed out while waiting for object table to "
            "update.")

    object_table = ray.objects()
    assert len(object_table) == 2

    assert object_table[x_id] == ray.objects(x_id)
    object_table_entry = ray.objects(result_id)
    assert object_table[result_id] == object_table_entry
Example #5
0
def test_global_state_api(shutdown_only):

    error_message = ("The ray global state API cannot be used "
                     "before ray.init has been called.")

    with pytest.raises(Exception, match=error_message):
        ray.objects()

    with pytest.raises(Exception, match=error_message):
        ray.actors()

    with pytest.raises(Exception, match=error_message):
        ray.nodes()

    with pytest.raises(Exception, match=error_message):
        ray.jobs()

    ray.init(num_cpus=5, num_gpus=3, resources={"CustomResource": 1})

    assert ray.cluster_resources()["CPU"] == 5
    assert ray.cluster_resources()["GPU"] == 3
    assert ray.cluster_resources()["CustomResource"] == 1

    assert ray.objects() == {}

    job_id = ray.utils.compute_job_id_from_driver(
        ray.WorkerID(ray.worker.global_worker.worker_id))

    client_table = ray.nodes()
    node_ip_address = ray.worker.global_worker.node_ip_address

    assert len(client_table) == 1
    assert client_table[0]["NodeManagerAddress"] == node_ip_address

    @ray.remote
    class Actor:
        def __init__(self):
            pass

    _ = Actor.remote()  # noqa: F841
    # Wait for actor to be created
    wait_for_num_actors(1)

    actor_table = ray.actors()
    assert len(actor_table) == 1

    actor_info, = actor_table.values()
    assert actor_info["JobID"] == job_id.hex()
    assert "IPAddress" in actor_info["Address"]
    assert "IPAddress" in actor_info["OwnerAddress"]
    assert actor_info["Address"]["Port"] != actor_info["OwnerAddress"]["Port"]

    job_table = ray.jobs()

    assert len(job_table) == 1
    assert job_table[0]["JobID"] == job_id.hex()
    assert job_table[0]["NodeManagerAddress"] == node_ip_address
Example #6
0
def test_cleanup_on_driver_exit(call_ray_start):
    # This test will create a driver that creates a bunch of objects and then
    # exits. The entries in the object table should be cleaned up.
    address = call_ray_start

    ray.init(address=address)

    # Define a driver that creates a bunch of objects and exits.
    driver_script = """
import time
import ray
import numpy as np
ray.init(address="{}")
object_refs = [ray.put(np.zeros(200 * 1024, dtype=np.uint8))
              for i in range(1000)]
start_time = time.time()
while time.time() - start_time < 30:
    if len(ray.objects()) == 1000:
        break
else:
    raise Exception("Objects did not appear in object table.")
print("success")
""".format(address)

    run_string_as_driver(driver_script)

    # Make sure the objects are removed from the object table.
    start_time = time.time()
    while time.time() - start_time < 30:
        if len(ray.objects()) == 0:
            break
    else:
        raise Exception("Objects were not all removed from object table.")
Example #7
0
def wait_for_num_objects(num_objects, timeout=10):
    start_time = time.time()
    while time.time() - start_time < timeout:
        if len(ray.objects()) >= num_objects:
            return
        time.sleep(0.1)
    raise RayTestTimeoutException("Timed out while waiting for global state.")
Example #8
0
def test_delete_refs_on_disconnect(ray_start_regular):
    with ray_start_client_server() as ray:

        @ray.remote
        def f(x):
            return x + 2

        thing1 = f.remote(6)  # noqa
        thing2 = ray.put("Hello World")  # noqa

        # One put, one function -- the function result thing1 is
        # in a different category, according to the raylet.
        assert len(real_ray.objects()) == 2
        # But we're maintaining the reference
        assert server_object_ref_count(3)()
        # And can get the data
        assert ray.get(thing1) == 8

        # Close the client
        ray.close()

        wait_for_condition(server_object_ref_count(0), timeout=5)

        def test_cond():
            return len(real_ray.objects()) == 0

        wait_for_condition(test_cond, timeout=5)
Example #9
0
def test_delete_refs_on_disconnect(ray_start_cluster):
    cluster = ray_start_cluster
    with ray_start_cluster_client_server_pair(cluster.address) as pair:
        ray, server = pair

        @ray.remote
        def f(x):
            return x + 2

        thing1 = f.remote(6)  # noqa
        thing2 = ray.put("Hello World")  # noqa

        # One put, one function -- the function result thing1 is
        # in a different category, according to the raylet.
        assert len(real_ray.objects()) == 2
        # But we're maintaining the reference
        assert server_object_ref_count(server, 3)()
        # And can get the data
        assert ray.get(thing1) == 8

        # Close the client.
        ray.close()

        wait_for_condition(server_object_ref_count(server, 0), timeout=5)

        # Connect to the real ray again, since we disconnected
        # upon num_clients = 0.
        real_ray.init(address=cluster.address)

        def test_cond():
            return len(real_ray.objects()) == 0

        wait_for_condition(test_cond, timeout=5)
Example #10
0
 def on_trial_result(self, *args, **kwargs):
     self.iter_ += 1
     all_files = self.process.open_files()
     if self.verbose:
         print("Iteration", self.iter_)
         print("=" * 10)
         print("Number of objects: ", len(ray.objects()))
         print("Virtual Mem:", self.get_virt_mem() >> 30, "gb")
         print("File Descriptors:", len(all_files))
     assert len(all_files) < 20
Example #11
0
 def wait_for_object_table():
     timeout = 10
     start_time = time.time()
     while time.time() - start_time < timeout:
         object_table = ray.objects()
         tables_ready = (object_table[x_id]["ManagerIDs"] is not None and
                         object_table[result_id]["ManagerIDs"] is not None)
         if tables_ready:
             return
         time.sleep(0.1)
     raise RayTestTimeoutException(
         "Timed out while waiting for object table to "
         "update.")
Example #12
0
def test_global_state_api(shutdown_only):

    ray.init(num_cpus=5, num_gpus=3, resources={"CustomResource": 1})

    assert ray.cluster_resources()["CPU"] == 5
    assert ray.cluster_resources()["GPU"] == 3
    assert ray.cluster_resources()["CustomResource"] == 1

    # A driver/worker creates a temporary object during startup. Although the
    # temporary object is freed immediately, in a rare case, we can still find
    # the object ref in GCS because Raylet removes the object ref from GCS
    # asynchronously.
    # Because we can't control when workers create the temporary objects, so
    # We can't assert that `ray.objects()` returns an empty dict. Here we just
    # make sure `ray.objects()` succeeds.
    assert len(ray.objects()) >= 0

    job_id = ray.utils.compute_job_id_from_driver(
        ray.WorkerID(ray.worker.global_worker.worker_id))

    client_table = ray.nodes()
    node_ip_address = ray.worker.global_worker.node_ip_address

    assert len(client_table) == 1
    assert client_table[0]["NodeManagerAddress"] == node_ip_address

    @ray.remote
    class Actor:
        def __init__(self):
            pass

    _ = Actor.options(name="test_actor").remote()  # noqa: F841
    # Wait for actor to be created
    wait_for_num_actors(1)

    actor_table = ray.actors()
    assert len(actor_table) == 1

    actor_info, = actor_table.values()
    assert actor_info["JobID"] == job_id.hex()
    assert actor_info["Name"] == "test_actor"
    assert "IPAddress" in actor_info["Address"]
    assert "IPAddress" in actor_info["OwnerAddress"]
    assert actor_info["Address"]["Port"] != actor_info["OwnerAddress"]["Port"]

    job_table = ray.jobs()

    assert len(job_table) == 1
    assert job_table[0]["JobID"] == job_id.hex()
    assert job_table[0]["DriverIPAddress"] == node_ip_address
Example #13
0
def test_lease_request_leak(shutdown_only):
    ray.init(num_cpus=1, _system_config={"object_timeout_milliseconds": 200})
    assert len(ray.objects()) == 0

    @ray.remote
    def f(x):
        time.sleep(0.1)
        return

    # Submit pairs of tasks. Tasks in a pair can reuse the same worker leased
    # from the raylet.
    tasks = []
    for _ in range(10):
        obj_ref = ray.put(1)
        for _ in range(2):
            tasks.append(f.remote(obj_ref))
        del obj_ref
    ray.get(tasks)

    def _no_objects():
        return len(ray.objects()) == 0

    wait_for_condition(_no_objects, timeout=10)
Example #14
0
def test_global_state_api(shutdown_only):

    error_message = ("The ray global state API cannot be used "
                     "before ray.init has been called.")

    with pytest.raises(Exception, match=error_message):
        ray.objects()

    with pytest.raises(Exception, match=error_message):
        ray.tasks()

    with pytest.raises(Exception, match=error_message):
        ray.nodes()

    with pytest.raises(Exception, match=error_message):
        ray.jobs()

    ray.init(num_cpus=5, num_gpus=3, resources={"CustomResource": 1})

    assert ray.cluster_resources()["CPU"] == 5
    assert ray.cluster_resources()["GPU"] == 3
    assert ray.cluster_resources()["CustomResource"] == 1

    assert ray.objects() == {}

    job_id = ray.utils.compute_job_id_from_driver(
        ray.WorkerID(ray.worker.global_worker.worker_id))
    driver_task_id = ray.worker.global_worker.current_task_id.hex()

    # One task is put in the task table which corresponds to this driver.
    wait_for_num_tasks(1)
    task_table = ray.tasks()
    assert len(task_table) == 1
    assert driver_task_id == list(task_table.keys())[0]
    task_spec = task_table[driver_task_id]["TaskSpec"]
    nil_unique_id_hex = ray.UniqueID.nil().hex()
    nil_actor_id_hex = ray.ActorID.nil().hex()

    assert task_spec["TaskID"] == driver_task_id
    assert task_spec["ActorID"] == nil_actor_id_hex
    assert task_spec["Args"] == []
    assert task_spec["JobID"] == job_id.hex()
    assert task_spec["FunctionID"] == nil_unique_id_hex
    assert task_spec["ReturnObjectIDs"] == []

    client_table = ray.nodes()
    node_ip_address = ray.worker.global_worker.node_ip_address

    assert len(client_table) == 1
    assert client_table[0]["NodeManagerAddress"] == node_ip_address

    @ray.remote
    def f(*xs):
        return 1

    x_id = ray.put(1)
    result_id = f.remote(1, "hi", x_id)

    # Wait for one additional task to complete.
    wait_for_num_tasks(1 + 1)
    task_table = ray.tasks()
    assert len(task_table) == 1 + 1
    task_id_set = set(task_table.keys())
    task_id_set.remove(driver_task_id)
    task_id = list(task_id_set)[0]

    task_spec = task_table[task_id]["TaskSpec"]
    assert task_spec["ActorID"] == nil_actor_id_hex
    assert task_spec["Args"] == [
        signature.DUMMY_TYPE, 1, signature.DUMMY_TYPE, "hi",
        signature.DUMMY_TYPE, x_id
    ]
    assert task_spec["JobID"] == job_id.hex()
    assert task_spec["ReturnObjectIDs"] == [result_id]

    assert task_table[task_id] == ray.tasks(task_id)

    # Wait for two objects, one for the x_id and one for result_id.
    wait_for_num_objects(2)

    def wait_for_object_table():
        timeout = 10
        start_time = time.time()
        while time.time() - start_time < timeout:
            object_table = ray.objects()
            tables_ready = (object_table[x_id]["ManagerIDs"] is not None and
                            object_table[result_id]["ManagerIDs"] is not None)
            if tables_ready:
                return
            time.sleep(0.1)
        raise RayTestTimeoutException(
            "Timed out while waiting for object table to "
            "update.")

    object_table = ray.objects()
    assert len(object_table) == 2

    assert object_table[x_id] == ray.objects(x_id)
    object_table_entry = ray.objects(result_id)
    assert object_table[result_id] == object_table_entry

    job_table = ray.jobs()

    assert len(job_table) == 1
    assert job_table[0]["JobID"] == job_id.hex()
    assert job_table[0]["NodeManagerAddress"] == node_ip_address
Example #15
0
def get_object_ids():
    return list(ray.objects().keys())
Example #16
0
    def _xray_clean_up_entries_for_job(self, job_id):
        """Remove this job's object/task entries from redis.

        Removes control-state entries of all tasks and task return
        objects belonging to the driver.

        Args:
            job_id: The job id.
        """

        xray_task_table_prefix = (
            ray.gcs_utils.TablePrefix_RAYLET_TASK_string.encode("ascii"))
        xray_object_table_prefix = (
            ray.gcs_utils.TablePrefix_OBJECT_string.encode("ascii"))

        task_table_objects = ray.tasks()
        job_id_hex = binary_to_hex(job_id)
        job_task_id_bins = set()
        for task_id_hex, task_info in task_table_objects.items():
            task_table_object = task_info["TaskSpec"]
            task_job_id_hex = task_table_object["JobID"]
            if job_id_hex != task_job_id_hex:
                # Ignore tasks that aren't from this driver.
                continue
            job_task_id_bins.add(hex_to_binary(task_id_hex))

        # Get objects associated with the driver.
        object_table_objects = ray.objects()
        job_object_id_bins = set()
        for object_id, _ in object_table_objects.items():
            task_id_bin = ray._raylet.compute_task_id(object_id).binary()
            if task_id_bin in job_task_id_bins:
                job_object_id_bins.add(object_id.binary())

        def to_shard_index(id_bin):
            if len(id_bin) == ray.TaskID.size():
                return binary_to_task_id(id_bin).redis_shard_hash() % len(
                    ray.state.state.redis_clients)
            else:
                return binary_to_object_id(id_bin).redis_shard_hash() % len(
                    ray.state.state.redis_clients)

        # Form the redis keys to delete.
        sharded_keys = [[] for _ in range(len(ray.state.state.redis_clients))]
        for task_id_bin in job_task_id_bins:
            sharded_keys[to_shard_index(task_id_bin)].append(
                xray_task_table_prefix + task_id_bin)
        for object_id_bin in job_object_id_bins:
            sharded_keys[to_shard_index(object_id_bin)].append(
                xray_object_table_prefix + object_id_bin)

        # Remove with best effort.
        for shard_index in range(len(sharded_keys)):
            keys = sharded_keys[shard_index]
            if len(keys) == 0:
                continue
            redis = ray.state.state.redis_clients[shard_index]
            num_deleted = redis.delete(*keys)
            logger.info("Monitor: "
                        "Removed {} dead redis entries of the "
                        "driver from redis shard {}.".format(
                            num_deleted, shard_index))
            if num_deleted != len(keys):
                logger.warning("Monitor: "
                               "Failed to remove {} relevant redis "
                               "entries from redis shard {}.".format(
                                   len(keys) - num_deleted, shard_index))
Example #17
0
        # backed by the object store.
        values = [random.random() for i in range(5)]
        keys = [i for i in range(len(values))]
        self.weights = dict(zip(keys, values))

    def push(self, keys, values):
        for key, value in zip(keys, values):
            self.weights[key] += value

    def pull(self, keys):
        return [self.weights[key] for key in keys]


@ray.remote
def worker_task(ps, worker_index, batch_size=50):
    return


if __name__ == "__main__":
    args = parser.parse_args()

    ray.init(redis_address=args.redis_address)

    ps = ParameterServer.remote()
    ps1 = ParameterServer.remote()

    object_ids = [ray.put(ps), ray.put(ps1)]

    print(ray.objects())
    print([ray.get(o.pull.remote([0, 1])) for o in ray.get(object_ids)])
Example #18
0
 def object_table(self, object_id=None):
     logger.warning(
         "ray.global_state.object_table() is deprecated and will be "
         "removed in a subsequent release. Use ray.objects() instead.")
     return ray.objects(object_id=object_id)
Example #19
0
 def save(self, *args, **kwargs):
     checkpoint = super(CustomExecutor, self).save(*args, **kwargs)
     assert len(ray.objects()) <= 12
     return checkpoint
Example #20
0
def test_global_state_api(shutdown_only):

    error_message = ("The ray global state API cannot be used "
                     "before ray.init has been called.")

    with pytest.raises(Exception, match=error_message):
        ray.objects()

    with pytest.raises(Exception, match=error_message):
        ray.actors()

    with pytest.raises(Exception, match=error_message):
        ray.tasks()

    with pytest.raises(Exception, match=error_message):
        ray.nodes()

    with pytest.raises(Exception, match=error_message):
        ray.jobs()

    ray.init(num_cpus=5, num_gpus=3, resources={"CustomResource": 1})

    assert ray.cluster_resources()["CPU"] == 5
    assert ray.cluster_resources()["GPU"] == 3
    assert ray.cluster_resources()["CustomResource"] == 1

    assert ray.objects() == {}

    job_id = ray.utils.compute_job_id_from_driver(
        ray.WorkerID(ray.worker.global_worker.worker_id))
    driver_task_id = ray.worker.global_worker.current_task_id.hex()

    # One task is put in the task table which corresponds to this driver.
    wait_for_num_tasks(1)
    task_table = ray.tasks()
    assert len(task_table) == 1
    assert driver_task_id == list(task_table.keys())[0]
    task_spec = task_table[driver_task_id]["TaskSpec"]
    nil_unique_id_hex = ray.UniqueID.nil().hex()
    nil_actor_id_hex = ray.ActorID.nil().hex()

    assert task_spec["TaskID"] == driver_task_id
    assert task_spec["ActorID"] == nil_actor_id_hex
    assert task_spec["Args"] == []
    assert task_spec["JobID"] == job_id.hex()
    assert task_spec["FunctionID"] == nil_unique_id_hex
    assert task_spec["ReturnObjectIDs"] == []

    client_table = ray.nodes()
    node_ip_address = ray.worker.global_worker.node_ip_address

    assert len(client_table) == 1
    assert client_table[0]["NodeManagerAddress"] == node_ip_address

    @ray.remote
    class Actor:
        def __init__(self):
            pass

    _ = Actor.remote()
    # Wait for actor to be created
    wait_for_num_actors(1)

    actor_table = ray.actors()
    assert len(actor_table) == 1

    actor_info, = actor_table.values()
    assert actor_info["JobID"] == job_id.hex()
    assert "IPAddress" in actor_info["Address"]
    assert "IPAddress" in actor_info["OwnerAddress"]
    assert actor_info["Address"]["Port"] != actor_info["OwnerAddress"]["Port"]

    job_table = ray.jobs()

    assert len(job_table) == 1
    assert job_table[0]["JobID"] == job_id.hex()
    assert job_table[0]["NodeManagerAddress"] == node_ip_address
Example #21
0
 def StateSummary():
     obj_tbl_len = len(ray.objects())
     task_tbl_len = len(ray.tasks())
     return obj_tbl_len, task_tbl_len
Example #22
0
 def test_cond():
     return len(real_ray.objects()) == 0
Example #23
0
 def all_agent_ids(self):
     return ray.objects().keys()
Example #24
0
 def _no_objects():
     return len(ray.objects()) == 0