async def test_task_runs_cached_inputs_give_preference_to_new_cached_inputs( self, state_cls, task_run_id, running_flow_run_id ): # set up a Failed state with null cached inputs res1 = SafeResult(1, result_handler=JSONResultHandler()) res2 = SafeResult({"a": 2}, result_handler=JSONResultHandler()) complex_result = {"b": res1, "c": res2} cached_state = state_cls(cached_inputs=complex_result) await api.states.set_task_run_state(task_run_id=task_run_id, state=cached_state) # set up a Retrying state with non-null cached inputs res1 = SafeResult(1, result_handler=JSONResultHandler()) res2 = SafeResult({"z": 2}, result_handler=JSONResultHandler()) complex_result = {"x": res1, "y": res2} cached_state = Retrying(cached_inputs=complex_result) await api.states.set_task_run_state(task_run_id=task_run_id, state=cached_state) run = Box( await models.TaskRun.where(id=task_run_id).first({"serialized_state"}) ) # verify that we have cached inputs, and that preference has been given to the new # cached inputs assert run.serialized_state.cached_inputs assert run.serialized_state.cached_inputs.x.value == 1 assert run.serialized_state.cached_inputs.y.value == {"z": 2}
def test_task_runner_validates_cached_state_inputs_with_upstream_handlers_if_task_has_caching( client, ): class Handler(ResultHandler): def read(self, val): return 1337 @prefect.task( cache_for=datetime.timedelta(minutes=1), cache_validator=all_inputs, result_handler=JSONResultHandler(), ) def cached_task(x): return 42 dull_state = Cached( cached_result_expiration=datetime.datetime.utcnow() + datetime.timedelta(minutes=2), result=SafeResult("-1", JSONResultHandler()), ) state = Cached( cached_result_expiration=datetime.datetime.utcnow() + datetime.timedelta(minutes=2), result=SafeResult("99", JSONResultHandler()), cached_inputs={ "x": SafeResult("2", result_handler=JSONResultHandler()) }, ) client.get_latest_cached_states = MagicMock( return_value=[dull_state, state]) res = CloudTaskRunner(task=cached_task).check_task_is_cached( Pending(), inputs={"x": Result(2, result_handler=Handler())}) assert client.get_latest_cached_states.called assert res.is_pending()
def test_safe_result_requires_both_init_args(self): with pytest.raises(TypeError, match="2 required positional arguments"): SafeResult() with pytest.raises(TypeError, match="1 required positional argument"): SafeResult(value="3") with pytest.raises(TypeError, match="1 required positional argument"): SafeResult(result_handler=JSONResultHandler())
def complex_states(): res1 = SafeResult(1, result_handler=JSONResultHandler()) res2 = SafeResult({"z": 2}, result_handler=JSONResultHandler()) res3 = SafeResult(dict(x=1, y={"z": 2}), result_handler=JSONResultHandler()) naive_dt = datetime.datetime(2020, 1, 1) utc_dt = pendulum.datetime(2020, 1, 1) complex_result = {"x": res1, "y": res2} cached_state = state.Cached( cached_inputs=complex_result, result=res3, cached_parameters={ "x": 1, "y": { "z": 2 } }, cached_result_expiration=utc_dt, ) cached_state_naive = state.Cached( cached_inputs=complex_result, result=res3, cached_parameters={ "x": 1, "y": { "z": 2 } }, cached_result_expiration=naive_dt, ) running_tags = state.Running() running_tags.context = dict(tags=["1", "2", "3"]) test_states = [ state.Looped(loop_count=45), state.Pending(cached_inputs=complex_result), state.Paused(cached_inputs=complex_result), state.Retrying(start_time=utc_dt, run_count=3), state.Retrying(start_time=naive_dt, run_count=3), state.Scheduled(start_time=utc_dt), state.Scheduled(start_time=naive_dt), state.Resume(start_time=utc_dt), state.Resume(start_time=naive_dt), running_tags, state.Submitted(state=state.Retrying(start_time=utc_dt, run_count=2)), state.Submitted(state=state.Resume(start_time=utc_dt)), state.Queued(state=state.Pending()), state.Queued(state=state.Pending(), start_time=utc_dt), state.Queued(state=state.Retrying(start_time=utc_dt, run_count=2)), cached_state, cached_state_naive, state.TimedOut(cached_inputs=complex_result), ] return test_states
def test_safe_result_requires_both_init_args(self): with pytest.raises(TypeError) as exc: res = SafeResult() assert "2 required positional arguments" in str(exc.value) with pytest.raises(TypeError) as exc: res = SafeResult(value="3") assert "1 required positional argument" in str(exc.value) with pytest.raises(TypeError) as exc: res = SafeResult(result_handler=JSONResultHandler()) assert "1 required positional argument" in str(exc.value)
async def test_task_runs_with_null_cached_inputs_do_not_overwrite_cache( self, state, task_run_id, running_flow_run_id ): await api.states.set_task_run_state(task_run_id=task_run_id, state=state) # set up a Retrying state with non-null cached inputs res1 = SafeResult(1, result_handler=JSONResultHandler()) res2 = SafeResult({"z": 2}, result_handler=JSONResultHandler()) complex_result = {"x": res1, "y": res2} cached_state = Retrying(cached_inputs=complex_result) await api.states.set_task_run_state(task_run_id=task_run_id, state=cached_state) run = await models.TaskRun.where(id=task_run_id).first({"serialized_state"}) assert run.serialized_state["cached_inputs"]["x"]["value"] == 1 assert run.serialized_state["cached_inputs"]["y"]["value"] == {"z": 2}
def test_serialize_and_deserialize_on_safe_cached_state(): safe = SafeResult("99", result_handler=JSONResultHandler()) safe_dct = SafeResult(dict(hi=5, bye=6), result_handler=JSONResultHandler()) now = pendulum.now("utc") state = Cached( cached_inputs=dict(x=safe, p=safe), result=safe_dct, cached_result_expiration=now, ) serialized = state.serialize() new_state = State.deserialize(serialized) assert isinstance(new_state, Cached) assert new_state.color == state.color assert new_state.result == dict(hi=5, bye=6) assert new_state.cached_result_expiration == state.cached_result_expiration assert new_state.cached_inputs == state.cached_inputs
def test_serialization_of_cached_inputs(): safe5 = SafeResult(5, result_handler=JSONResultHandler()) state = Pending(cached_inputs=dict(hi=safe5, bye=safe5)) serialized = state.serialize() new_state = State.deserialize(serialized) assert isinstance(new_state, Pending) assert new_state.cached_inputs == state.cached_inputs
def test_storing_happens_once(self): r = Result(value=4, result_handler=JSONResultHandler()) safe_value = SafeResult(value="123", result_handler=JSONResultHandler()) r.safe_value = safe_value r.store_safe_value() assert r.safe_value is safe_value
def test_to_result_returns_hydrated_result_for_safe(self): s = SafeResult("3", result_handler=JSONResultHandler()) res = s.to_result() assert isinstance(res, Result) assert res.value == 3 assert res.safe_value is s assert res.result_handler is s.result_handler
def test_task_runner_sends_checkpointed_success_states_to_cloud( self, client): handler = JSONResultHandler() @prefect.task(checkpoint=True, result_handler=handler) def add(x, y): return x + y x_state, y_state = Success(result=Result(1)), Success(result=Result(1)) upstream_states = { Edge(Task(), Task(), key="x"): x_state, Edge(Task(), Task(), key="y"): y_state, } res = CloudTaskRunner(task=add).run(upstream_states=upstream_states) ## assertions assert client.get_task_run_info.call_count == 0 # never called assert (client.set_task_run_state.call_count == 2 ) # Pending -> Running -> Successful states = [ call[1]["state"] for call in client.set_task_run_state.call_args_list ] assert states[0].is_running() assert states[1].is_successful() assert states[1]._result.safe_value == SafeResult( "2", result_handler=handler)
def test_task_runner_validates_cached_state_inputs_if_task_has_caching(client): @prefect.task( cache_for=datetime.timedelta(minutes=1), cache_validator=all_inputs, result_handler=JSONResultHandler(), ) def cached_task(x): return 42 dull_state = Cached( cached_result_expiration=datetime.datetime.utcnow() + datetime.timedelta(minutes=2), result=Result(-1, JSONResultHandler()), ) state = Cached( cached_result_expiration=datetime.datetime.utcnow() + datetime.timedelta(minutes=2), result=Result(99, JSONResultHandler()), cached_inputs={ "x": SafeResult("2", result_handler=JSONResultHandler()) }, ) client.get_latest_cached_states = MagicMock( return_value=[dull_state, state]) res = CloudTaskRunner(task=cached_task).check_task_is_cached( Pending(), inputs={"x": Result(2, result_handler=LocalResultHandler())}) assert client.get_latest_cached_states.called assert res.is_successful() assert res.is_cached() assert res.result == 99
def test_serialization_of_cached_inputs_with_safe_values(cls): safe5 = SafeResult(5, result_handler=JSONResultHandler()) state = cls(cached_inputs=dict(hi=safe5, bye=safe5)) serialized = state.serialize() new_state = State.deserialize(serialized) assert isinstance(new_state, cls) assert new_state.cached_inputs == state.cached_inputs
def test_value_raises_error_on_dump_if_not_valid_json(): r = SafeResult(value={"x": { "y": { "z": lambda: 1 } }}, result_handler=JSONResultHandler()) with pytest.raises(marshmallow.exceptions.ValidationError): StateResultSchema().dump(r)
def test_result_must_be_valid_json(): res = SafeResult({"x": { "y": { "z": 1 } }}, result_handler=JSONResultHandler()) s = state.Success(result=res) serialized = StateSchema().dump(s) assert serialized["_result"]["value"] == s.result
def test_serialize_state_with_safe_result(cls): res = SafeResult(value="1", result_handler=JSONResultHandler()) state = cls(message="message", result=res) serialized = StateSchema().dump(state) assert isinstance(serialized, dict) assert serialized["type"] == cls.__name__ assert serialized["message"] == "message" assert serialized["_result"]["type"] == "SafeResult" assert serialized["_result"]["value"] == "1" assert serialized["__version__"] == prefect.__version__
def test_result_raises_error_on_dump_if_not_valid_json(): res = SafeResult({"x": { "y": { "z": lambda: 1 } }}, result_handler=JSONResultHandler()) s = state.Success(result=res) with pytest.raises(marshmallow.exceptions.ValidationError): StateSchema().dump(s)
def test_to_result_uses_provided_result_handler(self): class WeirdHandler(ResultHandler): def read(self, loc): return 99 r = SafeResult("4", result_handler=JSONResultHandler()) out = r.to_result(result_handler=WeirdHandler()) assert isinstance(out, Result) assert isinstance(out.result_handler, WeirdHandler) assert out.value == 99 assert isinstance(out.safe_value.result_handler, WeirdHandler)
async def test_setting_a_task_run_state_pulls_cached_inputs_if_possible( self, task_run_id, state, running_flow_run_id ): res1 = SafeResult(1, result_handler=JSONResultHandler()) res2 = SafeResult({"z": 2}, result_handler=JSONResultHandler()) complex_result = {"x": res1, "y": res2} cached_state = Failed(cached_inputs=complex_result) await models.TaskRun.where(id=task_run_id).update( set=dict(serialized_state=cached_state.serialize()) ) # try to schedule the task run to scheduled await api.states.set_task_run_state(task_run_id=task_run_id, state=state) task_run = await models.TaskRun.where(id=task_run_id).first( {"serialized_state"} ) # ensure the state change took place assert task_run.serialized_state["type"] == type(state).__name__ assert task_run.serialized_state["cached_inputs"]["x"]["value"] == 1 assert task_run.serialized_state["cached_inputs"]["y"]["value"] == {"z": 2}
def test_set_task_run_state_serializes(patch_post): response = {"data": {"setTaskRunStates": {"states": [{"status": "SUCCESS"}]}}} post = patch_post(response) with set_temporary_config( {"cloud.api": "http://my-cloud.foo", "cloud.auth_token": "secret_token"} ): client = Client() res = SafeResult(lambda: None, result_handler=None) with pytest.raises(marshmallow.exceptions.ValidationError): client.set_task_run_state( task_run_id="76-salt", version=0, state=Pending(result=res) )
def test_serialize_and_deserialize_on_mixed_cached_state(): safe_dct = SafeResult(dict(hi=5, bye=6), result_handler=JSONResultHandler()) now = pendulum.now("utc") state = Cached( cached_inputs=dict(x=Result(2), p=Result("p")), result=safe_dct, cached_result_expiration=now, ) serialized = state.serialize() new_state = State.deserialize(serialized) assert isinstance(new_state, Cached) assert new_state.color == state.color assert new_state.result == dict(hi=5, bye=6) assert new_state.cached_result_expiration == state.cached_result_expiration assert new_state.cached_inputs == dict.fromkeys(["x", "p"], NoResult)
def test_set_task_run_state_serializes(monkeypatch): response = {"data": {"setTaskRunState": None}} post = MagicMock(return_value=MagicMock(json=MagicMock(return_value=response))) monkeypatch.setattr("requests.post", post) with set_temporary_config( {"cloud.graphql": "http://my-cloud.foo", "cloud.auth_token": "secret_token"} ): client = Client() res = SafeResult(lambda: None, result_handler=None) with pytest.raises(marshmallow.exceptions.ValidationError) as exc: result = client.set_task_run_state( task_run_id="76-salt", version=0, state=Pending(result=res) )
async def test_set_task_run_state_with_safe_result(self, run_query, task_run_id): result = SafeResult("10", result_handler=JSONResultHandler()) state = Success(result=result) result = await run_query( query=self.mutation, variables=dict(input=dict(states=[ dict(task_run_id=task_run_id, state=state.serialize()) ])), ) tr = await models.TaskRun.where( id=result.data.set_task_run_states.states[0].id ).first({"state", "version"}) assert tr.version == 2 assert tr.state == "Success"
def test_task_failure_with_upstream_secrets_doesnt_store_secret_value_and_recompute_if_necessary( client, ): @prefect.task(max_retries=2, retry_delay=timedelta(seconds=100)) def is_p_three(p): if p == 3: raise ValueError("No thank you.") return p with prefect.Flow("test", result_handler=JSONResultHandler()) as f: p = prefect.tasks.secrets.Secret("p") res = is_p_three(p) with prefect.context(secrets=dict(p=3)): state = CloudFlowRunner(flow=f).run(return_tasks=[res]) assert state.is_running() assert isinstance(state.result[res], Retrying) exp_res = Result(3, result_handler=SecretResultHandler(p)) assert not state.result[res].cached_inputs["p"] == exp_res exp_res.store_safe_value() assert state.result[res].cached_inputs["p"] == exp_res ## here we set the result of the secret to a saferesult, ensuring ## it will get converted to a "true" result; ## we expect that the upstream value will actually get recomputed from context ## through the SecretResultHandler safe = SafeResult("p", result_handler=SecretResultHandler(p)) state.result[p] = Success(result=safe) state.result[res].start_time = pendulum.now("utc") state.result[res].cached_inputs = dict(p=safe) with prefect.context(secrets=dict(p=4)): new_state = CloudFlowRunner(flow=f).run(return_tasks=[res], task_states=state.result) assert new_state.is_successful() assert new_state.result[res].result == 4
def test_basic_safe_result_repr(): r = SafeResult(2, result_handler=JSONResultHandler()) assert repr(r) == "<SafeResult: 2>"
def test_safe_result_inits_with_both_args(self): res = SafeResult(value="3", result_handler=JSONResultHandler()) assert res.value == "3" assert res.result_handler == JSONResultHandler() assert res.safe_value is res
def test_safe_results_are_same(self): r = SafeResult("3", result_handler=JSONResultHandler()) s = SafeResult("3", result_handler=JSONResultHandler()) assert r == s
def test_safe_results_with_different_handlers_are_not_same(self): r = SafeResult("3", result_handler=JSONResultHandler()) s = SafeResult("3", result_handler=LocalResultHandler()) assert r != s
def test_store_safe_value_for_safe_results(self): r = SafeResult(value=4, result_handler=JSONResultHandler()) output = r.store_safe_value() assert output is None assert isinstance(r.safe_value, SafeResult) assert r.value == 4
def test_safe_results_to_results_remain_the_same(self): r = SafeResult("3", result_handler=JSONResultHandler()) s = SafeResult("3", result_handler=JSONResultHandler()) assert r.to_result() == s.to_result()