def test_cache_validator_provided_if_needed(self): """ If `cache_for` is provided, and `cache_validator` is not, a `cache_validator` should be provided. """ r = Result(value=3, cache_for=datetime.timedelta(days=2)) assert r.cache_validator is not None assert callable(r.cache_validator)
def test_preparing_state_for_cloud_does_nothing_if_result_is_none(cls): xres = Result(None, result_handler=JSONResultHandler()) state = prepare_state_for_cloud(cls(cached_inputs=dict(x=xres))) assert isinstance(state, cls) assert state.result is None assert state._result == NoResult assert state.cached_inputs == dict(x=xres) assert state.serialize()["cached_inputs"]["x"]["type"] == "NoResultType"
async def test_set_flow_run_state_with_result(self, run_query, flow_run_id): result = Result(10, result_handler=JSONResultHandler()) result.store_safe_value() state = Success(result=result) result = await run_query( query=self.mutation, variables=dict(input=dict(states=[ dict(flow_run_id=flow_run_id, state=state.serialize()) ])), ) fr = await models.FlowRun.where( id=result.data.set_flow_run_states.states[0].id ).first({"state", "version"}) assert fr.version == 2 assert fr.state == "Success"
def test_preparing_state_for_cloud_replaces_cached_inputs_with_safe(cls): xres = Result(3, result_handler=JSONResultHandler()) state = prepare_state_for_cloud(cls(cached_inputs=dict(x=xres))) assert isinstance(state, cls) assert state.result is None assert state._result == NoResult assert state.cached_inputs == dict(x=xres) assert state.serialize()["cached_inputs"]["x"]["value"] == "3"
def test_result_validate_warns_when_run_without_run_validators_flag( self, caplog): _example_function = MagicMock(return_value=True) r = Result(value=None, validators=[_example_function], run_validators=False) with caplog.at_level(logging.WARNING, "prefect.Result"): is_valid = r.validate() # it should have acted normal and called the validate functions _example_function.assert_called_once_with(r) assert is_valid is True # but ALSO it should published a warning log, going on about run_validators not being set assert caplog.text.find("WARNING") > -1 assert caplog.text.find("run_validators") > -1
def run_task( self, task: Task, state: State, upstream_states: Dict[Edge, State], context: Dict[str, Any], task_runner_state_handlers: Iterable[Callable], executor: "prefect.engine.executors.Executor", ) -> State: """ Runs a specific task. This method is intended to be called by submitting it to an executor. Args: - task (Task): the task to run - state (State): starting state for the Flow. Defaults to `Pending` - upstream_states (Dict[Edge, State]): dictionary of upstream states - context (Dict[str, Any]): a context dictionary for the task run - task_runner_state_handlers (Iterable[Callable]): A list of state change handlers that will be provided to the task_runner, and called whenever a task changes state. - executor (Executor): executor to use when performing computation; defaults to the executor provided in your prefect configuration Returns: - State: `State` representing the final post-run state of the `Flow`. """ with prefect.context(self.context): default_result = task.result or self.flow.result task_runner = self.task_runner_cls( task=task, state_handlers=task_runner_state_handlers, result=default_result or Result(), default_result=self.flow.result, ) # if this task reduces over a mapped state, make sure its children have finished for edge, upstream_state in upstream_states.items(): # if the upstream state is Mapped, wait until its results are all available if not edge.mapped and upstream_state.is_mapped(): assert isinstance(upstream_state, Mapped) # mypy assert upstream_state.map_states = executor.wait( upstream_state.map_states) upstream_state.result = [ s.result for s in upstream_state.map_states ] return task_runner.run( state=state, upstream_states=upstream_states, context=context, executor=executor, )
def __init__( self, task: Task, state_handlers: Iterable[Callable] = None, flow_result: Result = None, ): self.context = prefect.context.to_dict() self.task = task # Use result from task over the one provided off the parent Flow object if task.result: self.result = task.result else: self.result = Result().copy( ) if flow_result is None else flow_result.copy() self.flow_result = flow_result super().__init__(state_handlers=state_handlers)
def test_task_runner_puts_cloud_in_context(client): @prefect.task(result=Result()) def whats_in_ctx(): return prefect.context.get("checkpointing") res = CloudTaskRunner(task=whats_in_ctx).run() assert res.is_successful() assert res.result is True
def test_preparing_state_for_cloud_ignores_the_lack_of_result_handlers_for_cached_inputs( cls, ): xres = Result(3, result_handler=None) state = prepare_state_for_cloud(cls(cached_inputs=dict(x=xres))) assert isinstance(state, cls) assert state.result is None assert state._result == NoResult assert state.cached_inputs == dict(x=xres) assert state.serialize()["cached_inputs"]["x"]["type"] == "NoResultType"
def test_deserialize_json_with_context(): deserialized = StateSchema().load( {"type": "Running", "context": {"boo": ["a", "b", "c"]}} ) assert type(deserialized) is state.Running assert deserialized.is_running() assert deserialized.message is None assert deserialized.context == dict(boo=["a", "b", "c"]) assert deserialized.result is None assert deserialized._result == Result()
def __init__( self, task: Task, state_handlers: Iterable[Callable] = None, flow_result: Result = None, ): self.context = prefect.context.to_dict() self.task = task # if the result was provided off the parent Flow object # we want to use the task's target as the target location if task.result: self.result = task.result else: self.result = Result() if flow_result is None else flow_result if self.task.target: self.result.location = self.task.target self.flow_result = flow_result super().__init__(state_handlers=state_handlers)
def test_serialize_state_with_handled_result(cls): res = Result(value=1, location="src/place") state = cls(message="message", result=res) serialized = StateSchema().dump(state) assert isinstance(serialized, dict) assert serialized["type"] == cls.__name__ assert serialized["message"] == "message" assert serialized["_result"]["type"] == "Result" assert serialized["_result"]["location"] == "src/place" assert serialized["__version__"] == prefect.__version__
def test_deserialize_mapped(): s = state.Success(message="1", result=1) f = state.Failed(message="2", result=2) serialized = StateSchema().dump(state.Mapped(message="message", map_states=[s, f])) deserialized = StateSchema().load(serialized) assert isinstance(deserialized, state.Mapped) assert len(deserialized.map_states) == 2 assert deserialized.map_states == [None, None] assert deserialized._result == Result() assert deserialized.result is None
def test_has_abstract_interfaces(abstract_interface: str): """ Tests to make sure that calling the abstract interfaces directly on the base `Result` class results in `NotImplementedError`s. """ r = Result(value=3) func = getattr(r, abstract_interface) with pytest.raises(NotImplementedError): func()
def test_state_load_cached_results_reads_if_location_is_provided(self, cls): class MyResult(Result): def read(self, *args, **kwargs): self.value = "bar" return self state = cls(cached_inputs=dict(y=Result())) new_state = state.load_cached_results(dict(y=MyResult(location="foo"))) assert new_state.cached_inputs["y"].value == "bar" assert new_state.cached_inputs["y"].location == "foo"
def check_for_retry(self, state: State, inputs: Dict[str, Result]) -> State: """ Checks to see if a FAILED task should be retried. Args: - state (State): the current state of this task - inputs (Dict[str, Result], optional): a dictionary of inputs whose keys correspond to the task's `run()` arguments. Returns: - State: the state of the task after running the check """ if state.is_failed(): run_count = prefect.context.get("task_run_count", 1) if prefect.context.get("task_loop_count") is not None: loop_context = { "_loop_count": Result( value=prefect.context["task_loop_count"], result_handler=JSONResultHandler(), ), "_loop_result": Result( value=prefect.context.get("task_loop_result"), result_handler=self.result_handler, ), } inputs.update(loop_context) if run_count <= self.task.max_retries: start_time = pendulum.now("utc") + self.task.retry_delay msg = "Retrying Task (after attempt {n} of {m})".format( n=run_count, m=self.task.max_retries + 1) retry_state = Retrying( start_time=start_time, cached_inputs=inputs, message=msg, run_count=run_count, ) return retry_state return state
def test_uses_provided_cache_validator(self): def custom_cache_validator(*args, **kwargs): # Creating a custom function for identity comparison return True r = Result( value=3, cache_for=datetime.timedelta(days=2), cache_validator=custom_cache_validator, ) assert r.cache_validator is custom_cache_validator
def test_result_inits_with_value(self): r = Result(3) assert r.value == 3 assert r.safe_value is NoResult assert r.result_handler is None assert r.validators is None assert r.cache_for is None assert r.cache_validator is None assert r.filepath_template is None assert r.run_validators is True s = Result(value=5) assert s.value == 5 assert s.safe_value is NoResult assert s.result_handler is None assert s.validators is None assert s.cache_for is None assert s.cache_validator is None assert s.filepath_template is None assert r.run_validators is True
def test_task_failure_caches_constant_inputs_automatically(client): @prefect.task(max_retries=2, retry_delay=timedelta(seconds=100)) def is_p_three(p): if p == 3: raise ValueError("No thank you.") with prefect.Flow("test") as f: res = is_p_three(3) state = CloudFlowRunner(flow=f).run(return_tasks=[res]) assert state.is_running() assert isinstance(state.result[res], Retrying) exp_res = Result(3, result_handler=ConstantResultHandler(3)) assert not state.result[res].cached_inputs["p"] == exp_res exp_res.store_safe_value() assert state.result[res].cached_inputs["p"] == exp_res last_state = client.set_task_run_state.call_args_list[-1][-1]["state"] assert isinstance(last_state, Retrying) assert last_state.cached_inputs["p"] == exp_res
def test_create_state_with_tags_in_context(self, cls): with prefect.context(task_tags=set("abcdef")): state = cls() assert state.message is None assert state.result is None assert state._result == Result() assert state.context == dict(tags=list(set("abcdef"))) with prefect.context(task_tags=set("abcdef")): state = cls(context={"tags": ["foo"]}) assert state.context == dict(tags=["foo"])
def test_validate_on_kwarg(self): state = Cached(cached_inputs=dict(x=Result(1), s=Result("str"))) assert ( partial_inputs_only(validate_on=["x", "s"])(state, dict(x=1, s="str"), None) is True ) state = Cached(cached_inputs=dict(x=Result(1), s=Result("str"))) assert ( partial_inputs_only(validate_on=["x", "s"])( state, dict(x=1, s="strs"), None ) is False ) assert ( partial_inputs_only(validate_on=["x"])(state, dict(x=1, s="strs"), None) is True ) assert ( partial_inputs_only(validate_on=["s"])(state, dict(x=1, s="strs"), None) is False )
def test_task_failure_with_upstream_secrets_doesnt_store_secret_value_and_recompute_if_necessary( client, ): @prefect.task(max_retries=2, retry_delay=timedelta(seconds=100)) def is_p_three(p): if p == 3: raise ValueError("No thank you.") return p with prefect.Flow("test", result_handler=JSONResultHandler()) as f: p = prefect.tasks.secrets.Secret("p") res = is_p_three(p) with prefect.context(secrets=dict(p=3)): state = CloudFlowRunner(flow=f).run(return_tasks=[res]) assert state.is_running() assert isinstance(state.result[res], Retrying) exp_res = Result(3, result_handler=SecretResultHandler(p)) assert not state.result[res].cached_inputs["p"] == exp_res exp_res.store_safe_value() assert state.result[res].cached_inputs["p"] == exp_res ## here we set the result of the secret to a saferesult, ensuring ## it will get converted to a "true" result; ## we expect that the upstream value will actually get recomputed from context ## through the SecretResultHandler safe = SafeResult("p", result_handler=SecretResultHandler(p)) state.result[p] = Success(result=safe) state.result[res].start_time = pendulum.now("utc") state.result[res].cached_inputs = dict(p=safe) with prefect.context(secrets=dict(p=4)): new_state = CloudFlowRunner(flow=f).run(return_tasks=[res], task_states=state.result) assert new_state.is_successful() assert new_state.result[res].result == 4
def test_preparing_state_for_cloud_doesnt_copy_data(): class FakeHandler(ResultHandler): def read(self, val): return val def write(self, val): return val value = 124.090909 result = Result(value, result_handler=FakeHandler()) state = Cached(result=result) cloud_state = prepare_state_for_cloud(state) assert cloud_state.is_cached() assert cloud_state.result is state.result
def checkpoint_handler(task_runner: DSTaskRunner, old_state: State, new_state: State) -> State: """ A handler designed to implement result caching by filename. If the result handler's ``read`` method can be successfully run, this handler loads the result of that method as the task result and sets the task state to ``Success``. Similarly, on successful completion of the task, if the task was actually run and not loaded from cache, this handler will apply the result handler's ``write`` method to the task. Parameters ---------- task_runner : instance of DSTaskRunner The task runner associated with the flow the handler is used in. old_state : instance of prefect.engine.state.State The current state of the task. new_state : instance of prefect.engine.state.State The expected new state of the task. Returns ------- new_state : instance of prefect.engine.state.State The actual new state of the task. """ if "PREFECT__FLOWS__CHECKPOINTING" in os.environ and os.environ["PREFECT__FLOWS__CHECKPOINTING"] == "true": raise AttributeError("Cannot use standard prefect checkpointing with this handler") if task_runner.result_handler is not None and old_state.is_pending() and new_state.is_running(): if not hasattr(task_runner, "upstream_states"): raise TypeError( "upstream_states not found in task runner. Make sure to use " "prefect_ds.task_runner.DSTaskRunner." ) input_mapping = _create_input_mapping(task_runner.upstream_states) try: data = task_runner.task.result_handler.read(input_mapping=input_mapping) except FileNotFoundError: return new_state except TypeError: # unexpected argument input_mapping raise TypeError( "Result handler could not accept input_mapping argument. " "Please ensure that you are using a handler from prefect_ds." ) result = Result(value=data, result_handler=task_runner.task.result_handler) state = Success(result=result, message="Task loaded from disk.") return state if task_runner.result_handler is not None and old_state.is_running() and new_state.is_successful(): input_mapping = _create_input_mapping(task_runner.upstream_states) task_runner.task.result_handler.write(new_state.result, input_mapping=input_mapping) return new_state
def test_cloud_task_runner_handles_retries_with_queued_states_from_cloud(client): calls = [] def queued_mock(*args, **kwargs): calls.append(kwargs) # first retry attempt will get queued if len(calls) == 4: return Queued() # immediate start time else: return kwargs.get("state") client.set_task_run_state = queued_mock @prefect.task( max_retries=2, retry_delay=datetime.timedelta(seconds=0), result_handler=ResultHandler(), ) def tagged_task(x): if prefect.context.get("task_run_count", 1) == 1: raise ValueError("gimme a sec") return x upstream_result = Result(value=42, result_handler=JSONResultHandler()) res = CloudTaskRunner(task=tagged_task).run( context={"task_run_version": 1}, state=None, upstream_states={ Edge(Task(), tagged_task, key="x"): Success(result=upstream_result) }, executor=prefect.engine.executors.LocalExecutor(), ) assert res.is_successful() assert res.result == 42 assert ( len(calls) == 6 ) # Running -> Failed -> Retrying -> Queued -> Running -> Success assert [type(c["state"]).__name__ for c in calls] == [ "Running", "Failed", "Retrying", "Running", "Running", "Success", ] # ensures result handler was called and persisted assert calls[2]["state"].cached_inputs["x"].safe_value.value == "42"
def test_state_load_result_reads_if_location_is_provided(self, cls): class MyResult(Result): def read(self, *args, **kwargs): self.value = "bar" return self state = cls(result=Result()) assert state.message is None assert state.result is None assert state._result.location is None new_state = state.load_result(MyResult(location="foo")) assert new_state.message is None assert new_state.result == "bar" assert new_state._result.location == "foo"
def test_task_runner_validates_cached_states_if_task_has_caching(client): @prefect.task(cache_for=datetime.timedelta(minutes=1), result_handler=JSONResultHandler()) def cached_task(): return 42 state = Cached( cached_result_expiration=datetime.datetime.utcnow() - datetime.timedelta(minutes=2), result=Result(99, JSONResultHandler()), ) old_state = Cached( cached_result_expiration=datetime.datetime.utcnow() - datetime.timedelta(days=1), result=Result(13, JSONResultHandler()), ) client.get_latest_cached_states = MagicMock( return_value=[state, old_state]) res = CloudTaskRunner(task=cached_task).run() assert client.get_latest_cached_states.called assert res.is_successful() assert res.is_cached() assert res.result == 42
def test_state_load_cached_results_doesnt_call_read_if_location_is_none( self, cls): """ If both the value and location information are None, we assume that None is the correct return value and perform no action. """ class MyResult(Result): def read(self, *args, **kwargs): self.location = "foo" self.value = "bar" return self state = cls(cached_inputs=dict(x=Result())) new_state = state.load_cached_results(dict(x=MyResult())) assert new_state.cached_inputs["x"].value is None assert new_state.cached_inputs["x"].location is None
def test_state_load_cached_results_doesnt_call_read_if_value_present( self, cls): """ This test ensures that multiple calls to `load_result` will not result in multiple redundant reads from the remote result location. """ class MyResult(Result): def read(self, *args, **kwargs): self.location = "foo" self.value = "bar" return self state = cls(cached_inputs=dict(x=Result(value=42))) new_state = state.load_cached_results(dict(x=MyResult())) assert new_state.cached_inputs["x"].value == 42 assert new_state.cached_inputs["x"].location is None
def test_state_load_cached_results_calls_read(self, cls): """ This test ensures that the read logic of the provided result is used instead of self._result; this is important when "hydrating" JSON representations of Results objects that come from Cloud. """ class MyResult(Result): def read(self, *args, **kwargs): self.location = "foo" self.value = 42 return self state = cls(cached_inputs=dict(x=Result())) new_state = state.load_cached_results(dict(x=MyResult(location=""))) assert new_state.cached_inputs["x"].value == 42 assert new_state.cached_inputs["x"].location == "foo"