예제 #1
0
파일: test_events.py 프로젝트: tchordia/ray
def test_wait_for_multiple_events(workflow_start_regular_shared):
    """If a workflow has multiple event arguments, it should wait for them at the
    same time.
    """
    class EventListener1(workflow.EventListener):
        async def poll_for_event(self):
            utils.set_global_mark("listener1")
            while not utils.check_global_mark("trigger_event"):
                await asyncio.sleep(0.1)
            return "event1"

    class EventListener2(workflow.EventListener):
        async def poll_for_event(self):
            utils.set_global_mark("listener2")
            while not utils.check_global_mark("trigger_event"):
                await asyncio.sleep(0.1)
            return "event2"

    @ray.remote
    def trivial_step(arg1, arg2):
        return f"{arg1} {arg2}"

    event1_promise = workflow.wait_for_event(EventListener1)
    event2_promise = workflow.wait_for_event(EventListener2)

    promise = workflow.create(trivial_step.bind(event1_promise,
                                                event2_promise)).run_async()

    while not (utils.check_global_mark("listener1")
               and utils.check_global_mark("listener2")):
        time.sleep(0.1)

    utils.set_global_mark("trigger_event")
    assert ray.get(promise) == "event1 event2"
예제 #2
0
파일: test_events.py 프로젝트: tchordia/ray
def test_event_after_arg_resolution(workflow_start_regular_shared):
    """Ensure that a workflow resolves all of its non-event arguments while it
    waiting the the event to occur.
    """
    class MyEventListener(workflow.EventListener):
        async def poll_for_event(self):
            while not utils.check_global_mark():
                await asyncio.sleep(0.1)
            # Give the other step time to finish.
            await asyncio.sleep(1)

    @ray.remote
    def triggers_event():
        utils.set_global_mark()

    @ray.remote
    def gather(*args):
        return args

    event_promise = workflow.wait_for_event(MyEventListener)

    assert workflow.create(gather.bind(event_promise,
                                       triggers_event.bind())).run() == (
                                           None,
                                           None,
                                       )
예제 #3
0
파일: test_events.py 프로젝트: tchordia/ray
def test_event_during_arg_resolution(workflow_start_regular_shared):
    """If a workflow's arguments are being executed when the event occurs, the
    workflow should run immediately with no issues.
    """
    class MyEventListener(workflow.EventListener):
        async def poll_for_event(self):
            while not utils.check_global_mark():
                await asyncio.sleep(0.1)
            utils.set_global_mark("event_returning")

    @ray.remote
    def triggers_event():
        utils.set_global_mark()
        while not utils.check_global_mark("event_returning"):
            time.sleep(0.1)

    @ray.remote
    def gather(*args):
        return args

    event_promise = workflow.wait_for_event(MyEventListener)
    assert workflow.create(gather.bind(event_promise,
                                       triggers_event.bind())).run() == (
                                           None,
                                           None,
                                       )
예제 #4
0
def test_crash_after_commit(workflow_start_regular_shared):
    _storage = storage.get_global_storage()
    """Ensure that we don't re-call poll_for_event after `event_checkpointed`
       returns, even after a crash. Here we must call `event_checkpointed`
       twice, because there's no way to know if we called it after
       checkpointing.

    """
    class MyEventListener(workflow.EventListener):
        async def poll_for_event(self):
            assert not utils.check_global_mark("committed")

        async def event_checkpointed(self, event):
            utils.set_global_mark("committed")
            if utils.check_global_mark("first"):
                utils.set_global_mark("second")
            else:
                utils.set_global_mark("first")
                await asyncio.sleep(1000000)

    event_promise = workflow.wait_for_event(MyEventListener)
    event_promise.run_async("workflow")

    while not utils.check_global_mark("first"):
        time.sleep(0.1)

    ray.shutdown()
    subprocess.check_output(["ray", "stop", "--force"])

    ray.init(num_cpus=4)
    workflow.init(storage=_storage)
    workflow.resume("workflow")

    ray.get(workflow.get_output("workflow"))
    assert utils.check_global_mark("second")
예제 #5
0
def test_event_as_workflow(workflow_start_regular_shared):
    class MyEventListener(workflow.EventListener):
        async def poll_for_event(self):
            while not utils.check_global_mark():
                await asyncio.sleep(1)

    utils.unset_global_mark()
    promise = workflow.wait_for_event(MyEventListener).run_async("wf")

    assert workflow.get_status("wf") == workflow.WorkflowStatus.RUNNING

    utils.set_global_mark()
    assert ray.get(promise) is None
예제 #6
0
def test_crash_during_event_checkpointing(workflow_start_regular_shared):
    """Ensure that if the cluster dies while the event is being checkpointed, we
    properly re-poll for the event."""

    from ray._private import storage

    storage_uri = storage._storage_uri

    """Ensure that we don't re-call poll_for_event after `event_checkpointed`
       returns, even after a crash."""

    class MyEventListener(workflow.EventListener):
        async def poll_for_event(self):
            assert not utils.check_global_mark("committed")
            if utils.check_global_mark("first"):
                utils.set_global_mark("second")
            utils.set_global_mark("first")

            utils.set_global_mark("time_to_die")
            while not utils.check_global_mark("resume"):
                time.sleep(0.1)

        async def event_checkpointed(self, event):
            utils.set_global_mark("committed")

    @ray.remote
    def wait_then_finish(arg):
        pass

    event_promise = workflow.wait_for_event(MyEventListener)
    workflow.run_async(wait_then_finish.bind(event_promise), workflow_id="workflow")

    while not utils.check_global_mark("time_to_die"):
        time.sleep(0.1)

    assert utils.check_global_mark("first")
    ray.shutdown()
    subprocess.check_output(["ray", "stop", "--force"])

    # Give the workflow some time to kill the cluster.
    # time.sleep(3)

    ray.init(num_cpus=4, storage=storage_uri)
    workflow.init()
    workflow.resume_async("workflow")
    utils.set_global_mark("resume")

    workflow.get_output("workflow")
    assert utils.check_global_mark("second")
예제 #7
0
파일: test_events.py 프로젝트: tchordia/ray
def test_types(workflow_start_regular_shared):
    class NotAnEventListener:
        pass

    with pytest.raises(TypeError):
        workflow.wait_for_event(NotAnEventListener)