Ejemplo n.º 1
0
def test_cloud_task_runner_respects_queued_states_from_cloud(client):
    calls = []

    def queued_mock(*args, **kwargs):
        calls.append(kwargs)
        if len(calls) == 1:
            return Queued()  # immediate start time
        else:
            return kwargs.get("state")

    client.set_task_run_state = queued_mock

    @prefect.task
    def tagged_task():
        pass

    res = CloudTaskRunner(task=tagged_task).run(
        context={"task_run_version": 1},
        state=None,
        upstream_states={},
    )

    assert res.is_successful()
    assert len(calls) == 3  # Running -> Running -> Success
    assert [type(c["state"]).__name__ for c in calls] == [
        "Running",
        "Running",
        "Success",
    ]
Ejemplo n.º 2
0
def test_task_runner_uses_cached_inputs_from_db_state(monkeypatch):
    @prefect.task(name="test", result_handler=JSONResultHandler())
    def add_one(x):
        return x + 1

    db_state = Retrying(
        cached_inputs=dict(x=Result(41, result_handler=JSONResultHandler()))
    )
    get_task_run_info = MagicMock(return_value=MagicMock(state=db_state))
    set_task_run_state = MagicMock(
        side_effect=lambda task_run_id, version, state, cache_for: state
    )
    client = MagicMock(
        get_task_run_info=get_task_run_info, set_task_run_state=set_task_run_state
    )
    monkeypatch.setattr(
        "prefect.engine.cloud.task_runner.Client", MagicMock(return_value=client)
    )
    res = CloudTaskRunner(task=add_one).run(context={"map_index": 1})

    ## assertions
    assert get_task_run_info.call_count == 1  # one time to pull latest state
    assert set_task_run_state.call_count == 2  # Pending -> Running -> Success
    assert res.is_successful()
    assert res.result == 42
Ejemplo n.º 3
0
def test_task_runner_validates_cached_state_inputs_if_task_has_caching(client):
    @prefect.task(
        cache_for=datetime.timedelta(minutes=1),
        cache_validator=all_inputs,
        result_handler=JSONResultHandler(),
    )
    def cached_task(x):
        return 42

    dull_state = Cached(
        cached_result_expiration=datetime.datetime.utcnow() +
        datetime.timedelta(minutes=2),
        result=Result(-1, JSONResultHandler()),
    )
    state = Cached(
        cached_result_expiration=datetime.datetime.utcnow() +
        datetime.timedelta(minutes=2),
        result=Result(99, JSONResultHandler()),
        cached_inputs={
            "x": SafeResult("2", result_handler=JSONResultHandler())
        },
    )
    client.get_latest_cached_states = MagicMock(
        return_value=[dull_state, state])

    res = CloudTaskRunner(task=cached_task).check_task_is_cached(
        Pending(),
        inputs={"x": Result(2, result_handler=LocalResultHandler())})
    assert client.get_latest_cached_states.called
    assert res.is_successful()
    assert res.is_cached()
    assert res.result == 99
Ejemplo n.º 4
0
def test_task_runner_handles_looping_with_retries_with_no_result(client):
    # note that looping with retries _requires_ a result handler in Cloud
    @prefect.task(
        max_retries=1,
        retry_delay=datetime.timedelta(seconds=0),
        result=PrefectResult(),
    )
    def looper():
        if (prefect.context.get("task_loop_count") == 2
                and prefect.context.get("task_run_count", 1) == 1):
            raise ValueError("Stop")
        if prefect.context.get("task_loop_count", 1) < 3:
            raise LOOP()
        return 42

    client.get_task_run_info.side_effect = [
        MagicMock(version=i,
                  state=Pending() if i == 0 else Looped(loop_count=i))
        for i in range(5)
    ]
    res = CloudTaskRunner(task=looper).run(context={"task_run_version": 1},
                                           state=None,
                                           upstream_states={})

    ## assertions
    assert res.is_successful()
    assert client.get_task_run_info.call_count == 4
    assert (
        client.set_task_run_state.call_count == 9
    )  # Pending -> Running -> Looped (1) -> Running -> Failed -> Retrying -> Running -> Looped(2) -> Running -> Success
    versions = [
        call[1]["version"] for call in client.set_task_run_state.call_args_list
        if call[1]["version"]
    ]
    assert versions == [1, 2, 3]
Ejemplo n.º 5
0
def test_task_runner_handles_looping_with_no_result(client):
    @prefect.task(result=Result())
    def looper():
        if prefect.context.get("task_loop_count", 1) < 3:
            raise LOOP()
        return 42

    client.get_task_run_info.side_effect = [
        MagicMock(version=i, state=Pending()) for i in range(1, 4)
    ]
    res = CloudTaskRunner(task=looper).run(
        context={"task_run_version": 1}, state=None, upstream_states={}
    )

    ## assertions
    assert res.is_successful()
    assert client.get_task_run_info.call_count == 3
    assert (
        client.set_task_run_state.call_count == 6
    )  # Pending -> Running -> Looped (1) -> Running -> Looped (2) -> Running -> Success
    versions = [
        call[1]["version"]
        for call in client.set_task_run_state.call_args_list
        if call[1]["version"]
    ]
    assert versions == [1, 2, 3]
Ejemplo n.º 6
0
    def test_task_runner_validates_cached_state_inputs_if_task_has_caching(
            self, client):
        @prefect.task(
            cache_for=datetime.timedelta(minutes=1),
            cache_validator=all_inputs,
            result=PrefectResult(),
        )
        def cached_task(x):
            return 42

        dull_state = Cached(
            cached_result_expiration=datetime.datetime.utcnow() +
            datetime.timedelta(minutes=2),
            result=PrefectResult(location="-1"),
        )
        state = Cached(
            cached_result_expiration=datetime.datetime.utcnow() +
            datetime.timedelta(minutes=2),
            result=PrefectResult(location="99"),
            cached_inputs={"x": PrefectResult(location="2")},
        )
        client.get_latest_cached_states = MagicMock(
            return_value=[dull_state, state])

        res = CloudTaskRunner(task=cached_task).check_task_is_cached(
            Pending(), inputs={"x": PrefectResult(value=2)})
        assert client.get_latest_cached_states.called
        assert res.is_successful()
        assert res.is_cached()
        assert res.result == 99
Ejemplo n.º 7
0
def test_task_runner_performs_retries_for_short_delays(client):
    global_list = []

    @prefect.task(max_retries=1, retry_delay=datetime.timedelta(seconds=0))
    def noop():
        if global_list:
            return
        else:
            global_list.append(0)
            raise ValueError("oops")

    client.get_task_run_info.side_effect = [
        MagicMock(version=i) for i in range(4, 7)
    ]
    res = CloudTaskRunner(task=noop).run(
        context={"task_run_version": 1},
        state=None,
        upstream_states={},
    )

    ## assertions
    assert res.is_successful()
    assert client.get_task_run_info.call_count == 1  # called once on the retry
    assert (client.set_task_run_state.call_count == 5
            )  # Pending -> Running -> Failed -> Retrying -> Running -> Success
    versions = [
        call[1]["version"] for call in client.set_task_run_state.call_args_list
    ]
    assert versions == [1, 2, 3, 4, 5]
Ejemplo n.º 8
0
def test_task_runner_handles_looping_with_retries(client):
    # note that looping _requires_ a result handler in Cloud
    @prefect.task(
        max_retries=1,
        retry_delay=datetime.timedelta(seconds=0),
        result_handler=JSONResultHandler(),
    )
    def looper():
        if (prefect.context.get("task_loop_count") == 2
                and prefect.context.get("task_run_count", 1) == 1):
            raise ValueError("Stop")
        if prefect.context.get("task_loop_count", 1) < 3:
            raise LOOP(result=prefect.context.get("task_loop_result", 0) + 10)
        return prefect.context.get("task_loop_result")

    client.get_task_run_info.side_effect = [
        MagicMock(version=i) for i in range(6, 9)
    ]
    res = CloudTaskRunner(task=looper).run(
        context={"task_run_version": 1},
        state=None,
        upstream_states={},
    )

    ## assertions
    assert res.is_successful()
    assert client.get_task_run_info.call_count == 1  # called once for retry
    assert (
        client.set_task_run_state.call_count == 9
    )  # Pending -> Running -> Looped (1) -> Running -> Failed -> Retrying -> Running -> Looped(2) -> Running -> Success
    versions = [
        call[1]["version"] for call in client.set_task_run_state.call_args_list
    ]
    assert versions == [1, 2, 3, 4, 5, 6, 7, 8, 9]
Ejemplo n.º 9
0
def test_task_runner_puts_cloud_in_context(client):
    @prefect.task
    def whats_in_ctx():
        return prefect.context.get("cloud")

    res = CloudTaskRunner(task=whats_in_ctx).run()

    assert res.is_successful()
    assert res.result is True
Ejemplo n.º 10
0
def test_task_runner_puts_cloud_in_context(client):
    @prefect.task(result_handler=ResultHandler())
    def whats_in_ctx():
        return prefect.context.get("checkpointing")

    res = CloudTaskRunner(task=whats_in_ctx).run()

    assert res.is_successful()
    assert res.result is True
Ejemplo n.º 11
0
    def test_task_runner_has_a_heartbeat_with_task_run_id(self, monkeypatch):
        client = MagicMock()
        monkeypatch.setattr("prefect.engine.cloud.task_runner.Client",
                            MagicMock(return_value=client))
        task = Task(name="test")
        res = CloudTaskRunner(task=task).run(context={"task_run_id": 1234})

        assert res.is_successful()
        assert client.update_task_run_heartbeat.call_args[0][0] == 1234
Ejemplo n.º 12
0
def test_task_runner_doesnt_call_client_if_map_index_is_none(client):
    task = Task(name="test")

    res = CloudTaskRunner(task=task).run()

    ## assertions
    assert client.get_task_run_info.call_count == 0  # never called
    assert client.set_task_run_state.call_count == 2  # Pending -> Running -> Success
    assert client.get_latest_cached_states.call_count == 0

    states = [call[1]["state"] for call in client.set_task_run_state.call_args_list]
    assert [type(s).__name__ for s in states] == ["Running", "Success"]
    assert res.is_successful()
Ejemplo n.º 13
0
def test_cloud_task_runner_handles_retries_with_queued_states_from_cloud(client):
    calls = []

    def queued_mock(*args, **kwargs):
        calls.append(kwargs)
        # first retry attempt will get queued
        if len(calls) == 4:
            return Queued()  # immediate start time
        else:
            return kwargs.get("state")

    client.set_task_run_state = queued_mock

    @prefect.task(
        max_retries=2,
        retry_delay=datetime.timedelta(seconds=0),
        result_handler=ResultHandler(),
    )
    def tagged_task(x):
        if prefect.context.get("task_run_count", 1) == 1:
            raise ValueError("gimme a sec")
        return x

    upstream_result = Result(value=42, result_handler=JSONResultHandler())
    res = CloudTaskRunner(task=tagged_task).run(
        context={"task_run_version": 1},
        state=None,
        upstream_states={
            Edge(Task(), tagged_task, key="x"): Success(result=upstream_result)
        },
        executor=prefect.engine.executors.LocalExecutor(),
    )

    assert res.is_successful()
    assert res.result == 42
    assert (
        len(calls) == 6
    )  # Running -> Failed -> Retrying -> Queued -> Running -> Success
    assert [type(c["state"]).__name__ for c in calls] == [
        "Running",
        "Failed",
        "Retrying",
        "Running",
        "Running",
        "Success",
    ]

    # ensures result handler was called and persisted
    assert calls[2]["state"].cached_inputs["x"].safe_value.value == "42"
Ejemplo n.º 14
0
def test_task_runner_uses_upstream_result_handlers(client):
    class MyResult(Result):
        def read(self, *args, **kwargs):
            self.value = "cool"
            return self

        def write(self, *args, **kwargs):
            return self

    @prefect.task(result=PrefectResult())
    def t(x):
        return x

    success = Success(result=PrefectResult(location="1"))
    upstream_states = {Edge(Task(result=MyResult()), t, key="x"): success}
    state = CloudTaskRunner(task=t).run(upstream_states=upstream_states)
    assert state.is_successful()
    assert state.result == "cool"
Ejemplo n.º 15
0
def test_task_runner_places_task_tags_in_state_context_and_serializes_them(monkeypatch):
    task = Task(name="test", tags=["1", "2", "tag"])
    session = MagicMock()
    monkeypatch.setattr("prefect.client.client.GraphQLResult", MagicMock())
    monkeypatch.setattr("requests.Session", MagicMock(return_value=session))

    res = CloudTaskRunner(task=task).run()
    assert res.is_successful()

    ## extract the variables payload from the calls to POST
    call_vars = [
        json.loads(call[1]["json"]["variables"]) for call in session.post.call_args_list
    ]

    # do some mainpulation to get the state payloads
    inputs = [c["input"]["states"][0] for c in call_vars if c is not None]
    assert inputs[0]["state"]["type"] == "Running"
    assert set(inputs[0]["state"]["context"]["tags"]) == set(["1", "2", "tag"])
    assert inputs[-1]["state"]["type"] == "Success"
    assert set(inputs[-1]["state"]["context"]["tags"]) == set(["1", "2", "tag"])
Ejemplo n.º 16
0
def test_task_runner_queries_for_cached_states_if_task_has_caching(client):
    @prefect.task(cache_for=datetime.timedelta(minutes=1))
    def cached_task():
        return 42

    state = Cached(
        cached_result_expiration=datetime.datetime.utcnow()
        + datetime.timedelta(days=1),
        result=Result(99, JSONResultHandler()),
    )
    old_state = Cached(
        cached_result_expiration=datetime.datetime.utcnow()
        - datetime.timedelta(days=1),
        result=13,
    )
    client.get_latest_cached_states = MagicMock(return_value=[state, old_state])

    res = CloudTaskRunner(task=cached_task).run()
    assert client.get_latest_cached_states.called
    assert res.is_successful()
    assert res.is_cached()
    assert res.result == 99
def test_task_runner_handles_looping_with_no_result(client):
    @prefect.task(result_handler=ResultHandler())
    def looper():
        if prefect.context.get("task_loop_count", 1) < 3:
            raise LOOP()
        return 42

    res = CloudTaskRunner(task=looper).run(
        context={"task_run_version": 1},
        state=None,
        upstream_states={},
        executor=prefect.engine.executors.LocalExecutor(),
    )

    ## assertions
    assert res.is_successful()
    assert client.get_task_run_info.call_count == 0
    assert (
        client.set_task_run_state.call_count == 6
    )  # Pending -> Running -> Looped (1) -> Running -> Looped (2) -> Running -> Success
    versions = [call[1]["version"] for call in client.set_task_run_state.call_args_list]
    assert versions == [1, 2, 3, 4, 5, 6]
Ejemplo n.º 18
0
def test_task_runner_validates_cached_states_if_task_has_caching(client):
    @prefect.task(cache_for=datetime.timedelta(minutes=1), result=PrefectResult())
    def cached_task():
        return 42

    state = Cached(
        cached_result_expiration=datetime.datetime.utcnow()
        - datetime.timedelta(minutes=2),
        result=PrefectResult(location="99"),
    )
    old_state = Cached(
        cached_result_expiration=datetime.datetime.utcnow()
        - datetime.timedelta(days=1),
        result=PrefectResult(location="13"),
    )
    client.get_latest_cached_states = MagicMock(return_value=[state, old_state])

    res = CloudTaskRunner(task=cached_task).run()
    assert client.get_latest_cached_states.called
    assert res.is_successful()
    assert res.is_cached()
    assert res.result == 42
Ejemplo n.º 19
0
def test_task_runner_treats_unfound_files_as_invalid_caches(client, tmpdir):
    @prefect.task(cache_for=datetime.timedelta(minutes=1), result=PrefectResult())
    def cached_task():
        return 42

    state = Cached(
        cached_result_expiration=datetime.datetime.utcnow()
        + datetime.timedelta(minutes=2),
        result=LocalResult(location=str(tmpdir / "made_up_data.prefect")),
    )
    old_state = Cached(
        cached_result_expiration=datetime.datetime.utcnow()
        + datetime.timedelta(days=1),
        result=PrefectResult(location="13"),
    )
    client.get_latest_cached_states = MagicMock(return_value=[state, old_state])

    res = CloudTaskRunner(task=cached_task).run()
    assert client.get_latest_cached_states.called
    assert res.is_successful()
    assert res.is_cached()
    assert res.result == 13
Ejemplo n.º 20
0
def test_task_runner_performs_retries_for_short_delays(client):
    global_list = []

    @prefect.task(max_retries=1, retry_delay=datetime.timedelta(seconds=0))
    def noop():
        if global_list:
            return
        else:
            global_list.append(0)
            raise ValueError("oops")

    res = CloudTaskRunner(task=noop).run(
        state=None,
        upstream_states={},
        executor=prefect.engine.executors.LocalExecutor(),
    )

    ## assertions
    assert res.is_successful()
    assert client.get_task_run_info.call_count == 0  # never called
    assert (client.set_task_run_state.call_count == 5
            )  # Pending -> Running -> Failed -> Retrying -> Running -> Success
Ejemplo n.º 21
0
def test_task_runner_handles_looping(client):
    @prefect.task(result_handler=ResultHandler())
    def looper():
        if prefect.context.get("task_loop_count", 1) < 3:
            raise LOOP(result=prefect.context.get("task_loop_result", 0) + 10)
        return prefect.context.get("task_loop_result")

    res = CloudTaskRunner(task=looper).run(context={"task_run_version": 1},
                                           state=None,
                                           upstream_states={})

    ## assertions
    assert res.is_successful()
    assert client.get_task_run_info.call_count == 0
    assert (
        client.set_task_run_state.call_count == 6
    )  # Pending -> Running -> Looped (1) -> Running -> Looped (2) -> Running -> Success
    versions = [
        call[1]["version"] for call in client.set_task_run_state.call_args_list
        if call[1]["version"]
    ]
    assert versions == [1, 3, 5]