def test_task_runner_validates_cached_state_inputs_with_upstream_handlers_if_task_has_caching( client, ): class Handler(ResultHandler): def read(self, val): return 1337 @prefect.task( cache_for=datetime.timedelta(minutes=1), cache_validator=all_inputs, result_handler=JSONResultHandler(), ) def cached_task(x): return 42 dull_state = Cached( cached_result_expiration=datetime.datetime.utcnow() + datetime.timedelta(minutes=2), result=SafeResult("-1", JSONResultHandler()), ) state = Cached( cached_result_expiration=datetime.datetime.utcnow() + datetime.timedelta(minutes=2), result=SafeResult("99", JSONResultHandler()), cached_inputs={ "x": SafeResult("2", result_handler=JSONResultHandler()) }, ) client.get_latest_cached_states = MagicMock( return_value=[dull_state, state]) res = CloudTaskRunner(task=cached_task).check_task_is_cached( Pending(), inputs={"x": Result(2, result_handler=Handler())}) assert client.get_latest_cached_states.called assert res.is_pending()
def test_inputs_validate_with_defaults(self): state = Cached(cached_inputs=dict(x=1, s="str")) assert partial_inputs_only(None)(state, dict(x=1, s="str"), None) is True state = Cached(cached_inputs=dict(x=1, s="str")) assert partial_inputs_only(None)(state, dict(x=1, s="strs"), None) is True
def test_handles_none(self): state = Cached(cached_parameters=dict(x=5)) assert partial_inputs_only(validate_on=["x"])(state, dict(x=5), None) is False state = Cached(cached_inputs=dict(x=5)) assert partial_inputs_only(validate_on=["x"])(state, None, None) is False
def test_task_runner_validates_cached_state_inputs_if_task_has_caching( self, client): @prefect.task( cache_for=datetime.timedelta(minutes=1), cache_validator=all_inputs, result=PrefectResult(), ) def cached_task(x): return 42 dull_state = Cached( cached_result_expiration=datetime.datetime.utcnow() + datetime.timedelta(minutes=2), result=PrefectResult(location="-1"), ) state = Cached( cached_result_expiration=datetime.datetime.utcnow() + datetime.timedelta(minutes=2), result=PrefectResult(location="99"), cached_inputs={"x": PrefectResult(location="2")}, ) client.get_latest_cached_states = MagicMock( return_value=[dull_state, state]) res = CloudTaskRunner(task=cached_task).check_task_is_cached( Pending(), inputs={"x": PrefectResult(value=2)}) assert client.get_latest_cached_states.called assert res.is_successful() assert res.is_cached() assert res.result == 99
def test_parameters_validate_with_defaults(self): state = Cached(cached_parameters=dict(x=1, s="str")) assert partial_parameters_only()(state, None, dict(x=1, s="str")) is True state = Cached(cached_parameters=dict(x=1, s="str")) assert partial_parameters_only()(state, None, dict(x=1, s="strs")) is True
def test_task_runner_validates_cached_state_inputs_if_task_has_caching(client): @prefect.task( cache_for=datetime.timedelta(minutes=1), cache_validator=all_inputs, result_handler=JSONResultHandler(), ) def cached_task(x): return 42 dull_state = Cached( cached_result_expiration=datetime.datetime.utcnow() + datetime.timedelta(minutes=2), result=Result(-1, JSONResultHandler()), ) state = Cached( cached_result_expiration=datetime.datetime.utcnow() + datetime.timedelta(minutes=2), result=Result(99, JSONResultHandler()), cached_inputs={ "x": SafeResult("2", result_handler=JSONResultHandler()) }, ) client.get_latest_cached_states = MagicMock( return_value=[dull_state, state]) res = CloudTaskRunner(task=cached_task).check_task_is_cached( Pending(), inputs={"x": Result(2, result_handler=LocalResultHandler())}) assert client.get_latest_cached_states.called assert res.is_successful() assert res.is_cached() assert res.result == 99
def test_task_runner_validates_cached_state_inputs_with_upstream_handlers_if_task_has_caching( self, client): class MyResult(Result): def read(self, *args, **kwargs): new = self.copy() new.value = 1337 return new @prefect.task( cache_for=datetime.timedelta(minutes=1), cache_validator=all_inputs, result=PrefectResult(), ) def cached_task(x): return 42 dull_state = Cached( cached_result_expiration=datetime.datetime.utcnow() + datetime.timedelta(minutes=2), result=PrefectResult(location="-1"), ) state = Cached( cached_result_expiration=datetime.datetime.utcnow() + datetime.timedelta(minutes=2), result=PrefectResult(location="99"), cached_inputs={"x": PrefectResult(location="2")}, ) client.get_latest_cached_states = MagicMock( return_value=[dull_state, state]) res = CloudTaskRunner(task=cached_task).check_task_is_cached( Pending(), inputs={"x": MyResult(value=2)}) assert client.get_latest_cached_states.called assert res.is_pending()
def test_validate_on_kwarg(self): state = Cached(cached_inputs=dict(x=1, s="str")) assert (partial_inputs_only(validate_on=["x", "s"])( state, dict(x=1, s="str"), None) is True) state = Cached(cached_inputs=dict(x=1, s="str")) assert (partial_inputs_only(validate_on=["x", "s"])( state, dict(x=1, s="strs"), None) is False) assert (partial_inputs_only(validate_on=["x"])( state, dict(x=1, s="strs"), None) is True) assert (partial_inputs_only(validate_on=["s"])( state, dict(x=1, s="strs"), None) is False)
def test_validate_on_kwarg(self): state = Cached(cached_parameters=dict(x=1, s="str")) assert (partial_parameters_only( state, None, dict(x=1, s="str"), validate_on=["x", "s"]) is True) state = Cached(cached_parameters=dict(x=1, s="str")) assert (partial_parameters_only( state, None, dict(x=1, s="strs"), validate_on=["x", "s"]) is False) assert (partial_parameters_only( state, None, dict(x=1, s="strs"), validate_on=["x"]) is True) assert (partial_parameters_only( state, None, dict(x=1, s="strs"), validate_on=["s"]) is False)
def cache_result(self, state: State, inputs: Dict[str, Result]) -> State: """ Caches the result of a successful task, if appropriate. Alternatively, if the task is failed, caches the inputs. Tasks are cached if: - task.cache_for is not None - the task state is Successful - the task state is not Skipped (which is a subclass of Successful) Args: - state (State): the current state of this task - inputs (Dict[str, Result], optional): a dictionary of inputs whose keys correspond to the task's `run()` arguments. Returns: - State: the state of the task after running the check """ if (state.is_successful() and not state.is_skipped() and self.task.cache_for is not None): expiration = pendulum.now("utc") + self.task.cache_for cached_state = Cached( result=state._result, hashed_inputs={ key: tokenize(val.value) for key, val in inputs.items() }, cached_result_expiration=expiration, cached_parameters=prefect.context.get("parameters"), message=state.message, ) return cached_state return state
def check_target(self, state: State, inputs: Dict[str, Result]) -> State: """ Checks if a Result exists at the task's target. Args: - state (State): the current state of this task - inputs (Dict[str, Result]): a dictionary of inputs whose keys correspond to the task's `run()` arguments. Returns: - State: the state of the task after running the check """ result = self.result target = self.task.target if result and target: if result.exists(target, **prefect.context): new_res = result.read(target.format(**prefect.context)) cached_state = Cached( result=new_res, cached_inputs=inputs, cached_result_expiration=None, cached_parameters=prefect.context.get("parameters"), message=f"Result found at task target {target}", ) return cached_state return state
def test_providing_cachedstate_with_simple_example(self, executor): class TestTask(Task): call_count = 0 def run(self, x, s): self.call_count += 1 return self.call_count with Flow(name="test") as f: y = TestTask(cache_validator=duration_only, cache_for=datetime.timedelta(days=1)) x = Parameter("x") s = SuccessTask() f.add_edge(x, y, key="x") f.add_edge(s, y, key="s") state = Cached( cached_result_expiration=pendulum.now("utc") + datetime.timedelta(days=1), result=100, ) flow_state = FlowRunner(flow=f).run( executor=executor, parameters=dict(x=1), return_tasks=[y], task_states={y: state}, ) assert isinstance(flow_state, Success) assert flow_state.result[y].result == 100
def test_state_kwarg_is_prioritized_over_db_caches(self, client): task = Task( cache_for=datetime.timedelta(minutes=1), cache_validator=duration_only, result=PrefectResult(), ) state_a = Cached( result=PrefectResult(location="2"), cached_result_expiration=pendulum.now("utc").add(minutes=1), ) state_b = Cached( result=PrefectResult(location="99"), cached_result_expiration=pendulum.now("utc").add(minutes=1), ) client.get_latest_cached_states = MagicMock(return_value=[state_a]) new = CloudTaskRunner(task).check_task_is_cached( state=state_b, inputs={"a": Result(1)}) assert new is state_b assert new.result == 99
def test_state_type_methods_with_cached_state(self): state = Cached() assert state.is_cached() assert not state.is_retrying() assert not state.is_pending() assert not state.is_running() assert state.is_finished() assert not state.is_skipped() assert not state.is_scheduled() assert state.is_successful() assert not state.is_failed() assert not state.is_mapped() assert not state.is_meta_state()
def test_task_runner_validates_cached_states_if_task_has_caching(client): @prefect.task(cache_for=datetime.timedelta(minutes=1), result=PrefectResult()) def cached_task(): return 42 state = Cached( cached_result_expiration=datetime.datetime.utcnow() - datetime.timedelta(minutes=2), result=PrefectResult(location="99"), ) old_state = Cached( cached_result_expiration=datetime.datetime.utcnow() - datetime.timedelta(days=1), result=PrefectResult(location="13"), ) client.get_latest_cached_states = MagicMock(return_value=[state, old_state]) res = CloudTaskRunner(task=cached_task).run() assert client.get_latest_cached_states.called assert res.is_successful() assert res.is_cached() assert res.result == 42
def test_task_runner_queries_for_cached_states_if_task_has_caching(client): @prefect.task(cache_for=datetime.timedelta(minutes=1)) def cached_task(): return 42 state = Cached( cached_result_expiration=datetime.datetime.utcnow() + datetime.timedelta(days=1), result=Result(99, JSONResultHandler()), ) old_state = Cached( cached_result_expiration=datetime.datetime.utcnow() - datetime.timedelta(days=1), result=13, ) client.get_latest_cached_states = MagicMock(return_value=[state, old_state]) res = CloudTaskRunner(task=cached_task).run() assert client.get_latest_cached_states.called assert res.is_successful() assert res.is_cached() assert res.result == 99
def test_task_runner_treats_unfound_files_as_invalid_caches(client, tmpdir): @prefect.task(cache_for=datetime.timedelta(minutes=1), result=PrefectResult()) def cached_task(): return 42 state = Cached( cached_result_expiration=datetime.datetime.utcnow() + datetime.timedelta(minutes=2), result=LocalResult(location=str(tmpdir / "made_up_data.prefect")), ) old_state = Cached( cached_result_expiration=datetime.datetime.utcnow() + datetime.timedelta(days=1), result=PrefectResult(location="13"), ) client.get_latest_cached_states = MagicMock(return_value=[state, old_state]) res = CloudTaskRunner(task=cached_task).run() assert client.get_latest_cached_states.called assert res.is_successful() assert res.is_cached() assert res.result == 13
def test_serialize_and_deserialize_on_raw_cached_state(): now = pendulum.now("utc") state = Cached( cached_inputs=dict(x=Result(99), p=Result("p")), result=dict(hi=5, bye=6), cached_result_expiration=now, ) serialized = state.serialize() new_state = State.deserialize(serialized) assert isinstance(new_state, Cached) assert new_state.color == state.color assert new_state.result is None assert new_state.cached_result_expiration == state.cached_result_expiration assert new_state.cached_inputs == dict.fromkeys(["x", "p"], NoResult)
def test_preparing_state_for_cloud_doesnt_copy_data(): class FakeHandler(ResultHandler): def read(self, val): return val def write(self, val): return val value = 124.090909 result = Result(value, result_handler=FakeHandler()) state = Cached(result=result) cloud_state = prepare_state_for_cloud(state) assert cloud_state.is_cached() assert cloud_state.result is state.result
def check_target(self, state: State, inputs: Dict[str, Result]) -> State: """ Checks if a Result exists at the task's target. Args: - state (State): the current state of this task - inputs (Dict[str, Result]): a dictionary of inputs whose keys correspond to the task's `run()` arguments. Returns: - State: the state of the task after running the check """ from dask.base import tokenize result = self.result target = self.task.target if result and target: raw_inputs = {k: r.value for k, r in inputs.items()} formatting_kwargs = { **prefect.context.get("parameters", {}).copy(), **prefect.context, **raw_inputs, } # self can't be used as a formatting parameter because it would ruin all method calls such as # result.exists() by providing two values of self formatting_kwargs.pop("self", None) if not isinstance(target, str): target = target(**formatting_kwargs) if result.exists(target, **formatting_kwargs): # type: ignore known_location = target.format( **formatting_kwargs) # type: ignore new_res = result.read(known_location) cached_state = Cached( result=new_res, hashed_inputs={ key: tokenize(val.value) for key, val in inputs.items() }, cached_result_expiration=None, cached_parameters=formatting_kwargs.get("parameters"), message=f"Result found at task target {known_location}", ) return cached_state return state
def test_serialize_and_deserialize_on_mixed_cached_state(): safe_dct = SafeResult(dict(hi=5, bye=6), result_handler=JSONResultHandler()) now = pendulum.now("utc") state = Cached( cached_inputs=dict(x=Result(2), p=Result("p")), result=safe_dct, cached_result_expiration=now, ) serialized = state.serialize() new_state = State.deserialize(serialized) assert isinstance(new_state, Cached) assert new_state.color == state.color assert new_state.result == dict(hi=5, bye=6) assert new_state.cached_result_expiration == state.cached_result_expiration assert new_state.cached_inputs == dict.fromkeys(["x", "p"], NoResult)
def test_serialize_and_deserialize_on_mixed_cached_state(): safe_dct = PrefectResult(location=json.dumps(dict(hi=5, bye=6))) now = pendulum.now("utc") state = Cached( cached_inputs=dict(x=PrefectResult(value=2), p=PrefectResult(value="p")), result=safe_dct, cached_result_expiration=now, ) serialized = state.serialize() new_state = State.deserialize(serialized) assert isinstance(new_state, Cached) assert new_state.color == state.color assert new_state._result.location == json.dumps(dict(hi=5, bye=6)) assert new_state.cached_result_expiration == state.cached_result_expiration assert new_state.cached_inputs == dict.fromkeys(["x", "p"], PrefectResult())
def test_serialize_and_deserialize_on_safe_cached_state(): safe = SafeResult("99", result_handler=JSONResultHandler()) safe_dct = SafeResult(dict(hi=5, bye=6), result_handler=JSONResultHandler()) now = pendulum.now("utc") state = Cached( cached_inputs=dict(x=safe, p=safe), result=safe_dct, cached_result_expiration=now, ) serialized = state.serialize() new_state = State.deserialize(serialized) assert isinstance(new_state, Cached) assert new_state.color == state.color assert new_state.result == dict(hi=5, bye=6) assert new_state.cached_result_expiration == state.cached_result_expiration assert new_state.cached_inputs == state.cached_inputs
def test_reads_result_if_cached_valid(self, client): result = PrefectResult(location="2") with pytest.warns(UserWarning): task = Task(cache_validator=duration_only, result=PrefectResult()) state = Cached( result=result, cached_result_expiration=pendulum.now("utc").add(minutes=1)) client.get_latest_cached_states = MagicMock(return_value=[]) new = CloudTaskRunner(task).check_task_is_cached( state=state, inputs={"a": PrefectResult(value=1)}) assert new is state assert new.result == 2
def check_target(self, state: State, inputs: Dict[str, Result]) -> State: """ Checks if a Result exists at the task's target. Args: - state (State): the current state of this task - inputs (Dict[str, Result]): a dictionary of inputs whose keys correspond to the task's `run()` arguments. Returns: - State: the state of the task after running the check """ result = self.result target = self.task.target if result and target: raw_inputs = {k: r.value for k, r in inputs.items()} formatting_kwargs = { **prefect.context.get("parameters", {}).copy(), **prefect.context, **raw_inputs, } if not isinstance(target, str): target = target(**formatting_kwargs) if result.exists(target, **formatting_kwargs): # type: ignore known_location = target.format( **formatting_kwargs) # type: ignore new_res = result.read(known_location) cached_state = Cached( result=new_res, hashed_inputs={ key: tokenize(list(val.value.keys())) if type(val.value) == dict else tokenize(val.value) for key, val in inputs.items() }, cached_result_expiration=None, cached_parameters=formatting_kwargs.get("parameters"), message=f"Result found at task target {known_location}", ) return cached_state return state
def test_reads_result_using_handler_attribute_if_cached_valid( self, client): class MyResult(Result): def read(self, *args, **kwargs): self.value = 53 return self with pytest.warns(UserWarning): task = Task(cache_validator=duration_only, result=MyResult()) result = PrefectResult(location="2") state = Cached( result=result, cached_result_expiration=pendulum.now("utc").add(minutes=1)) client.get_latest_cached_states = MagicMock(return_value=[]) new = CloudTaskRunner(task).check_task_is_cached( state=state, inputs={"a": Result(1)}) assert new is state assert new.result == 53
def test_reads_result_if_cached_valid_using_task_result(task, client): class MyResult(Result): def read(self, *args, **kwargs): self.value = 53 return self task = Task( result=MyResult(), cache_for=datetime.timedelta(minutes=1), cache_validator=duration_only, ) state = Cached( result=PrefectResult(location="2"), cached_result_expiration=pendulum.now("utc").add(minutes=1), ) client.get_latest_cached_states = MagicMock(return_value=[state]) new = CloudTaskRunner(task).check_task_is_cached( state=Pending(), inputs={"a": Result(1)}) assert new is state assert new.result == 53
def test_additional_parameters_invalidate(self): state = Cached(cached_parameters=dict(x=1, s="str")) assert all_parameters(state, None, dict(x=1, s="str", noise="e")) is False
def test_skipped_is_success(self): assert issubclass(Skipped, Success) def test_timedout_is_failed(self): assert issubclass(TimedOut, Failed) def test_trigger_failed_is_failed(self): assert issubclass(TriggerFailed, Failed) @pytest.mark.parametrize( "state_check", [ dict(state=Cancelled(), assert_true={"is_finished", "is_failed"}), dict(state=Cached(), assert_true={"is_cached", "is_finished", "is_successful"}), dict(state=ClientFailed(), assert_true={"is_meta_state"}), dict(state=Failed(), assert_true={"is_finished", "is_failed"}), dict(state=Finished(), assert_true={"is_finished"}), dict(state=Looped(), assert_true={"is_finished", "is_looped"}), dict(state=Mapped(), assert_true={"is_finished", "is_mapped", "is_successful"}), dict(state=Paused(), assert_true={"is_pending", "is_scheduled"}), dict(state=Pending(), assert_true={"is_pending"}), dict(state=Queued(), assert_true={"is_meta_state", "is_queued"}), dict(state=Resume(), assert_true={"is_pending", "is_scheduled"}), dict( state=Retrying(), assert_true={"is_pending", "is_scheduled", "is_retrying"} ), dict(state=Running(), assert_true={"is_running"}), dict(state=Scheduled(), assert_true={"is_pending", "is_scheduled"}), dict(
def test_skipped_is_success(self): assert issubclass(Skipped, Success) def test_timedout_is_failed(self): assert issubclass(TimedOut, Failed) def test_trigger_failed_is_failed(self): assert issubclass(TriggerFailed, Failed) @pytest.mark.parametrize( "state_check", [ dict(state=Cancelled(), assert_true={"is_finished"}), dict(state=Cached(), assert_true={"is_cached", "is_finished", "is_successful"}), dict(state=ClientFailed(), assert_true={"is_meta_state"}), dict(state=Failed(), assert_true={"is_finished", "is_failed"}), dict(state=Finished(), assert_true={"is_finished"}), dict(state=Looped(), assert_true={"is_finished", "is_looped"}), dict(state=Mapped(), assert_true={"is_finished", "is_mapped", "is_successful"}), dict(state=Paused(), assert_true={"is_pending", "is_scheduled"}), dict(state=Pending(), assert_true={"is_pending"}), dict(state=Queued(), assert_true={"is_meta_state", "is_queued"}), dict(state=Resume(), assert_true={"is_pending", "is_scheduled"}), dict(state=Retrying(), assert_true={"is_pending", "is_scheduled", "is_retrying"}), dict(state=Running(), assert_true={"is_running"}), dict(state=Scheduled(), assert_true={"is_pending", "is_scheduled"}),