Ejemplo n.º 1
0
    def test_task_runner_handles_outputs_prior_to_setting_state(self, client):
        @prefect.task(cache_for=datetime.timedelta(days=1),
                      result_handler=JSONResultHandler())
        def add(x, y):
            return x + y

        result = Result(1, result_handler=JSONResultHandler())
        assert result.safe_value is NoResult

        x_state, y_state = Success(result=result), Success(result=result)

        upstream_states = {
            Edge(Task(), Task(), key="x"): x_state,
            Edge(Task(), Task(), key="y"): y_state,
        }

        res = CloudTaskRunner(task=add).run(upstream_states=upstream_states)
        assert result.safe_value != NoResult  # proves was handled

        ## assertions
        assert client.get_task_run_info.call_count == 0  # never called
        assert (client.set_task_run_state.call_count == 3
                )  # Pending -> Running -> Successful -> Cached

        states = [
            call[1]["state"]
            for call in client.set_task_run_state.call_args_list
        ]
        assert states[0].is_running()
        assert states[1].is_successful()
        assert isinstance(states[2], Cached)
        assert states[2].cached_inputs == dict(x=result, y=result)
        assert states[2].result == 2
Ejemplo n.º 2
0
    def test_task_runner_sends_checkpointed_success_states_to_cloud(
            self, client):
        handler = JSONResultHandler()

        @prefect.task(checkpoint=True, result_handler=handler)
        def add(x, y):
            return x + y

        x_state, y_state = Success(result=Result(1)), Success(result=Result(1))

        upstream_states = {
            Edge(Task(), Task(), key="x"): x_state,
            Edge(Task(), Task(), key="y"): y_state,
        }

        res = CloudTaskRunner(task=add).run(upstream_states=upstream_states)

        ## assertions
        assert client.get_task_run_info.call_count == 0  # never called
        assert (client.set_task_run_state.call_count == 2
                )  # Pending -> Running -> Successful

        states = [
            call[1]["state"]
            for call in client.set_task_run_state.call_args_list
        ]
        assert states[0].is_running()
        assert states[1].is_successful()
        assert states[1]._result.safe_value == SafeResult(
            "2", result_handler=handler)
Ejemplo n.º 3
0
def test_task_runner_sets_mapped_state_prior_to_executor_mapping(client):
    upstream_states = {
        Edge(Task(), Task(), key="foo", mapped=True): Success(result=[1, 2])
    }

    class MyExecutor(prefect.engine.executors.LocalExecutor):
        def map(self, *args, **kwargs):
            raise SyntaxError("oops")

    with pytest.raises(SyntaxError):
        CloudTaskRunner(task=Task()).run_mapped_task(
            state=Pending(),
            upstream_states=upstream_states,
            context={},
            executor=MyExecutor(),
        )

    ## assertions
    assert client.get_task_run_info.call_count == 0  # never called
    assert client.set_task_run_state.call_count == 1  # Pending -> Mapped
    assert client.get_latest_cached_states.call_count == 0

    last_set_state = client.set_task_run_state.call_args_list[-1][1]["state"]
    assert last_set_state.map_states == [None, None]
    assert last_set_state.is_mapped()
    assert "Preparing to submit 2" in last_set_state.message
Ejemplo n.º 4
0
def test_cloud_task_runner_handles_retries_with_queued_states_from_cloud(
        client):
    calls = []

    def queued_mock(*args, **kwargs):
        calls.append(kwargs)
        # first retry attempt will get queued
        if len(calls) == 4:
            return Queued()  # immediate start time
        else:
            return kwargs.get("state")

    client.set_task_run_state = queued_mock

    @prefect.task(
        max_retries=2,
        retry_delay=datetime.timedelta(seconds=0),
        result=PrefectResult(),
    )
    def tagged_task(x):
        if prefect.context.get("task_run_count", 1) == 1:
            raise ValueError("gimme a sec")
        return x

    upstream_result = PrefectResult(value=42, location="42")
    res = CloudTaskRunner(task=tagged_task).run(
        context={"task_run_version": 1},
        state=None,
        upstream_states={
            Edge(Task(), tagged_task, key="x"): Success(result=upstream_result)
        },
    )

    assert res.is_successful()
    assert res.result == 42
    assert (len(calls) == 6
            )  # Running -> Failed -> Retrying -> Queued -> Running -> Success
    assert [type(c["state"]).__name__ for c in calls] == [
        "Running",
        "Failed",
        "Retrying",
        "Running",
        "Running",
        "Success",
    ]

    assert calls[2]["state"].cached_inputs["x"].value == 42
Ejemplo n.º 5
0
def test_task_runner_doesnt_call_client_if_map_index_is_none(client):
    task = Task(name="test")

    res = CloudTaskRunner(task=task).run()

    ## assertions
    assert client.get_task_run_info.call_count == 0  # never called
    assert client.set_task_run_state.call_count == 2  # Pending -> Running -> Success
    assert client.get_latest_cached_states.call_count == 0

    states = [
        call[1]["state"] for call in client.set_task_run_state.call_args_list
    ]
    assert [type(s).__name__ for s in states] == ["Running", "Success"]
    assert res.is_successful()
    assert states[0].context == dict(tags=[])
    assert states[1].context == dict(tags=[])
Ejemplo n.º 6
0
def test_task_runner_raises_endrun_if_client_cant_communicate_during_state_updates(
        monkeypatch):
    @prefect.task(name="test")
    def raise_error():
        raise NameError("I don't exist")

    get_task_run_info = MagicMock(return_value=MagicMock(state=None))
    set_task_run_state = MagicMock(side_effect=SyntaxError)
    client = MagicMock(get_task_run_info=get_task_run_info,
                       set_task_run_state=set_task_run_state)
    monkeypatch.setattr("prefect.engine.cloud.task_runner.Client",
                        MagicMock(return_value=client))

    ## an ENDRUN will cause the TaskRunner to return the most recently computed state
    res = CloudTaskRunner(task=raise_error).run(context={"map_index": 1})
    assert set_task_run_state.called
    assert res.is_running()
Ejemplo n.º 7
0
def test_task_handlers_handle_retry_signals(client):
    def state_handler(t, o, n):
        if n.is_failed():
            raise prefect.engine.signals.RETRY(
                "Will retry.", start_time=pendulum.now("utc").add(days=1))

    @prefect.task(state_handlers=[state_handler])
    def fn():
        1 / 0

    state = CloudTaskRunner(task=fn).run()

    assert state.is_retrying()
    assert state.run_count == 1

    # to make it run
    state.start_time = pendulum.now("utc")
    new_state = CloudTaskRunner(task=fn).run(state=state)
    assert new_state.is_retrying()
    assert new_state.run_count == 2

    states = [
        call[1]["state"] for call in client.set_task_run_state.call_args_list
    ]
    assert [type(s).__name__ for s in states] == ["Running", "Retrying"] * 2
Ejemplo n.º 8
0
def test_task_runner_uses_upstream_result_handlers(client):
    class MyResult(Result):
        def read(self, *args, **kwargs):
            self.value = "cool"
            return self

        def write(self, *args, **kwargs):
            return self

    @prefect.task(result=PrefectResult())
    def t(x):
        return x

    success = Success(result=PrefectResult(location="1"))
    upstream_states = {Edge(Task(result=MyResult()), t, key="x"): success}
    state = CloudTaskRunner(task=t).run(upstream_states=upstream_states)
    assert state.is_successful()
    assert state.result == "cool"
Ejemplo n.º 9
0
def test_state_handler_failures_are_handled_appropriately(client):
    def bad(*args, **kwargs):
        raise SyntaxError("Syntax Errors are nice because they're so unique")

    @prefect.task(on_failure=bad)
    def do_nothing():
        raise ValueError("This task failed somehow")

    res = CloudTaskRunner(task=do_nothing).run()
    assert res.is_failed()
    assert "SyntaxError" in res.message
    assert isinstance(res.result, SyntaxError)

    assert client.set_task_run_state.call_count == 2
    states = [call[1]["state"] for call in client.set_task_run_state.call_args_list]
    assert states[0].is_running()
    assert states[1].is_failed()
    assert isinstance(states[1].result, SyntaxError)
Ejemplo n.º 10
0
def test_task_runner_gracefully_handles_load_results_failures(client):
    class MyResult(Result):
        def read(self, *args, **kwargs):
            raise TypeError("something is wrong!")

    @prefect.task(result=PrefectResult())
    def t(x):
        return x

    success = Success(result=MyResult(location="foo.txt"))
    upstream_states = {Edge(Task(result=MyResult()), t, key="x"): success}
    state = CloudTaskRunner(task=t).run(upstream_states=upstream_states)

    assert state.is_failed()
    assert "task results" in state.message
    assert client.set_task_run_state.call_count == 1  # Pending -> Failed

    states = [call[1]["state"] for call in client.set_task_run_state.call_args_list]
    assert [type(s).__name__ for s in states] == ["Failed"]
Ejemplo n.º 11
0
def test_task_runner_uses_cached_inputs_from_db_state(monkeypatch):
    @prefect.task(name="test")
    def add_one(x):
        return x + 1

    db_state = Retrying(cached_inputs=dict(x=Result(41)))
    get_task_run_info = MagicMock(return_value=MagicMock(state=db_state))
    set_task_run_state = MagicMock()
    client = MagicMock(get_task_run_info=get_task_run_info,
                       set_task_run_state=set_task_run_state)
    monkeypatch.setattr("prefect.engine.cloud.task_runner.Client",
                        MagicMock(return_value=client))
    res = CloudTaskRunner(task=add_one).run(context={"map_index": 1})

    ## assertions
    assert get_task_run_info.call_count == 1  # one time to pull latest state
    assert set_task_run_state.call_count == 2  # Pending -> Running -> Success
    assert res.is_successful()
    assert res.result == 42
Ejemplo n.º 12
0
def test_task_runner_places_task_tags_in_state_context_and_serializes_them(client):
    task = Task(name="test", tags=["1", "2", "tag"])

    res = CloudTaskRunner(task=task).run()

    call_args = [c[1] for c in client.set_task_run_state.call_args_list]

    assert call_args[0]["state"].is_running()
    assert call_args[1]["state"].is_successful()
    assert set(call_args[0]["state"].context["tags"]) == set(["1", "2", "tag"])
    assert set(call_args[1]["state"].context["tags"]) == set(["1", "2", "tag"])
Ejemplo n.º 13
0
def test_task_runner_set_task_name_same_as_prefect_context(client):
    @prefect.task(name="hey", task_run_name=lambda **kwargs: kwargs["config"])
    def test_task(config):
        return

    edge = Edge(Task(), Task(), key="config")
    state = Success(result="any_value")
    res = CloudTaskRunner(task=test_task).run(upstream_states={edge: state})

    assert client.set_task_run_name.call_count == 1
    assert client.set_task_run_name.call_args[1]["name"] == "any_value"
Ejemplo n.º 14
0
    def test_load_results_from_upstream_reads_results(self):
        result = PrefectResult(location="1")
        state = Success(result=result)

        assert result.value is None

        t = Task(result=PrefectResult())
        edge = Edge(t, 2, key="x")
        new_state, upstreams = CloudTaskRunner(task=Task()).load_results(
            state=Pending(), upstream_states={edge: state})
        assert upstreams[edge].result == 1
Ejemplo n.º 15
0
def test_task_runner_places_task_tags_in_state_context_and_serializes_them(monkeypatch):
    task = Task(name="test", tags=["1", "2", "tag"])
    session = MagicMock()
    monkeypatch.setattr("prefect.client.client.GraphQLResult", MagicMock())
    monkeypatch.setattr("requests.Session", MagicMock(return_value=session))

    res = CloudTaskRunner(task=task).run()
    assert res.is_successful()

    ## extract the variables payload from the calls to POST
    call_vars = [
        json.loads(call[1]["json"]["variables"]) for call in session.post.call_args_list
    ]

    # do some mainpulation to get the state payloads
    inputs = [c["input"]["states"][0] for c in call_vars if c is not None]
    assert inputs[0]["state"]["type"] == "Running"
    assert set(inputs[0]["state"]["context"]["tags"]) == set(["1", "2", "tag"])
    assert inputs[-1]["state"]["type"] == "Success"
    assert set(inputs[-1]["state"]["context"]["tags"]) == set(["1", "2", "tag"])
Ejemplo n.º 16
0
    def test_load_results_from_upstream_reads_results_using_upstream_handlers(
            self, cloud_api):
        class CustomResult(Result):
            def read(self, *args, **kwargs):
                return "foo-bar-baz".split("-")

        state = Success(result=PrefectResult(location="1"))
        edge = Edge(Task(result=CustomResult()), 2, key="x")
        new_state, upstreams = CloudTaskRunner(task=Task()).load_results(
            state=Pending(), upstream_states={edge: state})
        assert upstreams[edge].result == ["foo", "bar", "baz"]
Ejemplo n.º 17
0
def test_task_runner_calls_get_task_run_info_if_map_index_is_not_none(client):
    task = Task(name="test")

    res = CloudTaskRunner(task=task).run(context={"map_index": 1})

    ## assertions
    assert client.get_task_run_info.call_count == 1  # never called
    assert client.set_task_run_state.call_count == 2  # Pending -> Running -> Success

    states = [call[1]["state"] for call in client.set_task_run_state.call_args_list]
    assert [type(s).__name__ for s in states] == ["Running", "Success"]
Ejemplo n.º 18
0
    def test_load_results_from_upstream_reads_secret_results(self, cloud_api):
        secret_result = SecretResult(
            prefect.tasks.secrets.PrefectSecret(name="foo"))

        state = Success(result=PrefectResult(location="foo"))

        with prefect.context(secrets=dict(foo=42)):
            edge = Edge(Task(result=secret_result), 2, key="x")
            new_state, upstreams = CloudTaskRunner(task=Task()).load_results(
                state=Pending(), upstream_states={edge: state})

        assert upstreams[edge].result == 42
Ejemplo n.º 19
0
def test_task_runner_handles_version_lock_error(monkeypatch):
    client = MagicMock()
    monkeypatch.setattr("prefect.engine.cloud.task_runner.Client",
                        MagicMock(return_value=client))
    client.set_task_run_state.side_effect = VersionLockError()

    task = Task(name="test")
    runner = CloudTaskRunner(task=task)

    # successful state
    client.get_task_run_state.return_value = Success()
    res = runner.call_runner_target_handlers(Pending(), Running())
    assert res.is_successful()

    # currently running
    client.get_task_run_state.return_value = Running()
    with pytest.raises(ENDRUN):
        runner.call_runner_target_handlers(Pending(), Running())

    # result load error
    s = Success()
    s.load_result = MagicMock(side_effect=Exception())
    client.get_task_run_state.return_value = s
    with pytest.raises(ENDRUN):
        res = runner.call_runner_target_handlers(Pending(), Running())
def test_task_runner_handles_looping(client):
    @prefect.task(result_handler=ResultHandler())
    def looper():
        if prefect.context.get("task_loop_count", 1) < 3:
            raise LOOP(result=prefect.context.get("task_loop_result", 0) + 10)
        return prefect.context.get("task_loop_result")

    res = CloudTaskRunner(task=looper).run(
        context={"task_run_version": 1},
        state=None,
        upstream_states={},
        executor=prefect.engine.executors.LocalExecutor(),
    )

    ## assertions
    assert res.is_successful()
    assert client.get_task_run_info.call_count == 0
    assert (
        client.set_task_run_state.call_count == 6
    )  # Pending -> Running -> Looped (1) -> Running -> Looped (2) -> Running -> Success
    versions = [call[1]["version"] for call in client.set_task_run_state.call_args_list]
    assert versions == [1, 2, 3, 4, 5, 6]
Ejemplo n.º 21
0
def test_task_runner_validates_cached_states_if_task_has_caching(client):
    @prefect.task(cache_for=datetime.timedelta(minutes=1), result=PrefectResult())
    def cached_task():
        return 42

    state = Cached(
        cached_result_expiration=datetime.datetime.utcnow()
        - datetime.timedelta(minutes=2),
        result=PrefectResult(location="99"),
    )
    old_state = Cached(
        cached_result_expiration=datetime.datetime.utcnow()
        - datetime.timedelta(days=1),
        result=PrefectResult(location="13"),
    )
    client.get_latest_cached_states = MagicMock(return_value=[state, old_state])

    res = CloudTaskRunner(task=cached_task).run()
    assert client.get_latest_cached_states.called
    assert res.is_successful()
    assert res.is_cached()
    assert res.result == 42
Ejemplo n.º 22
0
def test_task_runner_treats_unfound_files_as_invalid_caches(client, tmpdir):
    @prefect.task(cache_for=datetime.timedelta(minutes=1), result=PrefectResult())
    def cached_task():
        return 42

    state = Cached(
        cached_result_expiration=datetime.datetime.utcnow()
        + datetime.timedelta(minutes=2),
        result=LocalResult(location=str(tmpdir / "made_up_data.prefect")),
    )
    old_state = Cached(
        cached_result_expiration=datetime.datetime.utcnow()
        + datetime.timedelta(days=1),
        result=PrefectResult(location="13"),
    )
    client.get_latest_cached_states = MagicMock(return_value=[state, old_state])

    res = CloudTaskRunner(task=cached_task).run()
    assert client.get_latest_cached_states.called
    assert res.is_successful()
    assert res.is_cached()
    assert res.result == 13
Ejemplo n.º 23
0
def test_task_runner_handles_looping_with_no_result(client):
    @prefect.task(result_handler=ResultHandler())
    def looper():
        if prefect.context.get("task_loop_count", 1) < 3:
            raise LOOP()
        return 42

    res = CloudTaskRunner(task=looper).run(context={"task_run_version": 1},
                                           state=None,
                                           upstream_states={})

    ## assertions
    assert res.is_successful()
    assert client.get_task_run_info.call_count == 0
    assert (
        client.set_task_run_state.call_count == 6
    )  # Pending -> Running -> Looped (1) -> Running -> Looped (2) -> Running -> Success
    versions = [
        call[1]["version"] for call in client.set_task_run_state.call_args_list
        if call[1]["version"]
    ]
    assert versions == [1, 3, 5]
Ejemplo n.º 24
0
def test_task_runner_queries_for_cached_states_if_task_has_caching(client):
    @prefect.task(cache_for=datetime.timedelta(minutes=1))
    def cached_task():
        return 42

    state = Cached(
        cached_result_expiration=datetime.datetime.utcnow()
        + datetime.timedelta(days=1),
        result=Result(99, JSONResultHandler()),
    )
    old_state = Cached(
        cached_result_expiration=datetime.datetime.utcnow()
        - datetime.timedelta(days=1),
        result=13,
    )
    client.get_latest_cached_states = MagicMock(return_value=[state, old_state])

    res = CloudTaskRunner(task=cached_task).run()
    assert client.get_latest_cached_states.called
    assert res.is_successful()
    assert res.is_cached()
    assert res.result == 99
Ejemplo n.º 25
0
def test_task_runner_performs_retries_for_short_delays(client):
    global_list = []

    @prefect.task(max_retries=1, retry_delay=datetime.timedelta(seconds=0))
    def noop():
        if global_list:
            return
        else:
            global_list.append(0)
            raise ValueError("oops")

    res = CloudTaskRunner(task=noop).run(
        state=None,
        upstream_states={},
        executor=prefect.engine.executors.LocalExecutor(),
    )

    ## assertions
    assert res.is_successful()
    assert client.get_task_run_info.call_count == 0  # never called
    assert (client.set_task_run_state.call_count == 5
            )  # Pending -> Running -> Failed -> Retrying -> Running -> Success
    def test_task_runner_validates_cached_state_inputs_if_task_has_caching_and_uses_task_handler(
        self, client
    ):
        class MyResult(Result):
            def read(self, *args, **kwargs):
                new = self.copy()
                new.value = 1337
                return new

        @prefect.task(
            cache_for=datetime.timedelta(minutes=1),
            cache_validator=all_inputs,
            result=MyResult(),
        )
        def cached_task(x):
            return 42

        dull_state = Cached(
            cached_result_expiration=datetime.datetime.utcnow()
            + datetime.timedelta(minutes=2),
            result=PrefectResult(location="-1"),
        )
        state = Cached(
            cached_result_expiration=datetime.datetime.utcnow()
            + datetime.timedelta(minutes=2),
            result=PrefectResult(location="99"),
            cached_inputs={"x": PrefectResult(location="2")},
        )
        client.get_latest_cached_states = MagicMock(return_value=[dull_state, state])

        res = CloudTaskRunner(task=cached_task).check_task_is_cached(
            Pending(), inputs={"x": PrefectResult(value=2)}
        )
        assert client.get_latest_cached_states.called
        assert res.is_successful()
        assert res.is_cached()
        assert res.result == 1337
Ejemplo n.º 27
0
def test_task_runner_handles_looping_with_retries(client):
    # note that looping _requires_ a result handler in Cloud
    @prefect.task(
        max_retries=1,
        retry_delay=datetime.timedelta(seconds=0),
        result=PrefectResult(),
    )
    def looper():
        if (
            prefect.context.get("task_loop_count") == 2
            and prefect.context.get("task_run_count", 1) == 1
        ):
            raise ValueError("Stop")
        if prefect.context.get("task_loop_count", 1) < 3:
            raise LOOP(result=prefect.context.get("task_loop_result", 0) + 10)
        return prefect.context.get("task_loop_result")

    client.get_task_run_info.side_effect = [
        MagicMock(version=i, state=Pending() if i == 0 else Looped(loop_count=i))
        for i in range(5)
    ]
    res = CloudTaskRunner(task=looper).run(
        context={"task_run_version": 1}, state=None, upstream_states={}
    )

    ## assertions
    assert res.is_successful()
    assert client.get_task_run_info.call_count == 4
    assert (
        client.set_task_run_state.call_count == 9
    )  # Pending -> Running -> Looped (1) -> Running -> Failed -> Retrying -> Running -> Looped(2) -> Running -> Success
    versions = [
        call[1]["version"]
        for call in client.set_task_run_state.call_args_list
        if call[1]["version"]
    ]
    assert versions == [1, 2, 3]
Ejemplo n.º 28
0
def test_task_runner_respects_the_db_state(monkeypatch, state):
    task = Task(name="test")
    db_state = state("already", result=10)
    get_task_run_info = MagicMock(return_value=MagicMock(state=db_state))
    set_task_run_state = MagicMock()
    client = MagicMock(get_task_run_info=get_task_run_info,
                       set_task_run_state=set_task_run_state)
    monkeypatch.setattr("prefect.engine.cloud.task_runner.Client",
                        MagicMock(return_value=client))
    res = CloudTaskRunner(task=task).run(context={"map_index": 1})

    ## assertions
    assert get_task_run_info.call_count == 1  # one time to pull latest state
    assert set_task_run_state.call_count == 0  # never needs to update state
    assert res == db_state
Ejemplo n.º 29
0
def test_task_runner_raises_endrun_with_correct_state_if_client_cant_receive_state_updates(
    monkeypatch, ):
    task = Task(name="test")
    get_task_run_info = MagicMock(side_effect=SyntaxError)
    set_task_run_state = MagicMock()
    client = MagicMock(get_task_run_info=get_task_run_info,
                       set_task_run_state=set_task_run_state)
    monkeypatch.setattr("prefect.engine.cloud.task_runner.Client",
                        MagicMock(return_value=client))

    ## an ENDRUN will cause the TaskRunner to return the most recently computed state
    state = Pending(message="unique message", result=42)
    res = CloudTaskRunner(task=task).run(state=state, context={"map_index": 1})
    assert get_task_run_info.called
    assert res is state
Ejemplo n.º 30
0
    def test_load_results_from_upstream_reads_cached_inputs_using_upstream_results(
        self, ):
        class CustomResult(Result):
            def read(self, *args, **kwargs):
                self.value = 99
                return self

        result = PrefectResult(location="1")
        state = Pending(cached_inputs=dict(x=result))
        edge = Edge(Task(result=CustomResult()), 2, key="x")
        new_state, upstreams = CloudTaskRunner(task=Task(
            result=PrefectResult())).load_results(
                state=state, upstream_states={edge: Success(result=result)})

        assert new_state.cached_inputs["x"].value == 99