def test_task_runner_handles_outputs_prior_to_setting_state(self, client): @prefect.task(cache_for=datetime.timedelta(days=1), result_handler=JSONResultHandler()) def add(x, y): return x + y result = Result(1, result_handler=JSONResultHandler()) assert result.safe_value is NoResult x_state, y_state = Success(result=result), Success(result=result) upstream_states = { Edge(Task(), Task(), key="x"): x_state, Edge(Task(), Task(), key="y"): y_state, } res = CloudTaskRunner(task=add).run(upstream_states=upstream_states) assert result.safe_value != NoResult # proves was handled ## assertions assert client.get_task_run_info.call_count == 0 # never called assert (client.set_task_run_state.call_count == 3 ) # Pending -> Running -> Successful -> Cached states = [ call[1]["state"] for call in client.set_task_run_state.call_args_list ] assert states[0].is_running() assert states[1].is_successful() assert isinstance(states[2], Cached) assert states[2].cached_inputs == dict(x=result, y=result) assert states[2].result == 2
def test_task_runner_sends_checkpointed_success_states_to_cloud( self, client): handler = JSONResultHandler() @prefect.task(checkpoint=True, result_handler=handler) def add(x, y): return x + y x_state, y_state = Success(result=Result(1)), Success(result=Result(1)) upstream_states = { Edge(Task(), Task(), key="x"): x_state, Edge(Task(), Task(), key="y"): y_state, } res = CloudTaskRunner(task=add).run(upstream_states=upstream_states) ## assertions assert client.get_task_run_info.call_count == 0 # never called assert (client.set_task_run_state.call_count == 2 ) # Pending -> Running -> Successful states = [ call[1]["state"] for call in client.set_task_run_state.call_args_list ] assert states[0].is_running() assert states[1].is_successful() assert states[1]._result.safe_value == SafeResult( "2", result_handler=handler)
def test_task_runner_sets_mapped_state_prior_to_executor_mapping(client): upstream_states = { Edge(Task(), Task(), key="foo", mapped=True): Success(result=[1, 2]) } class MyExecutor(prefect.engine.executors.LocalExecutor): def map(self, *args, **kwargs): raise SyntaxError("oops") with pytest.raises(SyntaxError): CloudTaskRunner(task=Task()).run_mapped_task( state=Pending(), upstream_states=upstream_states, context={}, executor=MyExecutor(), ) ## assertions assert client.get_task_run_info.call_count == 0 # never called assert client.set_task_run_state.call_count == 1 # Pending -> Mapped assert client.get_latest_cached_states.call_count == 0 last_set_state = client.set_task_run_state.call_args_list[-1][1]["state"] assert last_set_state.map_states == [None, None] assert last_set_state.is_mapped() assert "Preparing to submit 2" in last_set_state.message
def test_cloud_task_runner_handles_retries_with_queued_states_from_cloud( client): calls = [] def queued_mock(*args, **kwargs): calls.append(kwargs) # first retry attempt will get queued if len(calls) == 4: return Queued() # immediate start time else: return kwargs.get("state") client.set_task_run_state = queued_mock @prefect.task( max_retries=2, retry_delay=datetime.timedelta(seconds=0), result=PrefectResult(), ) def tagged_task(x): if prefect.context.get("task_run_count", 1) == 1: raise ValueError("gimme a sec") return x upstream_result = PrefectResult(value=42, location="42") res = CloudTaskRunner(task=tagged_task).run( context={"task_run_version": 1}, state=None, upstream_states={ Edge(Task(), tagged_task, key="x"): Success(result=upstream_result) }, ) assert res.is_successful() assert res.result == 42 assert (len(calls) == 6 ) # Running -> Failed -> Retrying -> Queued -> Running -> Success assert [type(c["state"]).__name__ for c in calls] == [ "Running", "Failed", "Retrying", "Running", "Running", "Success", ] assert calls[2]["state"].cached_inputs["x"].value == 42
def test_task_runner_doesnt_call_client_if_map_index_is_none(client): task = Task(name="test") res = CloudTaskRunner(task=task).run() ## assertions assert client.get_task_run_info.call_count == 0 # never called assert client.set_task_run_state.call_count == 2 # Pending -> Running -> Success assert client.get_latest_cached_states.call_count == 0 states = [ call[1]["state"] for call in client.set_task_run_state.call_args_list ] assert [type(s).__name__ for s in states] == ["Running", "Success"] assert res.is_successful() assert states[0].context == dict(tags=[]) assert states[1].context == dict(tags=[])
def test_task_runner_raises_endrun_if_client_cant_communicate_during_state_updates( monkeypatch): @prefect.task(name="test") def raise_error(): raise NameError("I don't exist") get_task_run_info = MagicMock(return_value=MagicMock(state=None)) set_task_run_state = MagicMock(side_effect=SyntaxError) client = MagicMock(get_task_run_info=get_task_run_info, set_task_run_state=set_task_run_state) monkeypatch.setattr("prefect.engine.cloud.task_runner.Client", MagicMock(return_value=client)) ## an ENDRUN will cause the TaskRunner to return the most recently computed state res = CloudTaskRunner(task=raise_error).run(context={"map_index": 1}) assert set_task_run_state.called assert res.is_running()
def test_task_handlers_handle_retry_signals(client): def state_handler(t, o, n): if n.is_failed(): raise prefect.engine.signals.RETRY( "Will retry.", start_time=pendulum.now("utc").add(days=1)) @prefect.task(state_handlers=[state_handler]) def fn(): 1 / 0 state = CloudTaskRunner(task=fn).run() assert state.is_retrying() assert state.run_count == 1 # to make it run state.start_time = pendulum.now("utc") new_state = CloudTaskRunner(task=fn).run(state=state) assert new_state.is_retrying() assert new_state.run_count == 2 states = [ call[1]["state"] for call in client.set_task_run_state.call_args_list ] assert [type(s).__name__ for s in states] == ["Running", "Retrying"] * 2
def test_task_runner_uses_upstream_result_handlers(client): class MyResult(Result): def read(self, *args, **kwargs): self.value = "cool" return self def write(self, *args, **kwargs): return self @prefect.task(result=PrefectResult()) def t(x): return x success = Success(result=PrefectResult(location="1")) upstream_states = {Edge(Task(result=MyResult()), t, key="x"): success} state = CloudTaskRunner(task=t).run(upstream_states=upstream_states) assert state.is_successful() assert state.result == "cool"
def test_state_handler_failures_are_handled_appropriately(client): def bad(*args, **kwargs): raise SyntaxError("Syntax Errors are nice because they're so unique") @prefect.task(on_failure=bad) def do_nothing(): raise ValueError("This task failed somehow") res = CloudTaskRunner(task=do_nothing).run() assert res.is_failed() assert "SyntaxError" in res.message assert isinstance(res.result, SyntaxError) assert client.set_task_run_state.call_count == 2 states = [call[1]["state"] for call in client.set_task_run_state.call_args_list] assert states[0].is_running() assert states[1].is_failed() assert isinstance(states[1].result, SyntaxError)
def test_task_runner_gracefully_handles_load_results_failures(client): class MyResult(Result): def read(self, *args, **kwargs): raise TypeError("something is wrong!") @prefect.task(result=PrefectResult()) def t(x): return x success = Success(result=MyResult(location="foo.txt")) upstream_states = {Edge(Task(result=MyResult()), t, key="x"): success} state = CloudTaskRunner(task=t).run(upstream_states=upstream_states) assert state.is_failed() assert "task results" in state.message assert client.set_task_run_state.call_count == 1 # Pending -> Failed states = [call[1]["state"] for call in client.set_task_run_state.call_args_list] assert [type(s).__name__ for s in states] == ["Failed"]
def test_task_runner_uses_cached_inputs_from_db_state(monkeypatch): @prefect.task(name="test") def add_one(x): return x + 1 db_state = Retrying(cached_inputs=dict(x=Result(41))) get_task_run_info = MagicMock(return_value=MagicMock(state=db_state)) set_task_run_state = MagicMock() client = MagicMock(get_task_run_info=get_task_run_info, set_task_run_state=set_task_run_state) monkeypatch.setattr("prefect.engine.cloud.task_runner.Client", MagicMock(return_value=client)) res = CloudTaskRunner(task=add_one).run(context={"map_index": 1}) ## assertions assert get_task_run_info.call_count == 1 # one time to pull latest state assert set_task_run_state.call_count == 2 # Pending -> Running -> Success assert res.is_successful() assert res.result == 42
def test_task_runner_places_task_tags_in_state_context_and_serializes_them(client): task = Task(name="test", tags=["1", "2", "tag"]) res = CloudTaskRunner(task=task).run() call_args = [c[1] for c in client.set_task_run_state.call_args_list] assert call_args[0]["state"].is_running() assert call_args[1]["state"].is_successful() assert set(call_args[0]["state"].context["tags"]) == set(["1", "2", "tag"]) assert set(call_args[1]["state"].context["tags"]) == set(["1", "2", "tag"])
def test_task_runner_set_task_name_same_as_prefect_context(client): @prefect.task(name="hey", task_run_name=lambda **kwargs: kwargs["config"]) def test_task(config): return edge = Edge(Task(), Task(), key="config") state = Success(result="any_value") res = CloudTaskRunner(task=test_task).run(upstream_states={edge: state}) assert client.set_task_run_name.call_count == 1 assert client.set_task_run_name.call_args[1]["name"] == "any_value"
def test_load_results_from_upstream_reads_results(self): result = PrefectResult(location="1") state = Success(result=result) assert result.value is None t = Task(result=PrefectResult()) edge = Edge(t, 2, key="x") new_state, upstreams = CloudTaskRunner(task=Task()).load_results( state=Pending(), upstream_states={edge: state}) assert upstreams[edge].result == 1
def test_task_runner_places_task_tags_in_state_context_and_serializes_them(monkeypatch): task = Task(name="test", tags=["1", "2", "tag"]) session = MagicMock() monkeypatch.setattr("prefect.client.client.GraphQLResult", MagicMock()) monkeypatch.setattr("requests.Session", MagicMock(return_value=session)) res = CloudTaskRunner(task=task).run() assert res.is_successful() ## extract the variables payload from the calls to POST call_vars = [ json.loads(call[1]["json"]["variables"]) for call in session.post.call_args_list ] # do some mainpulation to get the state payloads inputs = [c["input"]["states"][0] for c in call_vars if c is not None] assert inputs[0]["state"]["type"] == "Running" assert set(inputs[0]["state"]["context"]["tags"]) == set(["1", "2", "tag"]) assert inputs[-1]["state"]["type"] == "Success" assert set(inputs[-1]["state"]["context"]["tags"]) == set(["1", "2", "tag"])
def test_load_results_from_upstream_reads_results_using_upstream_handlers( self, cloud_api): class CustomResult(Result): def read(self, *args, **kwargs): return "foo-bar-baz".split("-") state = Success(result=PrefectResult(location="1")) edge = Edge(Task(result=CustomResult()), 2, key="x") new_state, upstreams = CloudTaskRunner(task=Task()).load_results( state=Pending(), upstream_states={edge: state}) assert upstreams[edge].result == ["foo", "bar", "baz"]
def test_task_runner_calls_get_task_run_info_if_map_index_is_not_none(client): task = Task(name="test") res = CloudTaskRunner(task=task).run(context={"map_index": 1}) ## assertions assert client.get_task_run_info.call_count == 1 # never called assert client.set_task_run_state.call_count == 2 # Pending -> Running -> Success states = [call[1]["state"] for call in client.set_task_run_state.call_args_list] assert [type(s).__name__ for s in states] == ["Running", "Success"]
def test_load_results_from_upstream_reads_secret_results(self, cloud_api): secret_result = SecretResult( prefect.tasks.secrets.PrefectSecret(name="foo")) state = Success(result=PrefectResult(location="foo")) with prefect.context(secrets=dict(foo=42)): edge = Edge(Task(result=secret_result), 2, key="x") new_state, upstreams = CloudTaskRunner(task=Task()).load_results( state=Pending(), upstream_states={edge: state}) assert upstreams[edge].result == 42
def test_task_runner_handles_version_lock_error(monkeypatch): client = MagicMock() monkeypatch.setattr("prefect.engine.cloud.task_runner.Client", MagicMock(return_value=client)) client.set_task_run_state.side_effect = VersionLockError() task = Task(name="test") runner = CloudTaskRunner(task=task) # successful state client.get_task_run_state.return_value = Success() res = runner.call_runner_target_handlers(Pending(), Running()) assert res.is_successful() # currently running client.get_task_run_state.return_value = Running() with pytest.raises(ENDRUN): runner.call_runner_target_handlers(Pending(), Running()) # result load error s = Success() s.load_result = MagicMock(side_effect=Exception()) client.get_task_run_state.return_value = s with pytest.raises(ENDRUN): res = runner.call_runner_target_handlers(Pending(), Running())
def test_task_runner_handles_looping(client): @prefect.task(result_handler=ResultHandler()) def looper(): if prefect.context.get("task_loop_count", 1) < 3: raise LOOP(result=prefect.context.get("task_loop_result", 0) + 10) return prefect.context.get("task_loop_result") res = CloudTaskRunner(task=looper).run( context={"task_run_version": 1}, state=None, upstream_states={}, executor=prefect.engine.executors.LocalExecutor(), ) ## assertions assert res.is_successful() assert client.get_task_run_info.call_count == 0 assert ( client.set_task_run_state.call_count == 6 ) # Pending -> Running -> Looped (1) -> Running -> Looped (2) -> Running -> Success versions = [call[1]["version"] for call in client.set_task_run_state.call_args_list] assert versions == [1, 2, 3, 4, 5, 6]
def test_task_runner_validates_cached_states_if_task_has_caching(client): @prefect.task(cache_for=datetime.timedelta(minutes=1), result=PrefectResult()) def cached_task(): return 42 state = Cached( cached_result_expiration=datetime.datetime.utcnow() - datetime.timedelta(minutes=2), result=PrefectResult(location="99"), ) old_state = Cached( cached_result_expiration=datetime.datetime.utcnow() - datetime.timedelta(days=1), result=PrefectResult(location="13"), ) client.get_latest_cached_states = MagicMock(return_value=[state, old_state]) res = CloudTaskRunner(task=cached_task).run() assert client.get_latest_cached_states.called assert res.is_successful() assert res.is_cached() assert res.result == 42
def test_task_runner_treats_unfound_files_as_invalid_caches(client, tmpdir): @prefect.task(cache_for=datetime.timedelta(minutes=1), result=PrefectResult()) def cached_task(): return 42 state = Cached( cached_result_expiration=datetime.datetime.utcnow() + datetime.timedelta(minutes=2), result=LocalResult(location=str(tmpdir / "made_up_data.prefect")), ) old_state = Cached( cached_result_expiration=datetime.datetime.utcnow() + datetime.timedelta(days=1), result=PrefectResult(location="13"), ) client.get_latest_cached_states = MagicMock(return_value=[state, old_state]) res = CloudTaskRunner(task=cached_task).run() assert client.get_latest_cached_states.called assert res.is_successful() assert res.is_cached() assert res.result == 13
def test_task_runner_handles_looping_with_no_result(client): @prefect.task(result_handler=ResultHandler()) def looper(): if prefect.context.get("task_loop_count", 1) < 3: raise LOOP() return 42 res = CloudTaskRunner(task=looper).run(context={"task_run_version": 1}, state=None, upstream_states={}) ## assertions assert res.is_successful() assert client.get_task_run_info.call_count == 0 assert ( client.set_task_run_state.call_count == 6 ) # Pending -> Running -> Looped (1) -> Running -> Looped (2) -> Running -> Success versions = [ call[1]["version"] for call in client.set_task_run_state.call_args_list if call[1]["version"] ] assert versions == [1, 3, 5]
def test_task_runner_queries_for_cached_states_if_task_has_caching(client): @prefect.task(cache_for=datetime.timedelta(minutes=1)) def cached_task(): return 42 state = Cached( cached_result_expiration=datetime.datetime.utcnow() + datetime.timedelta(days=1), result=Result(99, JSONResultHandler()), ) old_state = Cached( cached_result_expiration=datetime.datetime.utcnow() - datetime.timedelta(days=1), result=13, ) client.get_latest_cached_states = MagicMock(return_value=[state, old_state]) res = CloudTaskRunner(task=cached_task).run() assert client.get_latest_cached_states.called assert res.is_successful() assert res.is_cached() assert res.result == 99
def test_task_runner_performs_retries_for_short_delays(client): global_list = [] @prefect.task(max_retries=1, retry_delay=datetime.timedelta(seconds=0)) def noop(): if global_list: return else: global_list.append(0) raise ValueError("oops") res = CloudTaskRunner(task=noop).run( state=None, upstream_states={}, executor=prefect.engine.executors.LocalExecutor(), ) ## assertions assert res.is_successful() assert client.get_task_run_info.call_count == 0 # never called assert (client.set_task_run_state.call_count == 5 ) # Pending -> Running -> Failed -> Retrying -> Running -> Success
def test_task_runner_validates_cached_state_inputs_if_task_has_caching_and_uses_task_handler( self, client ): class MyResult(Result): def read(self, *args, **kwargs): new = self.copy() new.value = 1337 return new @prefect.task( cache_for=datetime.timedelta(minutes=1), cache_validator=all_inputs, result=MyResult(), ) def cached_task(x): return 42 dull_state = Cached( cached_result_expiration=datetime.datetime.utcnow() + datetime.timedelta(minutes=2), result=PrefectResult(location="-1"), ) state = Cached( cached_result_expiration=datetime.datetime.utcnow() + datetime.timedelta(minutes=2), result=PrefectResult(location="99"), cached_inputs={"x": PrefectResult(location="2")}, ) client.get_latest_cached_states = MagicMock(return_value=[dull_state, state]) res = CloudTaskRunner(task=cached_task).check_task_is_cached( Pending(), inputs={"x": PrefectResult(value=2)} ) assert client.get_latest_cached_states.called assert res.is_successful() assert res.is_cached() assert res.result == 1337
def test_task_runner_handles_looping_with_retries(client): # note that looping _requires_ a result handler in Cloud @prefect.task( max_retries=1, retry_delay=datetime.timedelta(seconds=0), result=PrefectResult(), ) def looper(): if ( prefect.context.get("task_loop_count") == 2 and prefect.context.get("task_run_count", 1) == 1 ): raise ValueError("Stop") if prefect.context.get("task_loop_count", 1) < 3: raise LOOP(result=prefect.context.get("task_loop_result", 0) + 10) return prefect.context.get("task_loop_result") client.get_task_run_info.side_effect = [ MagicMock(version=i, state=Pending() if i == 0 else Looped(loop_count=i)) for i in range(5) ] res = CloudTaskRunner(task=looper).run( context={"task_run_version": 1}, state=None, upstream_states={} ) ## assertions assert res.is_successful() assert client.get_task_run_info.call_count == 4 assert ( client.set_task_run_state.call_count == 9 ) # Pending -> Running -> Looped (1) -> Running -> Failed -> Retrying -> Running -> Looped(2) -> Running -> Success versions = [ call[1]["version"] for call in client.set_task_run_state.call_args_list if call[1]["version"] ] assert versions == [1, 2, 3]
def test_task_runner_respects_the_db_state(monkeypatch, state): task = Task(name="test") db_state = state("already", result=10) get_task_run_info = MagicMock(return_value=MagicMock(state=db_state)) set_task_run_state = MagicMock() client = MagicMock(get_task_run_info=get_task_run_info, set_task_run_state=set_task_run_state) monkeypatch.setattr("prefect.engine.cloud.task_runner.Client", MagicMock(return_value=client)) res = CloudTaskRunner(task=task).run(context={"map_index": 1}) ## assertions assert get_task_run_info.call_count == 1 # one time to pull latest state assert set_task_run_state.call_count == 0 # never needs to update state assert res == db_state
def test_task_runner_raises_endrun_with_correct_state_if_client_cant_receive_state_updates( monkeypatch, ): task = Task(name="test") get_task_run_info = MagicMock(side_effect=SyntaxError) set_task_run_state = MagicMock() client = MagicMock(get_task_run_info=get_task_run_info, set_task_run_state=set_task_run_state) monkeypatch.setattr("prefect.engine.cloud.task_runner.Client", MagicMock(return_value=client)) ## an ENDRUN will cause the TaskRunner to return the most recently computed state state = Pending(message="unique message", result=42) res = CloudTaskRunner(task=task).run(state=state, context={"map_index": 1}) assert get_task_run_info.called assert res is state
def test_load_results_from_upstream_reads_cached_inputs_using_upstream_results( self, ): class CustomResult(Result): def read(self, *args, **kwargs): self.value = 99 return self result = PrefectResult(location="1") state = Pending(cached_inputs=dict(x=result)) edge = Edge(Task(result=CustomResult()), 2, key="x") new_state, upstreams = CloudTaskRunner(task=Task( result=PrefectResult())).load_results( state=state, upstream_states={edge: Success(result=result)}) assert new_state.cached_inputs["x"].value == 99