def test_reads_result_using_handler_attribute_if_cached_valid( self, client): class MyResult(Result): def read(self, *args, **kwargs): self.value = 53 return self with pytest.warns(UserWarning): task = Task(cache_validator=duration_only, result=MyResult()) result = PrefectResult(location="2") state = Cached( result=result, cached_result_expiration=pendulum.now("utc").add(minutes=1)) client.get_latest_cached_states = MagicMock(return_value=[]) new = CloudTaskRunner(task).check_task_is_cached( state=state, inputs={"a": Result(1)}) assert new is state assert new.result == 53
def check_target(self, state: State, inputs: Dict[str, Result]) -> State: """ Checks if a Result exists at the task's target. Args: - state (State): the current state of this task - inputs (Dict[str, Result]): a dictionary of inputs whose keys correspond to the task's `run()` arguments. Returns: - State: the state of the task after running the check """ result = self.result target = self.task.target if result and target: raw_inputs = {k: r.value for k, r in inputs.items()} formatting_kwargs = { **prefect.context.get("parameters", {}).copy(), **raw_inputs, **prefect.context, } if not isinstance(target, str): target = target(**formatting_kwargs) if result.exists(target, **formatting_kwargs): new_res = result.read(target.format(**formatting_kwargs)) cached_state = Cached( result=new_res, hashed_inputs={ key: tokenize(val.value) for key, val in inputs.items() }, cached_result_expiration=None, cached_parameters=formatting_kwargs.get("parameters"), message=f"Result found at task target {target}", ) return cached_state return state
def test_reads_result_if_cached_valid_using_task_result(task, client): class MyResult(Result): def read(self, *args, **kwargs): self.value = 53 return self task = Task( result=MyResult(), cache_for=datetime.timedelta(minutes=1), cache_validator=duration_only, ) state = Cached( result=PrefectResult(location="2"), cached_result_expiration=pendulum.now("utc").add(minutes=1), ) client.get_latest_cached_states = MagicMock(return_value=[state]) new = CloudTaskRunner(task).check_task_is_cached( state=Pending(), inputs={"a": Result(1)}) assert new is state assert new.result == 53
def cache_result(self, state: State, inputs: Dict[str, Result]) -> State: """ Caches the result of a successful task, if appropriate. Alternatively, if the task is failed, caches the inputs. Tasks are cached if: - task.cache_for is not None - the task state is Successful - the task state is not Skipped (which is a subclass of Successful) Args: - state (State): the current state of this task - inputs (Dict[str, Result], optional): a dictionary of inputs whose keys correspond to the task's `run()` arguments. Returns: - State: the state of the task after running the check """ if state.is_failed(): state.cached_inputs = inputs # type: ignore if ( state.is_successful() and not state.is_skipped() and self.task.cache_for is not None ): expiration = pendulum.now("utc") + self.task.cache_for cached_state = Cached( result=state._result, cached_inputs=inputs, cached_result_expiration=expiration, cached_parameters=prefect.context.get("parameters"), message=state.message, ) return cached_state return state
def test_parameters_validate_with_defaults(self): state = Cached(cached_parameters=dict(x=1, s="str")) assert partial_parameters_only()(state, None, dict(x=1, s="str")) is True state = Cached(cached_parameters=dict(x=1, s="str")) assert partial_parameters_only()(state, None, dict(x=1, s="strs")) is True
def test_skipped_is_success(self): assert issubclass(Skipped, Success) def test_timedout_is_failed(self): assert issubclass(TimedOut, Failed) def test_trigger_failed_is_failed(self): assert issubclass(TriggerFailed, Failed) @pytest.mark.parametrize( "state_check", [ dict(state=Cancelled(), assert_true={"is_finished"}), dict(state=Cached(), assert_true={"is_cached", "is_finished", "is_successful"}), dict(state=ClientFailed(), assert_true={"is_meta_state"}), dict(state=Failed(), assert_true={"is_finished", "is_failed"}), dict(state=Finished(), assert_true={"is_finished"}), dict(state=Looped(), assert_true={"is_finished", "is_looped"}), dict(state=Mapped(), assert_true={"is_finished", "is_mapped", "is_successful"}), dict(state=Paused(), assert_true={"is_pending", "is_scheduled"}), dict(state=Pending(), assert_true={"is_pending"}), dict(state=Queued(), assert_true={"is_meta_state", "is_queued"}), dict(state=Resume(), assert_true={"is_pending", "is_scheduled"}), dict(state=Retrying(), assert_true={"is_pending", "is_scheduled", "is_retrying"}), dict(state=Running(), assert_true={"is_running"}), dict(state=Scheduled(), assert_true={"is_pending", "is_scheduled"}),
def test_state_type_methods_with_cached_state(self): state = Cached() assert state.is_cached() assert not state.is_retrying() assert not state.is_pending() assert not state.is_running() assert state.is_finished() assert not state.is_skipped() assert not state.is_scheduled() assert state.is_successful() assert not state.is_failed() assert not state.is_mapped() assert not state.is_meta_state()
def test_skipped_is_success(self): assert issubclass(Skipped, Success) def test_timedout_is_failed(self): assert issubclass(TimedOut, Failed) def test_trigger_failed_is_failed(self): assert issubclass(TriggerFailed, Failed) @pytest.mark.parametrize( "state_check", [ dict(state=Cancelled(), assert_true={"is_finished", "is_failed"}), dict(state=Cached(), assert_true={"is_cached", "is_finished", "is_successful"}), dict(state=ClientFailed(), assert_true={"is_meta_state"}), dict(state=Failed(), assert_true={"is_finished", "is_failed"}), dict(state=Finished(), assert_true={"is_finished"}), dict(state=Looped(), assert_true={"is_finished", "is_looped"}), dict(state=Mapped(), assert_true={"is_finished", "is_mapped", "is_successful"}), dict(state=Paused(), assert_true={"is_pending", "is_scheduled"}), dict(state=Pending(), assert_true={"is_pending"}), dict(state=Queued(), assert_true={"is_meta_state", "is_queued"}), dict(state=Resume(), assert_true={"is_pending", "is_scheduled"}), dict( state=Retrying(), assert_true={"is_pending", "is_scheduled", "is_retrying"} ), dict(state=Running(), assert_true={"is_running"}), dict(state=Scheduled(), assert_true={"is_pending", "is_scheduled"}), dict(
def test_parameters_validate(self): state = Cached(cached_parameters=dict(x=1, s="str")) assert all_parameters(state, None, dict(x=1, s="str")) is True
def test_additional_parameters_invalidate(self): state = Cached(cached_parameters=dict(x=1, s="str")) assert all_parameters(state, None, dict(x=1, s="str", noise="e")) is False
def test_inputs_validate(self): state = Cached(cached_inputs=dict(x=1, s="str")) assert all_inputs(state, dict(x=1, s="str"), None) is True
def test_hashed_inputs_validate(self): state = Cached(hashed_inputs=dict(x=tokenize(1), s=tokenize("str"))) assert all_inputs(state, dict(x=1, s="str"), None) is True
def test_unexpired_cache(self): state = Cached(cached_result_expiration=pendulum.now("utc") + timedelta(days=1)) assert duration_only(state, None, None) is True
def test_handles_none(self): state = Cached(cached_parameters=dict(x=5)) assert partial_inputs_only(validate_on=["x"])(state, dict(x=5), None) is False state = Cached(cached_inputs=dict(x=Result(5))) assert partial_inputs_only(validate_on=["x"])(state, None, None) is False
def test_expired_cache_stateful(validator): state = Cached(cached_result_expiration=pendulum.now("utc") - timedelta(days=1)) assert validator()(state, None, None) is False
def test_curried(self): state = Cached(cached_parameters=dict(x=1, s="str")) validator = partial_parameters_only(validate_on=["x"]) assert validator(state, None, dict(x=1)) is True assert validator(state, None, dict(x=2, s="str")) is False
def test_inputs_validate_with_defaults(self): state = Cached(cached_inputs=dict(x=Result(1), s=Result("str"))) assert partial_inputs_only(None)(state, dict(x=1, s="str"), None) is True state = Cached(cached_inputs=dict(x=Result(1), s=Result("str"))) assert partial_inputs_only(None)(state, dict(x=1, s="strs"), None) is True
def test_curried(self): state = Cached(cached_inputs=dict(x=Result(1), s=Result("str"))) validator = partial_inputs_only(validate_on=["x"]) assert validator(state, dict(x=1), None) is True assert validator(state, dict(x=2, s="str"), None) is False
def test_additional_inputs_invalidate(self): state = Cached(cached_inputs=dict(x=1, s="str")) assert all_inputs(state, dict(x=1, s="str", noise="e"), None) is False
def test_cached_result_expiration_none_is_interpreted_as_infinite(self): state = Cached(cached_result_expiration=None) assert duration_only(state, None, None) is True
def test_inputs_invalidate(self): state = Cached(cached_inputs=dict(x=Result(1), s=Result("str"))) assert all_inputs(state, dict(x=1, s="strs"), None) is False