예제 #1
0
    def test_reads_result_using_handler_attribute_if_cached_valid(
            self, client):
        class MyResult(Result):
            def read(self, *args, **kwargs):
                self.value = 53
                return self

        with pytest.warns(UserWarning):
            task = Task(cache_validator=duration_only, result=MyResult())
        result = PrefectResult(location="2")
        state = Cached(
            result=result,
            cached_result_expiration=pendulum.now("utc").add(minutes=1))

        client.get_latest_cached_states = MagicMock(return_value=[])

        new = CloudTaskRunner(task).check_task_is_cached(
            state=state, inputs={"a": Result(1)})
        assert new is state
        assert new.result == 53
예제 #2
0
    def check_target(self, state: State, inputs: Dict[str, Result]) -> State:
        """
        Checks if a Result exists at the task's target.

        Args:
            - state (State): the current state of this task
            - inputs (Dict[str, Result]): a dictionary of inputs whose keys correspond
                to the task's `run()` arguments.

        Returns:
            - State: the state of the task after running the check
        """
        result = self.result
        target = self.task.target

        if result and target:
            raw_inputs = {k: r.value for k, r in inputs.items()}
            formatting_kwargs = {
                **prefect.context.get("parameters", {}).copy(),
                **raw_inputs,
                **prefect.context,
            }

            if not isinstance(target, str):
                target = target(**formatting_kwargs)

            if result.exists(target, **formatting_kwargs):
                new_res = result.read(target.format(**formatting_kwargs))
                cached_state = Cached(
                    result=new_res,
                    hashed_inputs={
                        key: tokenize(val.value)
                        for key, val in inputs.items()
                    },
                    cached_result_expiration=None,
                    cached_parameters=formatting_kwargs.get("parameters"),
                    message=f"Result found at task target {target}",
                )
                return cached_state

        return state
예제 #3
0
    def test_reads_result_if_cached_valid_using_task_result(task, client):
        class MyResult(Result):
            def read(self, *args, **kwargs):
                self.value = 53
                return self

        task = Task(
            result=MyResult(),
            cache_for=datetime.timedelta(minutes=1),
            cache_validator=duration_only,
        )
        state = Cached(
            result=PrefectResult(location="2"),
            cached_result_expiration=pendulum.now("utc").add(minutes=1),
        )

        client.get_latest_cached_states = MagicMock(return_value=[state])
        new = CloudTaskRunner(task).check_task_is_cached(
            state=Pending(), inputs={"a": Result(1)})
        assert new is state
        assert new.result == 53
예제 #4
0
    def cache_result(self, state: State, inputs: Dict[str, Result]) -> State:
        """
        Caches the result of a successful task, if appropriate. Alternatively,
        if the task is failed, caches the inputs.

        Tasks are cached if:
            - task.cache_for is not None
            - the task state is Successful
            - the task state is not Skipped (which is a subclass of Successful)

        Args:
            - state (State): the current state of this task
            - inputs (Dict[str, Result], optional): a dictionary of inputs whose keys correspond
                to the task's `run()` arguments.

        Returns:
            - State: the state of the task after running the check

        """
        if state.is_failed():
            state.cached_inputs = inputs  # type: ignore

        if (
            state.is_successful()
            and not state.is_skipped()
            and self.task.cache_for is not None
        ):
            expiration = pendulum.now("utc") + self.task.cache_for
            cached_state = Cached(
                result=state._result,
                cached_inputs=inputs,
                cached_result_expiration=expiration,
                cached_parameters=prefect.context.get("parameters"),
                message=state.message,
            )
            return cached_state

        return state
예제 #5
0
 def test_parameters_validate_with_defaults(self):
     state = Cached(cached_parameters=dict(x=1, s="str"))
     assert partial_parameters_only()(state, None, dict(x=1, s="str")) is True
     state = Cached(cached_parameters=dict(x=1, s="str"))
     assert partial_parameters_only()(state, None, dict(x=1, s="strs")) is True
예제 #6
0
    def test_skipped_is_success(self):
        assert issubclass(Skipped, Success)

    def test_timedout_is_failed(self):
        assert issubclass(TimedOut, Failed)

    def test_trigger_failed_is_failed(self):
        assert issubclass(TriggerFailed, Failed)


@pytest.mark.parametrize(
    "state_check",
    [
        dict(state=Cancelled(), assert_true={"is_finished"}),
        dict(state=Cached(),
             assert_true={"is_cached", "is_finished", "is_successful"}),
        dict(state=ClientFailed(), assert_true={"is_meta_state"}),
        dict(state=Failed(), assert_true={"is_finished", "is_failed"}),
        dict(state=Finished(), assert_true={"is_finished"}),
        dict(state=Looped(), assert_true={"is_finished", "is_looped"}),
        dict(state=Mapped(),
             assert_true={"is_finished", "is_mapped", "is_successful"}),
        dict(state=Paused(), assert_true={"is_pending", "is_scheduled"}),
        dict(state=Pending(), assert_true={"is_pending"}),
        dict(state=Queued(), assert_true={"is_meta_state", "is_queued"}),
        dict(state=Resume(), assert_true={"is_pending", "is_scheduled"}),
        dict(state=Retrying(),
             assert_true={"is_pending", "is_scheduled", "is_retrying"}),
        dict(state=Running(), assert_true={"is_running"}),
        dict(state=Scheduled(), assert_true={"is_pending", "is_scheduled"}),
예제 #7
0
 def test_state_type_methods_with_cached_state(self):
     state = Cached()
     assert state.is_cached()
     assert not state.is_retrying()
     assert not state.is_pending()
     assert not state.is_running()
     assert state.is_finished()
     assert not state.is_skipped()
     assert not state.is_scheduled()
     assert state.is_successful()
     assert not state.is_failed()
     assert not state.is_mapped()
     assert not state.is_meta_state()
예제 #8
0
    def test_skipped_is_success(self):
        assert issubclass(Skipped, Success)

    def test_timedout_is_failed(self):
        assert issubclass(TimedOut, Failed)

    def test_trigger_failed_is_failed(self):
        assert issubclass(TriggerFailed, Failed)


@pytest.mark.parametrize(
    "state_check",
    [
        dict(state=Cancelled(), assert_true={"is_finished", "is_failed"}),
        dict(state=Cached(), assert_true={"is_cached", "is_finished", "is_successful"}),
        dict(state=ClientFailed(), assert_true={"is_meta_state"}),
        dict(state=Failed(), assert_true={"is_finished", "is_failed"}),
        dict(state=Finished(), assert_true={"is_finished"}),
        dict(state=Looped(), assert_true={"is_finished", "is_looped"}),
        dict(state=Mapped(), assert_true={"is_finished", "is_mapped", "is_successful"}),
        dict(state=Paused(), assert_true={"is_pending", "is_scheduled"}),
        dict(state=Pending(), assert_true={"is_pending"}),
        dict(state=Queued(), assert_true={"is_meta_state", "is_queued"}),
        dict(state=Resume(), assert_true={"is_pending", "is_scheduled"}),
        dict(
            state=Retrying(), assert_true={"is_pending", "is_scheduled", "is_retrying"}
        ),
        dict(state=Running(), assert_true={"is_running"}),
        dict(state=Scheduled(), assert_true={"is_pending", "is_scheduled"}),
        dict(
예제 #9
0
 def test_parameters_validate(self):
     state = Cached(cached_parameters=dict(x=1, s="str"))
     assert all_parameters(state, None, dict(x=1, s="str")) is True
예제 #10
0
 def test_additional_parameters_invalidate(self):
     state = Cached(cached_parameters=dict(x=1, s="str"))
     assert all_parameters(state, None, dict(x=1, s="str",
                                             noise="e")) is False
예제 #11
0
 def test_inputs_validate(self):
     state = Cached(cached_inputs=dict(x=1, s="str"))
     assert all_inputs(state, dict(x=1, s="str"), None) is True
예제 #12
0
 def test_hashed_inputs_validate(self):
     state = Cached(hashed_inputs=dict(x=tokenize(1), s=tokenize("str")))
     assert all_inputs(state, dict(x=1, s="str"), None) is True
예제 #13
0
 def test_unexpired_cache(self):
     state = Cached(cached_result_expiration=pendulum.now("utc") +
                    timedelta(days=1))
     assert duration_only(state, None, None) is True
예제 #14
0
 def test_handles_none(self):
     state = Cached(cached_parameters=dict(x=5))
     assert partial_inputs_only(validate_on=["x"])(state, dict(x=5), None) is False
     state = Cached(cached_inputs=dict(x=Result(5)))
     assert partial_inputs_only(validate_on=["x"])(state, None, None) is False
예제 #15
0
def test_expired_cache_stateful(validator):
    state = Cached(cached_result_expiration=pendulum.now("utc") -
                   timedelta(days=1))
    assert validator()(state, None, None) is False
예제 #16
0
 def test_curried(self):
     state = Cached(cached_parameters=dict(x=1, s="str"))
     validator = partial_parameters_only(validate_on=["x"])
     assert validator(state, None, dict(x=1)) is True
     assert validator(state, None, dict(x=2, s="str")) is False
예제 #17
0
 def test_inputs_validate_with_defaults(self):
     state = Cached(cached_inputs=dict(x=Result(1), s=Result("str")))
     assert partial_inputs_only(None)(state, dict(x=1, s="str"), None) is True
     state = Cached(cached_inputs=dict(x=Result(1), s=Result("str")))
     assert partial_inputs_only(None)(state, dict(x=1, s="strs"), None) is True
예제 #18
0
 def test_curried(self):
     state = Cached(cached_inputs=dict(x=Result(1), s=Result("str")))
     validator = partial_inputs_only(validate_on=["x"])
     assert validator(state, dict(x=1), None) is True
     assert validator(state, dict(x=2, s="str"), None) is False
예제 #19
0
 def test_additional_inputs_invalidate(self):
     state = Cached(cached_inputs=dict(x=1, s="str"))
     assert all_inputs(state, dict(x=1, s="str", noise="e"), None) is False
예제 #20
0
 def test_cached_result_expiration_none_is_interpreted_as_infinite(self):
     state = Cached(cached_result_expiration=None)
     assert duration_only(state, None, None) is True
예제 #21
0
 def test_inputs_invalidate(self):
     state = Cached(cached_inputs=dict(x=Result(1), s=Result("str")))
     assert all_inputs(state, dict(x=1, s="strs"), None) is False