Пример #1
0
def test_task_runner_handles_looping_with_retries(client):
    # note that looping _requires_ a result handler in Cloud
    @prefect.task(
        max_retries=1,
        retry_delay=datetime.timedelta(seconds=0),
        result=PrefectResult(),
    )
    def looper():
        if (prefect.context.get("task_loop_count") == 2
                and prefect.context.get("task_run_count", 1) == 1):
            raise ValueError("Stop")
        if prefect.context.get("task_loop_count", 1) < 3:
            raise LOOP(result=prefect.context.get("task_loop_result", 0) + 10)
        return prefect.context.get("task_loop_result")

    client.get_task_run_info.side_effect = [
        MagicMock(version=i,
                  state=Pending() if i == 0 else Looped(loop_count=i))
        for i in range(5)
    ]
    res = CloudTaskRunner(task=looper).run(context={"task_run_version": 1},
                                           state=None,
                                           upstream_states={})

    ## assertions
    assert res.is_successful()
    assert client.get_task_run_info.call_count == 4
    assert (
        client.set_task_run_state.call_count == 9
    )  # Pending -> Running -> Looped (1) -> Running -> Failed -> Retrying -> Running -> Looped(2) -> Running -> Success
    versions = [
        call[1]["version"] for call in client.set_task_run_state.call_args_list
        if call[1]["version"]
    ]
    assert versions == [1, 2, 3]
Пример #2
0
    def test_reads_result_if_cached_valid(self, client):
        result = PrefectResult(location="2")

        with pytest.warns(UserWarning):
            task = Task(cache_validator=duration_only, result=PrefectResult())

        state = Cached(
            result=result,
            cached_result_expiration=pendulum.now("utc").add(minutes=1))

        client.get_latest_cached_states = MagicMock(return_value=[])

        new = CloudTaskRunner(task).check_task_is_cached(
            state=state, inputs={"a": PrefectResult(value=1)})
        assert new is state
        assert new.result == 2
Пример #3
0
def test_task_runner_handles_looping(client):
    @prefect.task(result=PrefectResult())
    def looper():
        if prefect.context.get("task_loop_count", 1) < 3:
            raise LOOP(result=prefect.context.get("task_loop_result", 0) + 10)
        return prefect.context.get("task_loop_result")

    client.get_task_run_info.side_effect = [
        MagicMock(version=i, state=Pending()) for i in range(1, 4)
    ]

    res = CloudTaskRunner(task=looper).run(context={"task_run_version": 1},
                                           state=None,
                                           upstream_states={})

    ## assertions
    assert res.is_successful()
    assert client.get_task_run_info.call_count == 3
    assert (
        client.set_task_run_state.call_count == 6
    )  # Pending -> Running -> Looped (1) -> Running -> Looped (2) -> Running -> Success
    versions = [
        call[1]["version"] for call in client.set_task_run_state.call_args_list
        if call[1]["version"]
    ]
    assert versions == [1, 2, 3]
Пример #4
0
def test_cloud_task_runner_handles_retries_with_queued_states_from_cloud(
        client):
    calls = []

    def queued_mock(*args, **kwargs):
        calls.append(kwargs)
        # first retry attempt will get queued
        if len(calls) == 4:
            return Queued()  # immediate start time
        else:
            return kwargs.get("state")

    client.set_task_run_state = queued_mock

    @prefect.task(
        max_retries=2,
        retry_delay=datetime.timedelta(seconds=0),
        result=PrefectResult(),
    )
    def tagged_task(x):
        if prefect.context.get("task_run_count", 1) == 1:
            raise ValueError("gimme a sec")
        return x

    upstream_result = PrefectResult(value=42, location="42")
    res = CloudTaskRunner(task=tagged_task).run(
        context={"task_run_version": 1},
        state=None,
        upstream_states={
            Edge(Task(), tagged_task, key="x"): Success(result=upstream_result)
        },
    )

    assert res.is_successful()
    assert res.result == 42
    assert (len(calls) == 6
            )  # Running -> Failed -> Retrying -> Queued -> Running -> Success
    assert [type(c["state"]).__name__ for c in calls] == [
        "Running",
        "Failed",
        "Retrying",
        "Running",
        "Running",
        "Success",
    ]

    assert calls[2]["state"].cached_inputs["x"].value == 42
Пример #5
0
def test_serialize_and_deserialize_on_mixed_cached_state():
    safe_dct = PrefectResult(location=json.dumps(dict(hi=5, bye=6)))
    now = pendulum.now("utc")
    state = Cached(
        cached_inputs=dict(x=PrefectResult(value=2),
                           p=PrefectResult(value="p")),
        result=safe_dct,
        cached_result_expiration=now,
    )
    serialized = state.serialize()
    new_state = State.deserialize(serialized)
    assert isinstance(new_state, Cached)
    assert new_state.color == state.color
    assert new_state._result.location == json.dumps(dict(hi=5, bye=6))
    assert new_state.cached_result_expiration == state.cached_result_expiration
    assert new_state.cached_inputs == dict.fromkeys(["x", "p"],
                                                    PrefectResult())
Пример #6
0
def test_deep_map_with_a_failure(monkeypatch, executor):

    flow_run_id = str(uuid.uuid4())
    task_run_id_1 = str(uuid.uuid4())
    task_run_id_2 = str(uuid.uuid4())
    task_run_id_3 = str(uuid.uuid4())

    with prefect.Flow(name="test", result=PrefectResult()) as flow:
        t1 = plus_one.map([-1, 0, 1])
        t2 = invert_fail_once.map(t1)
        t3 = plus_one.map(t2)

    client = MockedCloudClient(
        flow_runs=[FlowRun(id=flow_run_id)],
        task_runs=[
            TaskRun(id=task_run_id_1,
                    task_slug=flow.slugs[t1],
                    flow_run_id=flow_run_id),
            TaskRun(id=task_run_id_2,
                    task_slug=flow.slugs[t2],
                    flow_run_id=flow_run_id),
            TaskRun(id=task_run_id_3,
                    task_slug=flow.slugs[t3],
                    flow_run_id=flow_run_id),
        ] + [
            TaskRun(id=str(uuid.uuid4()),
                    task_slug=flow.slugs[t],
                    flow_run_id=flow_run_id)
            for t in flow.tasks if t not in [t1, t2, t3]
        ],
        monkeypatch=monkeypatch,
    )

    with prefect.context(flow_run_id=flow_run_id):
        state = CloudFlowRunner(flow=flow).run(return_tasks=flow.tasks)

    assert state.is_failed()
    assert client.flow_runs[flow_run_id].state.is_failed()
    assert client.task_runs[task_run_id_1].state.is_mapped()
    assert client.task_runs[task_run_id_2].state.is_mapped()
    assert client.task_runs[task_run_id_3].state.is_mapped()

    # there should be a total of 4 task runs corresponding to each mapped task
    for t in [t1, t2, t3]:
        assert (len([
            tr for tr in client.task_runs.values()
            if tr.task_slug == flow.slugs[t]
        ]) == 4)

    # t2's first child task should have failed
    t2_0 = next(tr for tr in client.task_runs.values()
                if tr.task_slug == flow.slugs[t2] and tr.map_index == 0)
    assert t2_0.state.is_failed()

    # t3's first child task should have failed
    t3_0 = next(tr for tr in client.task_runs.values()
                if tr.task_slug == flow.slugs[t3] and tr.map_index == 0)
    assert t3_0.state.is_failed()
Пример #7
0
def test_deep_map(monkeypatch, executor):

    flow_run_id = str(uuid.uuid4())
    task_run_id_1 = str(uuid.uuid4())
    task_run_id_2 = str(uuid.uuid4())
    task_run_id_3 = str(uuid.uuid4())

    with prefect.Flow(name="test", result=PrefectResult()) as flow:
        t1 = plus_one.map([0, 1, 2])
        t2 = plus_one.map(t1)
        t3 = plus_one.map(t2)

    client = MockedCloudClient(
        flow_runs=[FlowRun(id=flow_run_id)],
        task_runs=[
            TaskRun(
                id=task_run_id_1, task_slug=flow.slugs[t1], flow_run_id=flow_run_id
            ),
            TaskRun(
                id=task_run_id_2, task_slug=flow.slugs[t2], flow_run_id=flow_run_id
            ),
            TaskRun(
                id=task_run_id_3, task_slug=flow.slugs[t3], flow_run_id=flow_run_id
            ),
        ]
        + [
            TaskRun(
                id=str(uuid.uuid4()), task_slug=flow.slugs[t], flow_run_id=flow_run_id
            )
            for t in flow.tasks
            if t not in [t1, t2, t3]
        ],
        monkeypatch=monkeypatch,
    )

    with prefect.context(flow_run_id=flow_run_id):
        state = CloudFlowRunner(flow=flow).run(
            return_tasks=flow.tasks, executor=executor
        )

    assert state.is_successful()
    assert client.flow_runs[flow_run_id].state.is_successful()
    assert client.task_runs[task_run_id_1].state.is_mapped()
    assert client.task_runs[task_run_id_2].state.is_mapped()
    assert client.task_runs[task_run_id_3].state.is_mapped()

    # there should be a total of 4 task runs corresponding to each mapped task
    for t in [t1, t2, t3]:
        assert (
            len(
                [
                    tr
                    for tr in client.task_runs.values()
                    if tr.task_slug == flow.slugs[t]
                ]
            )
            == 4
        )
Пример #8
0
def test_task_runner_uses_upstream_result_handlers(client):
    class MyResult(Result):
        def read(self, *args, **kwargs):
            self.value = "cool"
            return self

        def write(self, *args, **kwargs):
            return self

    @prefect.task(result=PrefectResult())
    def t(x):
        return x

    success = Success(result=PrefectResult(location="1"))
    upstream_states = {Edge(Task(result=MyResult()), t, key="x"): success}
    state = CloudTaskRunner(task=t).run(upstream_states=upstream_states)
    assert state.is_successful()
    assert state.result == "cool"
Пример #9
0
    def check_for_retry(self, state: State, inputs: Dict[str, Result]) -> State:
        """
        Checks to see if a FAILED task should be retried.

        Args:
            - state (State): the current state of this task
            - inputs (Dict[str, Result], optional): a dictionary of inputs whose keys correspond
                to the task's `run()` arguments.

        Returns:
            - State: the state of the task after running the check
        """
        if state.is_failed():
            run_count = prefect.context.get("task_run_count", 1)
            if prefect.context.get("task_loop_count") is not None:

                loop_result = self.result.from_value(
                    value=prefect.context.get("task_loop_result")
                )

                ## checkpoint tasks if a result is present, except for when the user has opted out by disabling checkpointing
                if (
                    prefect.context.get("checkpointing") is True
                    and self.task.checkpoint is not False
                    and loop_result.value is not None
                ):
                    try:
                        value = prefect.context.get("task_loop_result")
                        loop_result = self.result.write(
                            value, filename="output", **prefect.context
                        )
                    except NotImplementedError:
                        pass

                loop_context = {
                    "_loop_count": PrefectResult(
                        location=json.dumps(prefect.context["task_loop_count"]),
                    ),
                    "_loop_result": loop_result,
                }
                inputs.update(loop_context)
            if run_count <= self.task.max_retries:
                start_time = pendulum.now("utc") + self.task.retry_delay
                msg = "Retrying Task (after attempt {n} of {m})".format(
                    n=run_count, m=self.task.max_retries + 1
                )
                retry_state = Retrying(
                    start_time=start_time,
                    cached_inputs=inputs,
                    message=msg,
                    run_count=run_count,
                )
                return retry_state

        return state
Пример #10
0
def test_cloud_task_runner_sends_heartbeat_on_queued_retries(client):
    calls = []
    tr_ids = []

    def queued_mock(*args, **kwargs):
        calls.append(kwargs)
        # first retry attempt will get queued
        if len(calls) == 4:
            return Queued()  # immediate start time
        else:
            return kwargs.get("state")

    def mock_heartbeat(**kwargs):
        tr_ids.append(kwargs.get("task_run_id"))

    client.set_task_run_state = queued_mock
    client.update_task_run_heartbeat = mock_heartbeat

    @prefect.task(
        max_retries=2,
        retry_delay=datetime.timedelta(seconds=0),
        result=PrefectResult(),
    )
    def tagged_task(x):
        if prefect.context.get("task_run_count", 1) == 1:
            raise ValueError("gimme a sec")
        return x

    upstream_result = PrefectResult(value=42, location="42")
    CloudTaskRunner(task=tagged_task).run(
        context={
            "task_run_version": 1,
            "task_run_id": "id"
        },
        state=None,
        upstream_states={
            Edge(Task(), tagged_task, key="x"): Success(result=upstream_result)
        },
    )

    assert len(calls) == 6
    assert tr_ids == ["id", "id"]
Пример #11
0
    def test_load_results_from_upstream_reads_results_using_upstream_handlers(
            self, cloud_api):
        class CustomResult(Result):
            def read(self, *args, **kwargs):
                return "foo-bar-baz".split("-")

        state = Success(result=PrefectResult(location="1"))
        edge = Edge(Task(result=CustomResult()), 2, key="x")
        new_state, upstreams = CloudTaskRunner(task=Task()).load_results(
            state=Pending(), upstream_states={edge: state})
        assert upstreams[edge].result == ["foo", "bar", "baz"]
Пример #12
0
    def test_state_kwarg_is_prioritized_over_db_caches(self, client):
        task = Task(
            cache_for=datetime.timedelta(minutes=1),
            cache_validator=duration_only,
            result=PrefectResult(),
        )
        state_a = Cached(
            result=PrefectResult(location="2"),
            cached_result_expiration=pendulum.now("utc").add(minutes=1),
        )
        state_b = Cached(
            result=PrefectResult(location="99"),
            cached_result_expiration=pendulum.now("utc").add(minutes=1),
        )

        client.get_latest_cached_states = MagicMock(return_value=[state_a])
        new = CloudTaskRunner(task).check_task_is_cached(
            state=state_b, inputs={"a": Result(1)})
        assert new is state_b
        assert new.result == 99
Пример #13
0
def test_task_runner_uses_cached_inputs_from_db_state(monkeypatch):
    @prefect.task(name="test", result=PrefectResult())
    def add_one(x):
        return x + 1

    db_state = Retrying(cached_inputs=dict(x=PrefectResult(value=41)))
    get_task_run_info = MagicMock(return_value=MagicMock(state=db_state))
    set_task_run_state = MagicMock(
        side_effect=lambda task_run_id, version, state, cache_for: state)
    client = MagicMock(get_task_run_info=get_task_run_info,
                       set_task_run_state=set_task_run_state)
    monkeypatch.setattr("prefect.engine.cloud.task_runner.Client",
                        MagicMock(return_value=client))
    res = CloudTaskRunner(task=add_one).run(context={"map_index": 1})

    ## assertions
    assert get_task_run_info.call_count == 1  # one time to pull latest state
    assert set_task_run_state.call_count == 2  # Pending -> Running -> Success
    assert res.is_successful()
    assert res.result == 42
Пример #14
0
 def __init__(
     self,
     name: str,
     default: JSONSerializableParameterValue = no_default,
     required: bool = True,
     tags: Iterable[str] = None,
 ) -> None:
     super().__init__(name=name,
                      default=default,
                      required=required,
                      tags=tags)
     self.result = PrefectResult(serializer=DateTimeSerializer())
Пример #15
0
 def __init__(
     self,
     name: str,
     required: bool = True,
     tags: Iterable[str] = None,
 ) -> None:
     default = no_default if required else None
     super().__init__(name=name,
                      default=default,
                      required=required,
                      tags=tags)
     self.result = PrefectResult(serializer=DateTimeSerializer())
Пример #16
0
    def test_load_results_from_upstream_reads_secret_results(self, cloud_api):
        secret_result = SecretResult(
            prefect.tasks.secrets.PrefectSecret(name="foo"))

        state = Success(result=PrefectResult(location="foo"))

        with prefect.context(secrets=dict(foo=42)):
            edge = Edge(Task(result=secret_result), 2, key="x")
            new_state, upstreams = CloudTaskRunner(task=Task()).load_results(
                state=Pending(), upstream_states={edge: state})

        assert upstreams[edge].result == 42
Пример #17
0
def test_task_runner_treats_unfound_files_as_invalid_caches(client, tmpdir):
    @prefect.task(cache_for=datetime.timedelta(minutes=1), result=PrefectResult())
    def cached_task():
        return 42

    state = Cached(
        cached_result_expiration=datetime.datetime.utcnow()
        + datetime.timedelta(minutes=2),
        result=LocalResult(location=str(tmpdir / "made_up_data.prefect")),
    )
    old_state = Cached(
        cached_result_expiration=datetime.datetime.utcnow()
        + datetime.timedelta(days=1),
        result=PrefectResult(location="13"),
    )
    client.get_latest_cached_states = MagicMock(return_value=[state, old_state])

    res = CloudTaskRunner(task=cached_task).run()
    assert client.get_latest_cached_states.called
    assert res.is_successful()
    assert res.is_cached()
    assert res.result == 13
Пример #18
0
def test_task_runner_validates_cached_states_if_task_has_caching(client):
    @prefect.task(cache_for=datetime.timedelta(minutes=1), result=PrefectResult())
    def cached_task():
        return 42

    state = Cached(
        cached_result_expiration=datetime.datetime.utcnow()
        - datetime.timedelta(minutes=2),
        result=PrefectResult(location="99"),
    )
    old_state = Cached(
        cached_result_expiration=datetime.datetime.utcnow()
        - datetime.timedelta(days=1),
        result=PrefectResult(location="13"),
    )
    client.get_latest_cached_states = MagicMock(return_value=[state, old_state])

    res = CloudTaskRunner(task=cached_task).run()
    assert client.get_latest_cached_states.called
    assert res.is_successful()
    assert res.is_cached()
    assert res.result == 42
Пример #19
0
    def test_task_runner_preserves_location_of_inputs_when_retrying(
            self, client):
        """
        If a user opts out of checkpointing via checkpoint=False, we don't want to
        surprise them by storing the result in cached_inputs.  This test ensures
        that whatever location is provided to a downstream task is the one that is used.
        """
        @prefect.task(max_retries=1, retry_delay=datetime.timedelta(days=1))
        def add(x, y):
            return x + y

        x = PrefectResult(value=1)
        y = PrefectResult(value="0", location="foo")
        state = Pending(cached_inputs=dict(x=x, y=y))
        x_state = Success()
        y_state = Success()
        upstream_states = {
            Edge(Task(), Task(), key="x"): x_state,
            Edge(Task(), Task(), key="y"): y_state,
        }
        res = CloudTaskRunner(task=add).run(state=state,
                                            upstream_states=upstream_states)

        ## assertions
        assert client.get_task_run_info.call_count == 0  # never called
        assert (client.set_task_run_state.call_count == 3
                )  # Pending -> Running -> Failed -> Retrying

        states = [
            call[1]["state"]
            for call in client.set_task_run_state.call_args_list
        ]
        assert states[0].is_running()
        assert states[1].is_failed()
        assert isinstance(states[2], Retrying)
        assert states[2].cached_inputs["x"].location is None
        assert states[2].cached_inputs["y"].location == "foo"
    def test_task_runner_validates_cached_state_inputs_if_task_has_caching_and_uses_task_handler(
        self, client
    ):
        class MyResult(Result):
            def read(self, *args, **kwargs):
                new = self.copy()
                new.value = 1337
                return new

        @prefect.task(
            cache_for=datetime.timedelta(minutes=1),
            cache_validator=all_inputs,
            result=MyResult(),
        )
        def cached_task(x):
            return 42

        dull_state = Cached(
            cached_result_expiration=datetime.datetime.utcnow()
            + datetime.timedelta(minutes=2),
            result=PrefectResult(location="-1"),
        )
        state = Cached(
            cached_result_expiration=datetime.datetime.utcnow()
            + datetime.timedelta(minutes=2),
            result=PrefectResult(location="99"),
            cached_inputs={"x": PrefectResult(location="2")},
        )
        client.get_latest_cached_states = MagicMock(return_value=[dull_state, state])

        res = CloudTaskRunner(task=cached_task).check_task_is_cached(
            Pending(), inputs={"x": PrefectResult(value=2)}
        )
        assert client.get_latest_cached_states.called
        assert res.is_successful()
        assert res.is_cached()
        assert res.result == 1337
Пример #21
0
    async def test_set_task_run_state_with_result(self, run_query,
                                                  task_run_id):
        result = PrefectResult(location="10")
        state = Success(result=result)

        result = await run_query(
            query=self.mutation,
            variables=dict(input=dict(states=[
                dict(task_run_id=task_run_id, state=state.serialize())
            ])),
        )
        tr = await models.TaskRun.where(
            id=result.data.set_task_run_states.states[0].id
        ).first({"state", "version"})
        assert tr.version == 2
        assert tr.state == "Success"
Пример #22
0
    def __init__(
        self,
        name: str,
        default: Any = no_default,
        required: bool = None,
        tags: Iterable[str] = None,
    ):
        if required is None:
            required = default is no_default
        if default is no_default:
            default = None
        self.required = required
        self.default = default

        super().__init__(
            name=name, slug=name, tags=tags, result=PrefectResult(), checkpoint=True,
        )
Пример #23
0
    def test_serializer_not_configurable(self):
        # By default creates own JSONSerializer
        result = PrefectResult()
        assert isinstance(result.serializer, JSONSerializer)

        # Can specify one manually as well
        serializer = JSONSerializer()
        result = PrefectResult(serializer=serializer)
        assert result.serializer is serializer

        # Can set if it's a JSONSerializer
        serializer2 = JSONSerializer()
        result.serializer = serializer2
        assert result.serializer is serializer2

        # Type errors for other serializer types
        with pytest.raises(TypeError):
            result.serializer = PickleSerializer()
        with pytest.raises(TypeError):
            result = PrefectResult(serializer=PickleSerializer())
Пример #24
0
def test_task_runner_gracefully_handles_load_results_failures(client):
    class MyResult(Result):
        def read(self, *args, **kwargs):
            raise TypeError("something is wrong!")

    @prefect.task(result=PrefectResult())
    def t(x):
        return x

    success = Success(result=MyResult(location="foo.txt"))
    upstream_states = {Edge(Task(result=MyResult()), t, key="x"): success}
    state = CloudTaskRunner(task=t).run(upstream_states=upstream_states)

    assert state.is_failed()
    assert "task results" in state.message
    assert client.set_task_run_state.call_count == 1  # Pending -> Failed

    states = [call[1]["state"] for call in client.set_task_run_state.call_args_list]
    assert [type(s).__name__ for s in states] == ["Failed"]
Пример #25
0
    def test_reads_result_using_handler_attribute_if_cached_valid(
            self, client):
        class MyResult(Result):
            def read(self, *args, **kwargs):
                self.value = 53
                return self

        with pytest.warns(UserWarning):
            task = Task(cache_validator=duration_only, result=MyResult())
        result = PrefectResult(location="2")
        state = Cached(
            result=result,
            cached_result_expiration=pendulum.now("utc").add(minutes=1))

        client.get_latest_cached_states = MagicMock(return_value=[])

        new = CloudTaskRunner(task).check_task_is_cached(
            state=state, inputs={"a": Result(1)})
        assert new is state
        assert new.result == 53
Пример #26
0
 def __init__(
         self,
         name=None,
         slug=None,
         tags=None,
         max_retries=None,
         retry_delay=None,
         timeout=None,
         trigger=None,
         skip_on_upstream_skip=True,
         cache_for=None,
         cache_validator=None,
         cache_key=None,
         checkpoint=None,
         result_handler=None,
         state_handlers=None,
         on_failure=None,
         log_stdout=False,
         result=PrefectResult(),
         target=None,
 ):
     super().__init__(
         name=name,
         slug=slug,
         tags=tags,
         max_retries=max_retries,
         retry_delay=retry_delay,
         timeout=timeout,
         trigger=trigger,
         skip_on_upstream_skip=skip_on_upstream_skip,
         cache_for=cache_for,
         cache_validator=cache_validator,
         cache_key=cache_key,
         checkpoint=checkpoint,
         result_handler=result_handler,
         state_handlers=state_handlers,
         on_failure=on_failure,
         log_stdout=log_stdout,
         result=result,
         target=target,
     )
Пример #27
0
    def test_reads_result_if_cached_valid_using_task_result(task, client):
        class MyResult(Result):
            def read(self, *args, **kwargs):
                self.value = 53
                return self

        task = Task(
            result=MyResult(),
            cache_for=datetime.timedelta(minutes=1),
            cache_validator=duration_only,
        )
        state = Cached(
            result=PrefectResult(location="2"),
            cached_result_expiration=pendulum.now("utc").add(minutes=1),
        )

        client.get_latest_cached_states = MagicMock(return_value=[state])
        new = CloudTaskRunner(task).check_task_is_cached(
            state=Pending(), inputs={"a": Result(1)})
        assert new is state
        assert new.result == 53
Пример #28
0
    def check_for_retry(self, state: State, inputs: Dict[str,
                                                         Result]) -> State:
        """
        Checks to see if a FAILED task should be retried.

        Args:
            - state (State): the current state of this task
            - inputs (Dict[str, Result], optional): a dictionary of inputs whose keys correspond
                to the task's `run()` arguments.

        Returns:
            - State: the state of the task after running the check
        """
        if state.is_failed():
            run_count = prefect.context.get("task_run_count", 1)
            if prefect.context.get("task_loop_count") is not None:
                loop_context = {
                    "_loop_count":
                    PrefectResult(location=json.dumps(
                        prefect.context["task_loop_count"]), ),
                    "_loop_result":
                    self.result.from_value(
                        value=prefect.context.get("task_loop_result")),
                }
                inputs.update(loop_context)
            if run_count <= self.task.max_retries:
                start_time = pendulum.now("utc") + self.task.retry_delay
                msg = "Retrying Task (after attempt {n} of {m})".format(
                    n=run_count, m=self.task.max_retries + 1)
                retry_state = Retrying(
                    start_time=start_time,
                    cached_inputs=inputs,
                    message=msg,
                    run_count=run_count,
                )
                return retry_state

        return state
Пример #29
0
def test_starting_at_arbitrary_loop_index_from_cloud_context(client):
    @prefect.task
    def looper(x):
        if prefect.context.get("task_loop_count", 1) < 20:
            raise LOOP(result=prefect.context.get("task_loop_result", 0) + x)
        return prefect.context.get("task_loop_result", 0) + x

    @prefect.task
    def downstream(l):
        return l**2

    with prefect.Flow(name="looping", result=PrefectResult()) as f:
        inter = looper(10)
        final = downstream(inter)

    client.get_flow_run_info = MagicMock(return_value=MagicMock(
        context={"task_loop_count": 20}))

    flow_state = CloudFlowRunner(flow=f).run(return_tasks=[inter, final])

    assert flow_state.is_successful()
    assert flow_state.result[inter].result == 10
    assert flow_state.result[final].result == 100
Пример #30
0
def test_task_failure_with_upstream_secrets_doesnt_store_secret_value_and_recompute_if_necessary(
    client,
):
    @prefect.task(max_retries=2, retry_delay=timedelta(seconds=100))
    def is_p_three(p):
        if p == 3:
            raise ValueError("No thank you.")
        return p

    with prefect.Flow("test", result=PrefectResult()) as f:
        p = prefect.tasks.secrets.PrefectSecret("p")
        res = is_p_three(p)

    with prefect.context(secrets=dict(p=3)):
        state = CloudFlowRunner(flow=f).run(return_tasks=[res])

    assert state.is_running()
    assert isinstance(state.result[res], Retrying)

    assert state.result[res].cached_inputs["p"].location is None

    ## here we set the result of the secret to an empty result, ensuring
    ## it will get converted to a "true" result;
    ## we expect that the upstream value will actually get recomputed from context
    ## through the SecretResultHandler
    safe = SecretResult(p)
    state.result[p] = Success(result=safe)
    state.result[res].start_time = pendulum.now("utc")
    state.result[res].cached_inputs = dict(p=safe)

    with prefect.context(secrets=dict(p=4)):
        new_state = CloudFlowRunner(flow=f).run(
            return_tasks=[res], task_states=state.result
        )

    assert new_state.is_successful()
    assert new_state.result[res].result == 4