Example #1
0
 def test_cache_validator_provided_if_needed(self):
     """
     If `cache_for` is provided, and `cache_validator` is not,
     a `cache_validator` should be provided.
     """
     r = Result(value=3, cache_for=datetime.timedelta(days=2))
     assert r.cache_validator is not None
     assert callable(r.cache_validator)
Example #2
0
def test_preparing_state_for_cloud_does_nothing_if_result_is_none(cls):
    xres = Result(None, result_handler=JSONResultHandler())
    state = prepare_state_for_cloud(cls(cached_inputs=dict(x=xres)))
    assert isinstance(state, cls)
    assert state.result is None
    assert state._result == NoResult
    assert state.cached_inputs == dict(x=xres)
    assert state.serialize()["cached_inputs"]["x"]["type"] == "NoResultType"
Example #3
0
    async def test_set_flow_run_state_with_result(self, run_query,
                                                  flow_run_id):
        result = Result(10, result_handler=JSONResultHandler())
        result.store_safe_value()
        state = Success(result=result)

        result = await run_query(
            query=self.mutation,
            variables=dict(input=dict(states=[
                dict(flow_run_id=flow_run_id, state=state.serialize())
            ])),
        )
        fr = await models.FlowRun.where(
            id=result.data.set_flow_run_states.states[0].id
        ).first({"state", "version"})
        assert fr.version == 2
        assert fr.state == "Success"
Example #4
0
def test_preparing_state_for_cloud_replaces_cached_inputs_with_safe(cls):
    xres = Result(3, result_handler=JSONResultHandler())
    state = prepare_state_for_cloud(cls(cached_inputs=dict(x=xres)))
    assert isinstance(state, cls)
    assert state.result is None
    assert state._result == NoResult
    assert state.cached_inputs == dict(x=xres)
    assert state.serialize()["cached_inputs"]["x"]["value"] == "3"
Example #5
0
    def test_result_validate_warns_when_run_without_run_validators_flag(
            self, caplog):
        _example_function = MagicMock(return_value=True)

        r = Result(value=None,
                   validators=[_example_function],
                   run_validators=False)
        with caplog.at_level(logging.WARNING, "prefect.Result"):
            is_valid = r.validate()

        # it should have acted normal and called the validate functions
        _example_function.assert_called_once_with(r)
        assert is_valid is True

        # but ALSO it should published a warning log, going on about run_validators not being set
        assert caplog.text.find("WARNING") > -1
        assert caplog.text.find("run_validators") > -1
Example #6
0
    def run_task(
        self,
        task: Task,
        state: State,
        upstream_states: Dict[Edge, State],
        context: Dict[str, Any],
        task_runner_state_handlers: Iterable[Callable],
        executor: "prefect.engine.executors.Executor",
    ) -> State:
        """

        Runs a specific task. This method is intended to be called by submitting it to
        an executor.

        Args:
            - task (Task): the task to run
            - state (State): starting state for the Flow. Defaults to
                `Pending`
            - upstream_states (Dict[Edge, State]): dictionary of upstream states
            - context (Dict[str, Any]): a context dictionary for the task run
            - task_runner_state_handlers (Iterable[Callable]): A list of state change
                handlers that will be provided to the task_runner, and called whenever a task changes
                state.
            - executor (Executor): executor to use when performing
                computation; defaults to the executor provided in your prefect configuration

        Returns:
            - State: `State` representing the final post-run state of the `Flow`.

        """
        with prefect.context(self.context):
            default_result = task.result or self.flow.result
            task_runner = self.task_runner_cls(
                task=task,
                state_handlers=task_runner_state_handlers,
                result=default_result or Result(),
                default_result=self.flow.result,
            )

            # if this task reduces over a mapped state, make sure its children have finished
            for edge, upstream_state in upstream_states.items():

                # if the upstream state is Mapped, wait until its results are all available
                if not edge.mapped and upstream_state.is_mapped():
                    assert isinstance(upstream_state, Mapped)  # mypy assert
                    upstream_state.map_states = executor.wait(
                        upstream_state.map_states)
                    upstream_state.result = [
                        s.result for s in upstream_state.map_states
                    ]

            return task_runner.run(
                state=state,
                upstream_states=upstream_states,
                context=context,
                executor=executor,
            )
Example #7
0
    def __init__(
        self,
        task: Task,
        state_handlers: Iterable[Callable] = None,
        flow_result: Result = None,
    ):
        self.context = prefect.context.to_dict()
        self.task = task

        # Use result from task over the one provided off the parent Flow object
        if task.result:
            self.result = task.result
        else:
            self.result = Result().copy(
            ) if flow_result is None else flow_result.copy()

        self.flow_result = flow_result
        super().__init__(state_handlers=state_handlers)
def test_task_runner_puts_cloud_in_context(client):
    @prefect.task(result=Result())
    def whats_in_ctx():
        return prefect.context.get("checkpointing")

    res = CloudTaskRunner(task=whats_in_ctx).run()

    assert res.is_successful()
    assert res.result is True
Example #9
0
def test_preparing_state_for_cloud_ignores_the_lack_of_result_handlers_for_cached_inputs(
    cls, ):
    xres = Result(3, result_handler=None)
    state = prepare_state_for_cloud(cls(cached_inputs=dict(x=xres)))
    assert isinstance(state, cls)
    assert state.result is None
    assert state._result == NoResult
    assert state.cached_inputs == dict(x=xres)
    assert state.serialize()["cached_inputs"]["x"]["type"] == "NoResultType"
Example #10
0
def test_deserialize_json_with_context():
    deserialized = StateSchema().load(
        {"type": "Running", "context": {"boo": ["a", "b", "c"]}}
    )
    assert type(deserialized) is state.Running
    assert deserialized.is_running()
    assert deserialized.message is None
    assert deserialized.context == dict(boo=["a", "b", "c"])
    assert deserialized.result is None
    assert deserialized._result == Result()
Example #11
0
    def __init__(
        self,
        task: Task,
        state_handlers: Iterable[Callable] = None,
        flow_result: Result = None,
    ):
        self.context = prefect.context.to_dict()
        self.task = task

        # if the result was provided off the parent Flow object
        # we want to use the task's target as the target location
        if task.result:
            self.result = task.result
        else:
            self.result = Result() if flow_result is None else flow_result
            if self.task.target:
                self.result.location = self.task.target
        self.flow_result = flow_result
        super().__init__(state_handlers=state_handlers)
Example #12
0
def test_serialize_state_with_handled_result(cls):
    res = Result(value=1, location="src/place")
    state = cls(message="message", result=res)
    serialized = StateSchema().dump(state)
    assert isinstance(serialized, dict)
    assert serialized["type"] == cls.__name__
    assert serialized["message"] == "message"
    assert serialized["_result"]["type"] == "Result"
    assert serialized["_result"]["location"] == "src/place"
    assert serialized["__version__"] == prefect.__version__
Example #13
0
def test_deserialize_mapped():
    s = state.Success(message="1", result=1)
    f = state.Failed(message="2", result=2)
    serialized = StateSchema().dump(state.Mapped(message="message", map_states=[s, f]))
    deserialized = StateSchema().load(serialized)
    assert isinstance(deserialized, state.Mapped)
    assert len(deserialized.map_states) == 2
    assert deserialized.map_states == [None, None]
    assert deserialized._result == Result()
    assert deserialized.result is None
Example #14
0
def test_has_abstract_interfaces(abstract_interface: str):
    """
    Tests to make sure that calling the abstract interfaces directly
    on the base `Result` class results in `NotImplementedError`s.
    """
    r = Result(value=3)

    func = getattr(r, abstract_interface)
    with pytest.raises(NotImplementedError):
        func()
Example #15
0
    def test_state_load_cached_results_reads_if_location_is_provided(self, cls):
        class MyResult(Result):
            def read(self, *args, **kwargs):
                self.value = "bar"
                return self

        state = cls(cached_inputs=dict(y=Result()))
        new_state = state.load_cached_results(dict(y=MyResult(location="foo")))
        assert new_state.cached_inputs["y"].value == "bar"
        assert new_state.cached_inputs["y"].location == "foo"
Example #16
0
    def check_for_retry(self, state: State, inputs: Dict[str,
                                                         Result]) -> State:
        """
        Checks to see if a FAILED task should be retried.

        Args:
            - state (State): the current state of this task
            - inputs (Dict[str, Result], optional): a dictionary of inputs whose keys correspond
                to the task's `run()` arguments.

        Returns:
            - State: the state of the task after running the check
        """
        if state.is_failed():
            run_count = prefect.context.get("task_run_count", 1)
            if prefect.context.get("task_loop_count") is not None:
                loop_context = {
                    "_loop_count":
                    Result(
                        value=prefect.context["task_loop_count"],
                        result_handler=JSONResultHandler(),
                    ),
                    "_loop_result":
                    Result(
                        value=prefect.context.get("task_loop_result"),
                        result_handler=self.result_handler,
                    ),
                }
                inputs.update(loop_context)
            if run_count <= self.task.max_retries:
                start_time = pendulum.now("utc") + self.task.retry_delay
                msg = "Retrying Task (after attempt {n} of {m})".format(
                    n=run_count, m=self.task.max_retries + 1)
                retry_state = Retrying(
                    start_time=start_time,
                    cached_inputs=inputs,
                    message=msg,
                    run_count=run_count,
                )
                return retry_state

        return state
Example #17
0
    def test_uses_provided_cache_validator(self):
        def custom_cache_validator(*args, **kwargs):
            # Creating a custom function for identity comparison
            return True

        r = Result(
            value=3,
            cache_for=datetime.timedelta(days=2),
            cache_validator=custom_cache_validator,
        )
        assert r.cache_validator is custom_cache_validator
Example #18
0
    def test_result_inits_with_value(self):
        r = Result(3)
        assert r.value == 3
        assert r.safe_value is NoResult
        assert r.result_handler is None
        assert r.validators is None
        assert r.cache_for is None
        assert r.cache_validator is None
        assert r.filepath_template is None
        assert r.run_validators is True

        s = Result(value=5)
        assert s.value == 5
        assert s.safe_value is NoResult
        assert s.result_handler is None
        assert s.validators is None
        assert s.cache_for is None
        assert s.cache_validator is None
        assert s.filepath_template is None
        assert r.run_validators is True
Example #19
0
def test_task_failure_caches_constant_inputs_automatically(client):
    @prefect.task(max_retries=2, retry_delay=timedelta(seconds=100))
    def is_p_three(p):
        if p == 3:
            raise ValueError("No thank you.")

    with prefect.Flow("test") as f:
        res = is_p_three(3)

    state = CloudFlowRunner(flow=f).run(return_tasks=[res])
    assert state.is_running()
    assert isinstance(state.result[res], Retrying)
    exp_res = Result(3, result_handler=ConstantResultHandler(3))
    assert not state.result[res].cached_inputs["p"] == exp_res
    exp_res.store_safe_value()
    assert state.result[res].cached_inputs["p"] == exp_res

    last_state = client.set_task_run_state.call_args_list[-1][-1]["state"]
    assert isinstance(last_state, Retrying)
    assert last_state.cached_inputs["p"] == exp_res
Example #20
0
    def test_create_state_with_tags_in_context(self, cls):
        with prefect.context(task_tags=set("abcdef")):
            state = cls()
        assert state.message is None
        assert state.result is None
        assert state._result == Result()
        assert state.context == dict(tags=list(set("abcdef")))

        with prefect.context(task_tags=set("abcdef")):
            state = cls(context={"tags": ["foo"]})
        assert state.context == dict(tags=["foo"])
Example #21
0
 def test_validate_on_kwarg(self):
     state = Cached(cached_inputs=dict(x=Result(1), s=Result("str")))
     assert (
         partial_inputs_only(validate_on=["x", "s"])(state, dict(x=1, s="str"), None)
         is True
     )
     state = Cached(cached_inputs=dict(x=Result(1), s=Result("str")))
     assert (
         partial_inputs_only(validate_on=["x", "s"])(
             state, dict(x=1, s="strs"), None
         )
         is False
     )
     assert (
         partial_inputs_only(validate_on=["x"])(state, dict(x=1, s="strs"), None)
         is True
     )
     assert (
         partial_inputs_only(validate_on=["s"])(state, dict(x=1, s="strs"), None)
         is False
     )
Example #22
0
def test_task_failure_with_upstream_secrets_doesnt_store_secret_value_and_recompute_if_necessary(
    client, ):
    @prefect.task(max_retries=2, retry_delay=timedelta(seconds=100))
    def is_p_three(p):
        if p == 3:
            raise ValueError("No thank you.")
        return p

    with prefect.Flow("test", result_handler=JSONResultHandler()) as f:
        p = prefect.tasks.secrets.Secret("p")
        res = is_p_three(p)

    with prefect.context(secrets=dict(p=3)):
        state = CloudFlowRunner(flow=f).run(return_tasks=[res])

    assert state.is_running()
    assert isinstance(state.result[res], Retrying)

    exp_res = Result(3, result_handler=SecretResultHandler(p))
    assert not state.result[res].cached_inputs["p"] == exp_res
    exp_res.store_safe_value()
    assert state.result[res].cached_inputs["p"] == exp_res

    ## here we set the result of the secret to a saferesult, ensuring
    ## it will get converted to a "true" result;
    ## we expect that the upstream value will actually get recomputed from context
    ## through the SecretResultHandler
    safe = SafeResult("p", result_handler=SecretResultHandler(p))
    state.result[p] = Success(result=safe)
    state.result[res].start_time = pendulum.now("utc")
    state.result[res].cached_inputs = dict(p=safe)

    with prefect.context(secrets=dict(p=4)):
        new_state = CloudFlowRunner(flow=f).run(return_tasks=[res],
                                                task_states=state.result)

    assert new_state.is_successful()
    assert new_state.result[res].result == 4
Example #23
0
def test_preparing_state_for_cloud_doesnt_copy_data():
    class FakeHandler(ResultHandler):
        def read(self, val):
            return val

        def write(self, val):
            return val

    value = 124.090909
    result = Result(value, result_handler=FakeHandler())
    state = Cached(result=result)
    cloud_state = prepare_state_for_cloud(state)
    assert cloud_state.is_cached()
    assert cloud_state.result is state.result
def checkpoint_handler(task_runner: DSTaskRunner, old_state: State, new_state: State) -> State:
    """
    A handler designed to implement result caching by filename. If the result handler's ``read``
    method can be successfully run, this handler loads the result of that method as the task result
    and sets the task state to ``Success``. Similarly, on successful
    completion of the task, if the task was actually run and not loaded from cache, this handler
    will apply the result handler's ``write`` method to the task.

    Parameters
    ----------
    task_runner : instance of DSTaskRunner
        The task runner associated with the flow the handler is used in.
    old_state : instance of prefect.engine.state.State
        The current state of the task.
    new_state : instance of prefect.engine.state.State
        The expected new state of the task.

    Returns
    -------
    new_state : instance of prefect.engine.state.State
        The actual new state of the task.
    """
    if "PREFECT__FLOWS__CHECKPOINTING" in os.environ and os.environ["PREFECT__FLOWS__CHECKPOINTING"] == "true":
        raise AttributeError("Cannot use standard prefect checkpointing with this handler")

    if task_runner.result_handler is not None and old_state.is_pending() and new_state.is_running():
        if not hasattr(task_runner, "upstream_states"):
            raise TypeError(
                "upstream_states not found in task runner. Make sure to use "
                "prefect_ds.task_runner.DSTaskRunner."
            )
        input_mapping = _create_input_mapping(task_runner.upstream_states)
        try:
            data = task_runner.task.result_handler.read(input_mapping=input_mapping)
        except FileNotFoundError:
            return new_state
        except TypeError: # unexpected argument input_mapping
            raise TypeError(
                "Result handler could not accept input_mapping argument. "
                "Please ensure that you are using a handler from prefect_ds."
            )
        result = Result(value=data, result_handler=task_runner.task.result_handler)
        state = Success(result=result, message="Task loaded from disk.")
        return state

    if task_runner.result_handler is not None and old_state.is_running() and new_state.is_successful():
        input_mapping = _create_input_mapping(task_runner.upstream_states)
        task_runner.task.result_handler.write(new_state.result, input_mapping=input_mapping)

    return new_state
Example #25
0
def test_cloud_task_runner_handles_retries_with_queued_states_from_cloud(client):
    calls = []

    def queued_mock(*args, **kwargs):
        calls.append(kwargs)
        # first retry attempt will get queued
        if len(calls) == 4:
            return Queued()  # immediate start time
        else:
            return kwargs.get("state")

    client.set_task_run_state = queued_mock

    @prefect.task(
        max_retries=2,
        retry_delay=datetime.timedelta(seconds=0),
        result_handler=ResultHandler(),
    )
    def tagged_task(x):
        if prefect.context.get("task_run_count", 1) == 1:
            raise ValueError("gimme a sec")
        return x

    upstream_result = Result(value=42, result_handler=JSONResultHandler())
    res = CloudTaskRunner(task=tagged_task).run(
        context={"task_run_version": 1},
        state=None,
        upstream_states={
            Edge(Task(), tagged_task, key="x"): Success(result=upstream_result)
        },
        executor=prefect.engine.executors.LocalExecutor(),
    )

    assert res.is_successful()
    assert res.result == 42
    assert (
        len(calls) == 6
    )  # Running -> Failed -> Retrying -> Queued -> Running -> Success
    assert [type(c["state"]).__name__ for c in calls] == [
        "Running",
        "Failed",
        "Retrying",
        "Running",
        "Running",
        "Success",
    ]

    # ensures result handler was called and persisted
    assert calls[2]["state"].cached_inputs["x"].safe_value.value == "42"
Example #26
0
    def test_state_load_result_reads_if_location_is_provided(self, cls):
        class MyResult(Result):
            def read(self, *args, **kwargs):
                self.value = "bar"
                return self

        state = cls(result=Result())
        assert state.message is None
        assert state.result is None
        assert state._result.location is None

        new_state = state.load_result(MyResult(location="foo"))
        assert new_state.message is None
        assert new_state.result == "bar"
        assert new_state._result.location == "foo"
def test_task_runner_validates_cached_states_if_task_has_caching(client):
    @prefect.task(cache_for=datetime.timedelta(minutes=1),
                  result_handler=JSONResultHandler())
    def cached_task():
        return 42

    state = Cached(
        cached_result_expiration=datetime.datetime.utcnow() -
        datetime.timedelta(minutes=2),
        result=Result(99, JSONResultHandler()),
    )
    old_state = Cached(
        cached_result_expiration=datetime.datetime.utcnow() -
        datetime.timedelta(days=1),
        result=Result(13, JSONResultHandler()),
    )
    client.get_latest_cached_states = MagicMock(
        return_value=[state, old_state])

    res = CloudTaskRunner(task=cached_task).run()
    assert client.get_latest_cached_states.called
    assert res.is_successful()
    assert res.is_cached()
    assert res.result == 42
Example #28
0
    def test_state_load_cached_results_doesnt_call_read_if_location_is_none(
            self, cls):
        """
        If both the value and location information are None, we assume that None is the
        correct return value and perform no action.
        """
        class MyResult(Result):
            def read(self, *args, **kwargs):
                self.location = "foo"
                self.value = "bar"
                return self

        state = cls(cached_inputs=dict(x=Result()))
        new_state = state.load_cached_results(dict(x=MyResult()))
        assert new_state.cached_inputs["x"].value is None
        assert new_state.cached_inputs["x"].location is None
Example #29
0
    def test_state_load_cached_results_doesnt_call_read_if_value_present(
            self, cls):
        """
        This test ensures that multiple calls to `load_result` will not result in
        multiple redundant reads from the remote result location.
        """
        class MyResult(Result):
            def read(self, *args, **kwargs):
                self.location = "foo"
                self.value = "bar"
                return self

        state = cls(cached_inputs=dict(x=Result(value=42)))

        new_state = state.load_cached_results(dict(x=MyResult()))
        assert new_state.cached_inputs["x"].value == 42
        assert new_state.cached_inputs["x"].location is None
Example #30
0
    def test_state_load_cached_results_calls_read(self, cls):
        """
        This test ensures that the read logic of the provided result is
        used instead of self._result; this is important when "hydrating" JSON
        representations of Results objects that come from Cloud.
        """
        class MyResult(Result):
            def read(self, *args, **kwargs):
                self.location = "foo"
                self.value = 42
                return self

        state = cls(cached_inputs=dict(x=Result()))
        new_state = state.load_cached_results(dict(x=MyResult(location="")))

        assert new_state.cached_inputs["x"].value == 42
        assert new_state.cached_inputs["x"].location == "foo"