예제 #1
0
def test_task_runner_validates_cached_state_inputs_with_upstream_handlers_if_task_has_caching(
    client, ):
    class Handler(ResultHandler):
        def read(self, val):
            return 1337

    @prefect.task(
        cache_for=datetime.timedelta(minutes=1),
        cache_validator=all_inputs,
        result_handler=JSONResultHandler(),
    )
    def cached_task(x):
        return 42

    dull_state = Cached(
        cached_result_expiration=datetime.datetime.utcnow() +
        datetime.timedelta(minutes=2),
        result=SafeResult("-1", JSONResultHandler()),
    )
    state = Cached(
        cached_result_expiration=datetime.datetime.utcnow() +
        datetime.timedelta(minutes=2),
        result=SafeResult("99", JSONResultHandler()),
        cached_inputs={
            "x": SafeResult("2", result_handler=JSONResultHandler())
        },
    )
    client.get_latest_cached_states = MagicMock(
        return_value=[dull_state, state])

    res = CloudTaskRunner(task=cached_task).check_task_is_cached(
        Pending(), inputs={"x": Result(2, result_handler=Handler())})
    assert client.get_latest_cached_states.called
    assert res.is_pending()
예제 #2
0
 def test_inputs_validate_with_defaults(self):
     state = Cached(cached_inputs=dict(x=1, s="str"))
     assert partial_inputs_only(None)(state, dict(x=1, s="str"),
                                      None) is True
     state = Cached(cached_inputs=dict(x=1, s="str"))
     assert partial_inputs_only(None)(state, dict(x=1, s="strs"),
                                      None) is True
예제 #3
0
 def test_handles_none(self):
     state = Cached(cached_parameters=dict(x=5))
     assert partial_inputs_only(validate_on=["x"])(state, dict(x=5),
                                                   None) is False
     state = Cached(cached_inputs=dict(x=5))
     assert partial_inputs_only(validate_on=["x"])(state, None,
                                                   None) is False
예제 #4
0
    def test_task_runner_validates_cached_state_inputs_if_task_has_caching(
            self, client):
        @prefect.task(
            cache_for=datetime.timedelta(minutes=1),
            cache_validator=all_inputs,
            result=PrefectResult(),
        )
        def cached_task(x):
            return 42

        dull_state = Cached(
            cached_result_expiration=datetime.datetime.utcnow() +
            datetime.timedelta(minutes=2),
            result=PrefectResult(location="-1"),
        )
        state = Cached(
            cached_result_expiration=datetime.datetime.utcnow() +
            datetime.timedelta(minutes=2),
            result=PrefectResult(location="99"),
            cached_inputs={"x": PrefectResult(location="2")},
        )
        client.get_latest_cached_states = MagicMock(
            return_value=[dull_state, state])

        res = CloudTaskRunner(task=cached_task).check_task_is_cached(
            Pending(), inputs={"x": PrefectResult(value=2)})
        assert client.get_latest_cached_states.called
        assert res.is_successful()
        assert res.is_cached()
        assert res.result == 99
예제 #5
0
 def test_parameters_validate_with_defaults(self):
     state = Cached(cached_parameters=dict(x=1, s="str"))
     assert partial_parameters_only()(state, None, dict(x=1,
                                                        s="str")) is True
     state = Cached(cached_parameters=dict(x=1, s="str"))
     assert partial_parameters_only()(state, None, dict(x=1,
                                                        s="strs")) is True
예제 #6
0
def test_task_runner_validates_cached_state_inputs_if_task_has_caching(client):
    @prefect.task(
        cache_for=datetime.timedelta(minutes=1),
        cache_validator=all_inputs,
        result_handler=JSONResultHandler(),
    )
    def cached_task(x):
        return 42

    dull_state = Cached(
        cached_result_expiration=datetime.datetime.utcnow() +
        datetime.timedelta(minutes=2),
        result=Result(-1, JSONResultHandler()),
    )
    state = Cached(
        cached_result_expiration=datetime.datetime.utcnow() +
        datetime.timedelta(minutes=2),
        result=Result(99, JSONResultHandler()),
        cached_inputs={
            "x": SafeResult("2", result_handler=JSONResultHandler())
        },
    )
    client.get_latest_cached_states = MagicMock(
        return_value=[dull_state, state])

    res = CloudTaskRunner(task=cached_task).check_task_is_cached(
        Pending(),
        inputs={"x": Result(2, result_handler=LocalResultHandler())})
    assert client.get_latest_cached_states.called
    assert res.is_successful()
    assert res.is_cached()
    assert res.result == 99
예제 #7
0
    def test_task_runner_validates_cached_state_inputs_with_upstream_handlers_if_task_has_caching(
            self, client):
        class MyResult(Result):
            def read(self, *args, **kwargs):
                new = self.copy()
                new.value = 1337
                return new

        @prefect.task(
            cache_for=datetime.timedelta(minutes=1),
            cache_validator=all_inputs,
            result=PrefectResult(),
        )
        def cached_task(x):
            return 42

        dull_state = Cached(
            cached_result_expiration=datetime.datetime.utcnow() +
            datetime.timedelta(minutes=2),
            result=PrefectResult(location="-1"),
        )
        state = Cached(
            cached_result_expiration=datetime.datetime.utcnow() +
            datetime.timedelta(minutes=2),
            result=PrefectResult(location="99"),
            cached_inputs={"x": PrefectResult(location="2")},
        )
        client.get_latest_cached_states = MagicMock(
            return_value=[dull_state, state])

        res = CloudTaskRunner(task=cached_task).check_task_is_cached(
            Pending(), inputs={"x": MyResult(value=2)})
        assert client.get_latest_cached_states.called
        assert res.is_pending()
예제 #8
0
 def test_validate_on_kwarg(self):
     state = Cached(cached_inputs=dict(x=1, s="str"))
     assert (partial_inputs_only(validate_on=["x", "s"])(
         state, dict(x=1, s="str"), None) is True)
     state = Cached(cached_inputs=dict(x=1, s="str"))
     assert (partial_inputs_only(validate_on=["x", "s"])(
         state, dict(x=1, s="strs"), None) is False)
     assert (partial_inputs_only(validate_on=["x"])(
         state, dict(x=1, s="strs"), None) is True)
     assert (partial_inputs_only(validate_on=["s"])(
         state, dict(x=1, s="strs"), None) is False)
예제 #9
0
 def test_validate_on_kwarg(self):
     state = Cached(cached_parameters=dict(x=1, s="str"))
     assert (partial_parameters_only(
         state, None, dict(x=1, s="str"), validate_on=["x", "s"]) is True)
     state = Cached(cached_parameters=dict(x=1, s="str"))
     assert (partial_parameters_only(
         state, None, dict(x=1, s="strs"), validate_on=["x", "s"]) is False)
     assert (partial_parameters_only(
         state, None, dict(x=1, s="strs"), validate_on=["x"]) is True)
     assert (partial_parameters_only(
         state, None, dict(x=1, s="strs"), validate_on=["s"]) is False)
예제 #10
0
    def cache_result(self, state: State, inputs: Dict[str, Result]) -> State:
        """
        Caches the result of a successful task, if appropriate. Alternatively,
        if the task is failed, caches the inputs.

        Tasks are cached if:
            - task.cache_for is not None
            - the task state is Successful
            - the task state is not Skipped (which is a subclass of Successful)

        Args:
            - state (State): the current state of this task
            - inputs (Dict[str, Result], optional): a dictionary of inputs whose keys correspond
                to the task's `run()` arguments.

        Returns:
            - State: the state of the task after running the check

        """
        if (state.is_successful() and not state.is_skipped()
                and self.task.cache_for is not None):
            expiration = pendulum.now("utc") + self.task.cache_for
            cached_state = Cached(
                result=state._result,
                hashed_inputs={
                    key: tokenize(val.value)
                    for key, val in inputs.items()
                },
                cached_result_expiration=expiration,
                cached_parameters=prefect.context.get("parameters"),
                message=state.message,
            )
            return cached_state

        return state
예제 #11
0
    def check_target(self, state: State, inputs: Dict[str, Result]) -> State:
        """
        Checks if a Result exists at the task's target.

        Args:
            - state (State): the current state of this task
            - inputs (Dict[str, Result]): a dictionary of inputs whose keys correspond
                to the task's `run()` arguments.

        Returns:
            - State: the state of the task after running the check
        """
        result = self.result
        target = self.task.target

        if result and target:
            if result.exists(target, **prefect.context):
                new_res = result.read(target.format(**prefect.context))
                cached_state = Cached(
                    result=new_res,
                    cached_inputs=inputs,
                    cached_result_expiration=None,
                    cached_parameters=prefect.context.get("parameters"),
                    message=f"Result found at task target {target}",
                )
                return cached_state

        return state
예제 #12
0
    def test_providing_cachedstate_with_simple_example(self, executor):
        class TestTask(Task):
            call_count = 0

            def run(self, x, s):
                self.call_count += 1
                return self.call_count

        with Flow(name="test") as f:
            y = TestTask(cache_validator=duration_only,
                         cache_for=datetime.timedelta(days=1))
            x = Parameter("x")
            s = SuccessTask()
            f.add_edge(x, y, key="x")
            f.add_edge(s, y, key="s")

        state = Cached(
            cached_result_expiration=pendulum.now("utc") +
            datetime.timedelta(days=1),
            result=100,
        )
        flow_state = FlowRunner(flow=f).run(
            executor=executor,
            parameters=dict(x=1),
            return_tasks=[y],
            task_states={y: state},
        )
        assert isinstance(flow_state, Success)
        assert flow_state.result[y].result == 100
예제 #13
0
    def test_state_kwarg_is_prioritized_over_db_caches(self, client):
        task = Task(
            cache_for=datetime.timedelta(minutes=1),
            cache_validator=duration_only,
            result=PrefectResult(),
        )
        state_a = Cached(
            result=PrefectResult(location="2"),
            cached_result_expiration=pendulum.now("utc").add(minutes=1),
        )
        state_b = Cached(
            result=PrefectResult(location="99"),
            cached_result_expiration=pendulum.now("utc").add(minutes=1),
        )

        client.get_latest_cached_states = MagicMock(return_value=[state_a])
        new = CloudTaskRunner(task).check_task_is_cached(
            state=state_b, inputs={"a": Result(1)})
        assert new is state_b
        assert new.result == 99
예제 #14
0
 def test_state_type_methods_with_cached_state(self):
     state = Cached()
     assert state.is_cached()
     assert not state.is_retrying()
     assert not state.is_pending()
     assert not state.is_running()
     assert state.is_finished()
     assert not state.is_skipped()
     assert not state.is_scheduled()
     assert state.is_successful()
     assert not state.is_failed()
     assert not state.is_mapped()
     assert not state.is_meta_state()
예제 #15
0
def test_task_runner_validates_cached_states_if_task_has_caching(client):
    @prefect.task(cache_for=datetime.timedelta(minutes=1), result=PrefectResult())
    def cached_task():
        return 42

    state = Cached(
        cached_result_expiration=datetime.datetime.utcnow()
        - datetime.timedelta(minutes=2),
        result=PrefectResult(location="99"),
    )
    old_state = Cached(
        cached_result_expiration=datetime.datetime.utcnow()
        - datetime.timedelta(days=1),
        result=PrefectResult(location="13"),
    )
    client.get_latest_cached_states = MagicMock(return_value=[state, old_state])

    res = CloudTaskRunner(task=cached_task).run()
    assert client.get_latest_cached_states.called
    assert res.is_successful()
    assert res.is_cached()
    assert res.result == 42
예제 #16
0
def test_task_runner_queries_for_cached_states_if_task_has_caching(client):
    @prefect.task(cache_for=datetime.timedelta(minutes=1))
    def cached_task():
        return 42

    state = Cached(
        cached_result_expiration=datetime.datetime.utcnow()
        + datetime.timedelta(days=1),
        result=Result(99, JSONResultHandler()),
    )
    old_state = Cached(
        cached_result_expiration=datetime.datetime.utcnow()
        - datetime.timedelta(days=1),
        result=13,
    )
    client.get_latest_cached_states = MagicMock(return_value=[state, old_state])

    res = CloudTaskRunner(task=cached_task).run()
    assert client.get_latest_cached_states.called
    assert res.is_successful()
    assert res.is_cached()
    assert res.result == 99
예제 #17
0
def test_task_runner_treats_unfound_files_as_invalid_caches(client, tmpdir):
    @prefect.task(cache_for=datetime.timedelta(minutes=1), result=PrefectResult())
    def cached_task():
        return 42

    state = Cached(
        cached_result_expiration=datetime.datetime.utcnow()
        + datetime.timedelta(minutes=2),
        result=LocalResult(location=str(tmpdir / "made_up_data.prefect")),
    )
    old_state = Cached(
        cached_result_expiration=datetime.datetime.utcnow()
        + datetime.timedelta(days=1),
        result=PrefectResult(location="13"),
    )
    client.get_latest_cached_states = MagicMock(return_value=[state, old_state])

    res = CloudTaskRunner(task=cached_task).run()
    assert client.get_latest_cached_states.called
    assert res.is_successful()
    assert res.is_cached()
    assert res.result == 13
예제 #18
0
def test_serialize_and_deserialize_on_raw_cached_state():
    now = pendulum.now("utc")
    state = Cached(
        cached_inputs=dict(x=Result(99), p=Result("p")),
        result=dict(hi=5, bye=6),
        cached_result_expiration=now,
    )
    serialized = state.serialize()
    new_state = State.deserialize(serialized)
    assert isinstance(new_state, Cached)
    assert new_state.color == state.color
    assert new_state.result is None
    assert new_state.cached_result_expiration == state.cached_result_expiration
    assert new_state.cached_inputs == dict.fromkeys(["x", "p"], NoResult)
예제 #19
0
def test_preparing_state_for_cloud_doesnt_copy_data():
    class FakeHandler(ResultHandler):
        def read(self, val):
            return val

        def write(self, val):
            return val

    value = 124.090909
    result = Result(value, result_handler=FakeHandler())
    state = Cached(result=result)
    cloud_state = prepare_state_for_cloud(state)
    assert cloud_state.is_cached()
    assert cloud_state.result is state.result
예제 #20
0
    def check_target(self, state: State, inputs: Dict[str, Result]) -> State:
        """
        Checks if a Result exists at the task's target.

        Args:
            - state (State): the current state of this task
            - inputs (Dict[str, Result]): a dictionary of inputs whose keys correspond
                to the task's `run()` arguments.

        Returns:
            - State: the state of the task after running the check
        """
        from dask.base import tokenize

        result = self.result
        target = self.task.target

        if result and target:
            raw_inputs = {k: r.value for k, r in inputs.items()}
            formatting_kwargs = {
                **prefect.context.get("parameters", {}).copy(),
                **prefect.context,
                **raw_inputs,
            }

            # self can't be used as a formatting parameter because it would ruin all method calls such as
            # result.exists() by providing two values of self
            formatting_kwargs.pop("self", None)

            if not isinstance(target, str):
                target = target(**formatting_kwargs)

            if result.exists(target, **formatting_kwargs):  # type: ignore
                known_location = target.format(
                    **formatting_kwargs)  # type: ignore
                new_res = result.read(known_location)
                cached_state = Cached(
                    result=new_res,
                    hashed_inputs={
                        key: tokenize(val.value)
                        for key, val in inputs.items()
                    },
                    cached_result_expiration=None,
                    cached_parameters=formatting_kwargs.get("parameters"),
                    message=f"Result found at task target {known_location}",
                )
                return cached_state

        return state
예제 #21
0
def test_serialize_and_deserialize_on_mixed_cached_state():
    safe_dct = SafeResult(dict(hi=5, bye=6), result_handler=JSONResultHandler())
    now = pendulum.now("utc")
    state = Cached(
        cached_inputs=dict(x=Result(2), p=Result("p")),
        result=safe_dct,
        cached_result_expiration=now,
    )
    serialized = state.serialize()
    new_state = State.deserialize(serialized)
    assert isinstance(new_state, Cached)
    assert new_state.color == state.color
    assert new_state.result == dict(hi=5, bye=6)
    assert new_state.cached_result_expiration == state.cached_result_expiration
    assert new_state.cached_inputs == dict.fromkeys(["x", "p"], NoResult)
예제 #22
0
def test_serialize_and_deserialize_on_mixed_cached_state():
    safe_dct = PrefectResult(location=json.dumps(dict(hi=5, bye=6)))
    now = pendulum.now("utc")
    state = Cached(
        cached_inputs=dict(x=PrefectResult(value=2), p=PrefectResult(value="p")),
        result=safe_dct,
        cached_result_expiration=now,
    )
    serialized = state.serialize()
    new_state = State.deserialize(serialized)
    assert isinstance(new_state, Cached)
    assert new_state.color == state.color
    assert new_state._result.location == json.dumps(dict(hi=5, bye=6))
    assert new_state.cached_result_expiration == state.cached_result_expiration
    assert new_state.cached_inputs == dict.fromkeys(["x", "p"], PrefectResult())
예제 #23
0
def test_serialize_and_deserialize_on_safe_cached_state():
    safe = SafeResult("99", result_handler=JSONResultHandler())
    safe_dct = SafeResult(dict(hi=5, bye=6), result_handler=JSONResultHandler())
    now = pendulum.now("utc")
    state = Cached(
        cached_inputs=dict(x=safe, p=safe),
        result=safe_dct,
        cached_result_expiration=now,
    )
    serialized = state.serialize()
    new_state = State.deserialize(serialized)
    assert isinstance(new_state, Cached)
    assert new_state.color == state.color
    assert new_state.result == dict(hi=5, bye=6)
    assert new_state.cached_result_expiration == state.cached_result_expiration
    assert new_state.cached_inputs == state.cached_inputs
예제 #24
0
    def test_reads_result_if_cached_valid(self, client):
        result = PrefectResult(location="2")

        with pytest.warns(UserWarning):
            task = Task(cache_validator=duration_only, result=PrefectResult())

        state = Cached(
            result=result,
            cached_result_expiration=pendulum.now("utc").add(minutes=1))

        client.get_latest_cached_states = MagicMock(return_value=[])

        new = CloudTaskRunner(task).check_task_is_cached(
            state=state, inputs={"a": PrefectResult(value=1)})
        assert new is state
        assert new.result == 2
예제 #25
0
    def check_target(self, state: State, inputs: Dict[str, Result]) -> State:
        """
        Checks if a Result exists at the task's target.

        Args:
            - state (State): the current state of this task
            - inputs (Dict[str, Result]): a dictionary of inputs whose keys correspond
                to the task's `run()` arguments.

        Returns:
            - State: the state of the task after running the check
        """
        result = self.result
        target = self.task.target

        if result and target:
            raw_inputs = {k: r.value for k, r in inputs.items()}
            formatting_kwargs = {
                **prefect.context.get("parameters", {}).copy(),
                **prefect.context,
                **raw_inputs,
            }

            if not isinstance(target, str):
                target = target(**formatting_kwargs)

            if result.exists(target, **formatting_kwargs):  # type: ignore
                known_location = target.format(
                    **formatting_kwargs)  # type: ignore
                new_res = result.read(known_location)
                cached_state = Cached(
                    result=new_res,
                    hashed_inputs={
                        key: tokenize(list(val.value.keys()))
                        if type(val.value) == dict else tokenize(val.value)
                        for key, val in inputs.items()
                    },
                    cached_result_expiration=None,
                    cached_parameters=formatting_kwargs.get("parameters"),
                    message=f"Result found at task target {known_location}",
                )
                return cached_state

        return state
예제 #26
0
    def test_reads_result_using_handler_attribute_if_cached_valid(
            self, client):
        class MyResult(Result):
            def read(self, *args, **kwargs):
                self.value = 53
                return self

        with pytest.warns(UserWarning):
            task = Task(cache_validator=duration_only, result=MyResult())
        result = PrefectResult(location="2")
        state = Cached(
            result=result,
            cached_result_expiration=pendulum.now("utc").add(minutes=1))

        client.get_latest_cached_states = MagicMock(return_value=[])

        new = CloudTaskRunner(task).check_task_is_cached(
            state=state, inputs={"a": Result(1)})
        assert new is state
        assert new.result == 53
예제 #27
0
    def test_reads_result_if_cached_valid_using_task_result(task, client):
        class MyResult(Result):
            def read(self, *args, **kwargs):
                self.value = 53
                return self

        task = Task(
            result=MyResult(),
            cache_for=datetime.timedelta(minutes=1),
            cache_validator=duration_only,
        )
        state = Cached(
            result=PrefectResult(location="2"),
            cached_result_expiration=pendulum.now("utc").add(minutes=1),
        )

        client.get_latest_cached_states = MagicMock(return_value=[state])
        new = CloudTaskRunner(task).check_task_is_cached(
            state=Pending(), inputs={"a": Result(1)})
        assert new is state
        assert new.result == 53
예제 #28
0
 def test_additional_parameters_invalidate(self):
     state = Cached(cached_parameters=dict(x=1, s="str"))
     assert all_parameters(state, None, dict(x=1, s="str",
                                             noise="e")) is False
예제 #29
0
    def test_skipped_is_success(self):
        assert issubclass(Skipped, Success)

    def test_timedout_is_failed(self):
        assert issubclass(TimedOut, Failed)

    def test_trigger_failed_is_failed(self):
        assert issubclass(TriggerFailed, Failed)


@pytest.mark.parametrize(
    "state_check",
    [
        dict(state=Cancelled(), assert_true={"is_finished", "is_failed"}),
        dict(state=Cached(), assert_true={"is_cached", "is_finished", "is_successful"}),
        dict(state=ClientFailed(), assert_true={"is_meta_state"}),
        dict(state=Failed(), assert_true={"is_finished", "is_failed"}),
        dict(state=Finished(), assert_true={"is_finished"}),
        dict(state=Looped(), assert_true={"is_finished", "is_looped"}),
        dict(state=Mapped(), assert_true={"is_finished", "is_mapped", "is_successful"}),
        dict(state=Paused(), assert_true={"is_pending", "is_scheduled"}),
        dict(state=Pending(), assert_true={"is_pending"}),
        dict(state=Queued(), assert_true={"is_meta_state", "is_queued"}),
        dict(state=Resume(), assert_true={"is_pending", "is_scheduled"}),
        dict(
            state=Retrying(), assert_true={"is_pending", "is_scheduled", "is_retrying"}
        ),
        dict(state=Running(), assert_true={"is_running"}),
        dict(state=Scheduled(), assert_true={"is_pending", "is_scheduled"}),
        dict(
예제 #30
0
    def test_skipped_is_success(self):
        assert issubclass(Skipped, Success)

    def test_timedout_is_failed(self):
        assert issubclass(TimedOut, Failed)

    def test_trigger_failed_is_failed(self):
        assert issubclass(TriggerFailed, Failed)


@pytest.mark.parametrize(
    "state_check",
    [
        dict(state=Cancelled(), assert_true={"is_finished"}),
        dict(state=Cached(),
             assert_true={"is_cached", "is_finished", "is_successful"}),
        dict(state=ClientFailed(), assert_true={"is_meta_state"}),
        dict(state=Failed(), assert_true={"is_finished", "is_failed"}),
        dict(state=Finished(), assert_true={"is_finished"}),
        dict(state=Looped(), assert_true={"is_finished", "is_looped"}),
        dict(state=Mapped(),
             assert_true={"is_finished", "is_mapped", "is_successful"}),
        dict(state=Paused(), assert_true={"is_pending", "is_scheduled"}),
        dict(state=Pending(), assert_true={"is_pending"}),
        dict(state=Queued(), assert_true={"is_meta_state", "is_queued"}),
        dict(state=Resume(), assert_true={"is_pending", "is_scheduled"}),
        dict(state=Retrying(),
             assert_true={"is_pending", "is_scheduled", "is_retrying"}),
        dict(state=Running(), assert_true={"is_running"}),
        dict(state=Scheduled(), assert_true={"is_pending", "is_scheduled"}),