예제 #1
0
def test_delete_actor_on_disconnect(ray_start_regular):
    with ray_start_client_server() as ray:

        @ray.remote
        class Accumulator:
            def __init__(self):
                self.acc = 0

            def inc(self):
                self.acc += 1

            def get(self):
                return self.acc

        actor = Accumulator.remote()
        actor.inc.remote()

        assert server_actor_ref_count(1)()

        assert ray.get(actor.get.remote()) == 1

        ray.close()

        wait_for_condition(server_actor_ref_count(0), timeout=5)

        def test_cond():
            alive_actors = [
                v for v in real_ray.actors().values()
                if v["State"] != ActorTableData.DEAD
            ]
            return len(alive_actors) == 0

        wait_for_condition(test_cond, timeout=10)
예제 #2
0
def test_delete_refs_on_disconnect(ray_start_regular):
    with ray_start_client_server() as ray:

        @ray.remote
        def f(x):
            return x + 2

        thing1 = f.remote(6)  # noqa
        thing2 = ray.put("Hello World")  # noqa

        # One put, one function -- the function result thing1 is
        # in a different category, according to the raylet.
        assert len(real_ray.objects()) == 2
        # But we're maintaining the reference
        assert server_object_ref_count(3)()
        # And can get the data
        assert ray.get(thing1) == 8

        # Close the client
        ray.close()

        wait_for_condition(server_object_ref_count(0), timeout=5)

        def test_cond():
            return len(real_ray.objects()) == 0

        wait_for_condition(test_cond, timeout=5)
예제 #3
0
def test_delete_ref_on_object_deletion(ray_start_regular):
    with ray_start_client_server() as ray:
        vals = {
            "ref": ray.put("Hello World"),
            "ref2": ray.put("This value stays"),
        }

        del vals["ref"]

        wait_for_condition(server_object_ref_count(1), timeout=5)
예제 #4
0
def test_cancel_chain(ray_start_regular, use_force):
    with ray_start_client_server() as ray:

        @ray.remote
        class SignalActor:
            def __init__(self):
                self.ready_event = asyncio.Event()

            def send(self, clear=False):
                self.ready_event.set()
                if clear:
                    self.ready_event.clear()

            async def wait(self, should_wait=True):
                if should_wait:
                    await self.ready_event.wait()

        signaler = SignalActor.remote()

        @ray.remote
        def wait_for(t):
            return ray.get(t[0])

        obj1 = wait_for.remote([signaler.wait.remote()])
        obj2 = wait_for.remote([obj1])
        obj3 = wait_for.remote([obj2])
        obj4 = wait_for.remote([obj3])

        assert len(ray.wait([obj1], timeout=.1)[0]) == 0
        ray.cancel(obj1, force=use_force)
        for ob in [obj1, obj2, obj3, obj4]:
            with pytest.raises(valid_exceptions(use_force)):
                ray.get(ob)

        signaler2 = SignalActor.remote()
        obj1 = wait_for.remote([signaler2.wait.remote()])
        obj2 = wait_for.remote([obj1])
        obj3 = wait_for.remote([obj2])
        obj4 = wait_for.remote([obj3])

        assert len(ray.wait([obj3], timeout=.1)[0]) == 0
        ray.cancel(obj3, force=use_force)
        for ob in [obj3, obj4]:
            with pytest.raises(valid_exceptions(use_force)):
                ray.get(ob)

        with pytest.raises(GetTimeoutError):
            ray.get(obj1, timeout=.1)

        with pytest.raises(GetTimeoutError):
            ray.get(obj2, timeout=.1)

        signaler2.send.remote()
        ray.get(obj1)
예제 #5
0
def test_kill_actor_immediately_after_creation(ray_start_regular):
    with ray_start_client_server() as ray:

        @ray.remote
        class A:
            pass

        a = A.remote()
        b = A.remote()

        ray.kill(a)
        ray.kill(b)
        wait_for_condition(_all_actors_dead(ray), timeout=10)
예제 #6
0
def test_simple_multiple_references(ray_start_regular):
    with ray_start_client_server() as ray:

        @ray.remote
        class A:
            def __init__(self):
                self.x = ray.put("hi")

            def get(self):
                return [self.x]

        a = A.remote()
        ref1 = ray.get(a.get.remote())[0]
        ref2 = ray.get(a.get.remote())[0]
        del a
        assert ray.get(ref1) == "hi"
        del ref1
        assert ray.get(ref2) == "hi"
        del ref2
예제 #7
0
def test_delete_actor(ray_start_regular):
    with ray_start_client_server() as ray:

        @ray.remote
        class Accumulator:
            def __init__(self):
                self.acc = 0

            def inc(self):
                self.acc += 1

        actor = Accumulator.remote()
        actor.inc.remote()
        actor2 = Accumulator.remote()
        actor2.inc.remote()

        assert server_actor_ref_count(2)()

        del actor

        wait_for_condition(server_actor_ref_count(1), timeout=5)
예제 #8
0
def test_cancel_chain(ray_start_regular, use_force):
    with ray_start_client_server() as ray:
        SignalActor = create_remote_signal_actor(ray)
        signaler = SignalActor.remote()

        @ray.remote
        def wait_for(t):
            return ray.get(t[0])

        obj1 = wait_for.remote([signaler.wait.remote()])
        obj2 = wait_for.remote([obj1])
        obj3 = wait_for.remote([obj2])
        obj4 = wait_for.remote([obj3])

        assert len(ray.wait([obj1], timeout=.1)[0]) == 0
        ray.cancel(obj1, force=use_force)
        for ob in [obj1, obj2, obj3, obj4]:
            with pytest.raises(valid_exceptions(use_force)):
                ray.get(ob)

        signaler2 = SignalActor.remote()
        obj1 = wait_for.remote([signaler2.wait.remote()])
        obj2 = wait_for.remote([obj1])
        obj3 = wait_for.remote([obj2])
        obj4 = wait_for.remote([obj3])

        assert len(ray.wait([obj3], timeout=.1)[0]) == 0
        ray.cancel(obj3, force=use_force)
        for ob in [obj3, obj4]:
            with pytest.raises(valid_exceptions(use_force)):
                ray.get(ob)

        with pytest.raises(GetTimeoutError):
            ray.get(obj1, timeout=.1)

        with pytest.raises(GetTimeoutError):
            ray.get(obj2, timeout=.1)

        signaler2.send.remote()
        ray.get(obj1)