class TestCheckFlowPendingOrRunning: @pytest.mark.parametrize( "state", [Pending(), Running(), Retrying(), Scheduled()]) def test_pending_or_running_are_ok(self, state): flow = Flow(name="test", tasks=[Task()]) new_state = FlowRunner(flow=flow).check_flow_is_pending_or_running( state=state) assert new_state is state @pytest.mark.parametrize( "state", [Finished(), Success(), Failed(), Skipped(), State()]) def test_not_pending_or_running_raise_endrun(self, state): flow = Flow(name="test", tasks=[Task()]) with pytest.raises(ENDRUN): FlowRunner(flow=flow).check_flow_is_pending_or_running(state=state)
def test_flow_runner_prioritizes_kwarg_states_over_db_states( monkeypatch, state, client): flow = prefect.Flow(name="test") db_state = state("already", result=10) get_flow_run_info = MagicMock(return_value=MagicMock(state=db_state)) client.get_flow_run_info = get_flow_run_info monkeypatch.setattr("prefect.engine.cloud.flow_runner.Client", MagicMock(return_value=client)) res = CloudFlowRunner(flow=flow).run(state=Pending("let's do this")) ## assertions assert get_flow_run_info.call_count == 2 # initial state & cancel check assert client.set_flow_run_state.call_count == 2 # Pending -> Running -> Success states = [ call[1]["state"] for call in client.set_flow_run_state.call_args_list ] assert states == [Running(), Success(result={})]
def set_flow_to_running(self, state: State) -> State: """ Puts Pending flows in a Running state; leaves Running flows Running. Args: - state (State): the current state of this flow Returns: - State: the state of the flow after running the check Raises: - ENDRUN: if the flow is not pending or running """ if state.is_pending(): return Running(message="Running flow.") elif state.is_running(): return state else: raise ENDRUN(state)
async def test_set_flow_run_state_with_bad_version(self, run_query, locked_flow_run_id): result = await run_query( query=self.mutation, variables=dict(input=dict(states=[ dict( flow_run_id=locked_flow_run_id, version=10, state=Running().serialize(), ) ])), ) assert "State update failed" in result.errors[0].message fr = await models.FlowRun.where(id=locked_flow_run_id ).first({"state", "version"}) assert fr.version == 1 assert fr.state == "Scheduled"
async def test_set_task_run_state_with_version(self, run_query, task_run_id, running_flow_run_id): result = await run_query( query=self.mutation, variables=dict(input=dict(states=[ dict( task_run_id=task_run_id, version=0, state=Running().serialize(), ) ])), ) assert result.data.set_task_run_states.states[0].id == task_run_id tr = await models.TaskRun.where(id=task_run_id ).first({"state", "version"}) assert tr.version == 1 assert tr.state == "Running"
def determine_final_state( self, key_states: Set[State], return_states: Dict[Task, State], terminal_states: Set[State], ) -> State: """ Implements the logic for determining the final state of the flow run. Args: - key_states (Set[State]): the states which will determine the success / failure of the flow run - return_states (Dict[Task, State]): states to return as results - terminal_states (Set[State]): the states of the terminal tasks for this flow Returns: - State: the final state of the flow run """ state = State() # mypy initialization # check that the flow is finished if not all(s.is_finished() for s in terminal_states): self.logger.info("Flow run RUNNING: terminal tasks are incomplete.") state = Running(message="Flow run in progress.", result=return_states) # check if any key task failed elif any(s.is_failed() for s in key_states): self.logger.info("Flow run FAILED: some reference tasks failed.") state = Failed(message="Some reference tasks failed.", result=return_states) # check if all reference tasks succeeded elif all(s.is_successful() for s in key_states): self.logger.info("Flow run SUCCESS: all reference tasks succeeded") state = Success( message="All reference tasks succeeded.", result=return_states ) # check for any unanticipated state that is finished but neither success nor failed else: self.logger.info("Flow run SUCCESS: no reference tasks failed") state = Success(message="No reference tasks failed.", result=return_states) return state
async def test_set_multiple_task_run_states( self, run_query, task_run_id, task_run_id_2, task_run_id_3, running_flow_run_id ): result = await run_query( query=self.mutation, variables=dict( input=dict( states=[ dict(task_run_id=task_run_id, state=Running().serialize()), dict(task_run_id=task_run_id_2, state=Success().serialize()), dict( task_run_id=task_run_id_3, version=1, state=Retrying().serialize(), ), ] ) ), ) assert result.data.set_task_run_states.states == [ {"id": task_run_id, "status": "SUCCESS", "message": None}, {"id": task_run_id_2, "status": "SUCCESS", "message": None}, {"id": task_run_id_3, "status": "SUCCESS", "message": None}, ] tr1 = await models.TaskRun.where( id=result.data.set_task_run_states.states[0].id ).first({"state", "version"}) assert tr1.version == 2 assert tr1.state == "Running" tr2 = await models.TaskRun.where( id=result.data.set_task_run_states.states[1].id ).first({"state", "version"}) assert tr2.version == 3 assert tr2.state == "Success" tr3 = await models.TaskRun.where( id=result.data.set_task_run_states.states[2].id ).first({"state", "version"}) assert tr3.version == 3 assert tr3.state == "Retrying"
def test_watch_flow_run_default_timeout(monkeypatch): # Test the default behavior, which sets the timeout to 12 hours # when the `max_duration` kwarg is not provided flow_run = FlowRunView._from_flow_run_data(FLOW_RUN_DATA_1) flow_run.state = Running() # Not finished flow_run.get_latest = MagicMock(return_value=flow_run) flow_run.get_logs = MagicMock() MockView = MagicMock() MockView.from_flow_run_id.return_value = flow_run monkeypatch.setattr("prefect.backend.flow_run.FlowRunView", MockView) # Mock sleep so that we do not have a slow test monkeypatch.setattr("prefect.backend.flow_run.time.sleep", MagicMock()) with pytest.raises(RuntimeError, match="timed out after 12.0 hours of waiting"): for log in watch_flow_run("id"): pass
async def test_set_flow_run_state(self, run_query, flow_run_id): result = await run_query( query=self.mutation, variables=dict(input=dict(states=[ dict( flow_run_id=flow_run_id, version=1, state=Running().serialize(), ) ])), ) assert result.data.set_flow_run_states.states[0].status == "SUCCESS" assert result.data.set_flow_run_states.states[0].message is None fr = await models.FlowRun.where( id=result.data.set_flow_run_states.states[0].id ).first({"state", "version"}) assert fr.version == 2 assert fr.state == "Running"
async def test_running_states_can_not_be_set_if_flow_run_is_not_running( self, flow_run_id, task_run_id, flow_run_state): await api.states.set_flow_run_state(flow_run_id=flow_run_id, state=flow_run_state) set_running_coroutine = api.states.set_task_run_state( task_run_id=task_run_id, state=Running()) if flow_run_state.is_running(): assert await set_running_coroutine assert (await models.TaskRun.where(id=task_run_id ).first({"state"})).state == "Running" else: with pytest.raises(ValueError, match="is not in a running state"): await set_running_coroutine assert (await models.TaskRun.where(id=task_run_id).first( {"state"})).state != "Running"
def set_task_to_running(self, state: State) -> State: """ Sets the task to running Args: - state (State): the current state of this task Returns: - State: the state of the task after running the check Raises: - ENDRUN: if the task is not ready to run """ if not state.is_pending(): self.logger.debug( "Task '{name}': can't set state to Running because it " "isn't Pending; ending run.".format(name=prefect.context.get( "task_full_name", self.task.name))) raise ENDRUN(state) return Running(message="Starting task run.")
def test_watch_flow_run_timeout(monkeypatch): flow_run = FlowRunView._from_flow_run_data(FLOW_RUN_DATA_1) flow_run.state = Running() # Not finished flow_run.get_latest = MagicMock(return_value=flow_run) flow_run.get_logs = MagicMock() MockView = MagicMock() MockView.from_flow_run_id.return_value = flow_run monkeypatch.setattr("prefect.backend.flow_run.FlowRunView", MockView) # Mock sleep so that we do not have a slow test monkeypatch.setattr("prefect.backend.flow_run.time.sleep", MagicMock()) with pytest.raises(RuntimeError, match="timed out after 36.5 hours of waiting"): for log in watch_flow_run("id", max_duration=timedelta(days=1, hours=12.5, seconds=1)): pass
def vclient(monkeypatch): cloud_client = MagicMock( get_flow_run_info=MagicMock(return_value=MagicMock(state=None)), set_flow_run_state=MagicMock(), get_task_run_info=MagicMock(return_value=MagicMock(state=None)), set_task_run_state=MagicMock( side_effect=VersionLockError(), return_value=Running() # side_effect=lambda task_run_id, version, state, cache_for: state ), get_latest_task_run_states=MagicMock( side_effect=lambda flow_run_id, states: states ), ) monkeypatch.setattr( "prefect.engine.cloud.task_runner.Client", MagicMock(return_value=cloud_client) ) monkeypatch.setattr( "prefect.engine.cloud.flow_runner.Client", MagicMock(return_value=cloud_client) ) yield cloud_client
async def test_set_task_run_states_rejects_states_with_large_payloads( self, run_query, task_run_id, task_run_id_2, running_flow_run_id): result = await run_query( query=self.mutation, variables=dict(input=dict(states=[ dict( task_run_id=task_run_id, # this state should successfully set state=Running().serialize(), ), dict( task_run_id=task_run_id_2, # nonsense payload, just large state={ i: os.urandom(2 * 1000000).decode("latin") for i in range(2) }, ), ])), ) assert "State payload is too large" in result.errors[0].message
class TestPrefectMessageCloudHook: @pytest.mark.parametrize("state", [Running(), Success(), Failed()]) async def test_prefect_message_cloud_hook( self, run_query, flow_run_id, state, tenant_id ): await models.Message.where().delete() cloud_hook_id = await api.cloud_hooks.create_cloud_hook( tenant_id=tenant_id, type="PREFECT_MESSAGE", config={}, states=[type(state).__name__], ) set_flow_run_state_mutation = """ mutation($input: set_flow_run_states_input!) { set_flow_run_states(input: $input) { states { id status message } } } """ await run_query( query=set_flow_run_state_mutation, variables=dict( input=dict( states=[ dict( flow_run_id=flow_run_id, version=1, state=state.serialize() ) ] ) ), ) await asyncio.sleep(1) assert ( await models.Message.where({"tenant_id": {"_eq": tenant_id}}).count() == 1 )
class TestCheckScheduledStep: @pytest.mark.parametrize("state", [Failed(), Pending(), Running(), Success()]) def test_non_scheduled_states(self, state): assert ( FlowRunner(flow=Flow(name="test")).check_flow_reached_start_time( state=state ) is state ) def test_scheduled_states_without_start_time(self): state = Scheduled(start_time=None) assert ( FlowRunner(flow=Flow(name="test")).check_flow_reached_start_time( state=state ) is state ) def test_scheduled_states_with_future_start_time(self): state = Scheduled( start_time=pendulum.now("utc") + datetime.timedelta(minutes=10) ) with pytest.raises(ENDRUN) as exc: FlowRunner(flow=Flow(name="test")).check_flow_reached_start_time( state=state ) assert exc.value.state is state def test_scheduled_states_with_past_start_time(self): state = Scheduled( start_time=pendulum.now("utc") - datetime.timedelta(minutes=1) ) assert ( FlowRunner(flow=Flow(name="test")).check_flow_reached_start_time( state=state ) is state )
def test_simple_two_task_flow_with_final_task_already_running( monkeypatch, executor): flow_run_id = str(uuid.uuid4()) task_run_id_1 = str(uuid.uuid4()) task_run_id_2 = str(uuid.uuid4()) with prefect.Flow(name="test") as flow: t1 = prefect.Task() t2 = prefect.Task() t2.set_upstream(t1) client = MockedCloudClient( flow_runs=[FlowRun(id=flow_run_id)], task_runs=[ TaskRun(id=task_run_id_1, task_slug=flow.slugs[t1], flow_run_id=flow_run_id), TaskRun( id=task_run_id_2, task_slug=flow.slugs[t2], version=1, flow_run_id=flow_run_id, state=Running(), ), ], monkeypatch=monkeypatch, ) with prefect.context(flow_run_id=flow_run_id): state = CloudFlowRunner(flow=flow).run(return_tasks=flow.tasks, executor=executor) assert state.is_running() assert client.flow_runs[flow_run_id].state.is_running() assert client.task_runs[task_run_id_1].state.is_successful() assert client.task_runs[task_run_id_1].version == 2 assert client.task_runs[task_run_id_2].state.is_running() assert client.task_runs[task_run_id_2].version == 1
def test_set_flow_run_state_gets_queued(patch_post): response = { "data": { "set_flow_run_states": { "states": [{"id": "74-salt", "status": "QUEUED", "message": None}] } } } post = patch_post(response) with set_temporary_config( { "cloud.api": "http://my-cloud.foo", "cloud.auth_token": "secret_token", "backend": "cloud", } ): client = Client() state = Running() result = client.set_flow_run_state(flow_run_id="74-salt", version=0, state=state) assert isinstance(result, State) assert state != result assert result.is_queued()
class TestCancelFlowRun: mutation = """ mutation($input: cancel_flow_run_input!) { cancel_flow_run(input: $input) { state } } """ @pytest.mark.parametrize( "state,res_state,version", [ (Running(), "Cancelling", 4), (Success(), "Success", 3), (Submitted(), "Cancelled", 4), ], ) async def test_cancel_flow_run(self, run_query, flow_run_id, state, res_state, version): await api.states.set_flow_run_state(flow_run_id=flow_run_id, version=1, state=state) result = await run_query( query=self.mutation, variables={"input": { "flow_run_id": flow_run_id }}, ) assert result.data.cancel_flow_run.state == res_state fr = await models.FlowRun.where(id=flow_run_id ).first({"state", "version"}) assert fr.version == version assert fr.state == res_state
async def test_set_multiple_flow_run_states_with_error( self, run_query, flow_run_id, flow_run_id_2, flow_run_id_3): result = await run_query( query=self.mutation, variables=dict(input=dict(states=[ dict( flow_run_id=flow_run_id, version=1, state=Running().serialize(), ), dict( flow_run_id=flow_run_id_2, version=10, state="a bad state", ), dict( flow_run_id=flow_run_id_3, version=3, state=Retrying().serialize(), ), ])), ) assert result.data.set_flow_run_states is None assert result.errors[0].message
async def test_returns_status_from_underlying_call( self, run_query, flow_run_id, payload_response, monkeypatch ): """ This test should ensure that the `status` field should be determined based on the underlying `api.states.set_flow_run_state()` call. """ mock_state_api = CoroutineMock(return_value=payload_response) monkeypatch.setattr( "src.prefect_server.graphql.states.api.states.set_flow_run_state", mock_state_api, ) result = await run_query( query=self.mutation, variables=dict( input=dict( states=[ dict( flow_run_id=flow_run_id, version=1, state=Running().serialize(), ) ] ) ), ) mock_state_api.assert_awaited_once() assert ( result.data.set_flow_run_states.states[0].status == payload_response["status"] )
class TestFlowRunStates: async def test_returns_status_dict(self, flow_run_id: str): result = await states.set_flow_run_state(flow_run_id, state=Success()) assert result["status"] == "SUCCESS" @pytest.mark.parametrize( "state_cls", [s for s in State.children() if s not in _MetaState.children()]) async def test_set_flow_run_state(self, flow_run_id, state_cls): result = await states.set_flow_run_state(flow_run_id=flow_run_id, state=state_cls()) query = await models.FlowRun.where(id=flow_run_id).first( {"version", "state", "serialized_state"}) assert query.version == 2 assert query.state == state_cls.__name__ assert query.serialized_state["type"] == state_cls.__name__ @pytest.mark.parametrize("state", [Running(), Success()]) async def test_set_flow_run_state_fails_with_wrong_flow_run_id( self, state): with pytest.raises(ValueError, match="Invalid flow run ID"): await states.set_flow_run_state(flow_run_id=str(uuid.uuid4()), state=state) async def test_trigger_failed_state_does_not_set_end_time( self, flow_run_id): # there is no logic in Prefect that would create this sequence of # events, but a user could manually do this await states.set_flow_run_state(flow_run_id=flow_run_id, state=TriggerFailed()) flow_run_info = await models.FlowRun.where(id=flow_run_id).first( {"id", "start_time", "end_time"}) assert not flow_run_info.start_time assert not flow_run_info.end_time
dict(state=Cancelled(), assert_true={"is_finished"}), dict(state=Cached(), assert_true={"is_cached", "is_finished", "is_successful"}), dict(state=ClientFailed(), assert_true={"is_meta_state"}), dict(state=Failed(), assert_true={"is_finished", "is_failed"}), dict(state=Finished(), assert_true={"is_finished"}), dict(state=Looped(), assert_true={"is_finished", "is_looped"}), dict(state=Mapped(), assert_true={"is_finished", "is_mapped", "is_successful"}), dict(state=Paused(), assert_true={"is_pending", "is_scheduled"}), dict(state=Pending(), assert_true={"is_pending"}), dict(state=Queued(), assert_true={"is_meta_state", "is_queued"}), dict(state=Resume(), assert_true={"is_pending", "is_scheduled"}), dict(state=Retrying(), assert_true={"is_pending", "is_scheduled", "is_retrying"}), dict(state=Running(), assert_true={"is_running"}), dict(state=Scheduled(), assert_true={"is_pending", "is_scheduled"}), dict(state=Skipped(), assert_true={"is_finished", "is_successful", "is_skipped"}), dict(state=Submitted(), assert_true={"is_meta_state", "is_submitted"}), dict(state=Success(), assert_true={"is_finished", "is_successful"}), dict(state=TimedOut(), assert_true={"is_finished", "is_failed"}), dict(state=TriggerFailed(), assert_true={"is_finished", "is_failed"}), ], ) def test_state_is_methods(state_check): """ Iterates over all of the "is_*()" methods of the state, asserting that each one is False, unless the name of that method is provided as `assert_true`. For example, if `state_check == (Pending(), {'is_pending'})`, then this method will
class TestTaskRunStates: async def test_set_task_run_state(self, task_run_id): result = await api.states.set_task_run_state( task_run_id=task_run_id, state=Failed() ) assert result.task_run_id == task_run_id query = await models.TaskRun.where(id=task_run_id).first( {"version", "state", "serialized_state"} ) assert query.version == 2 assert query.state == "Failed" assert query.serialized_state["type"] == "Failed" @pytest.mark.parametrize("state", [Failed(), Success()]) async def test_set_task_run_state_fails_with_wrong_task_run_id(self, state): with pytest.raises(ValueError, match="State update failed"): await api.states.set_task_run_state( task_run_id=str(uuid.uuid4()), state=state ) @pytest.mark.parametrize( "state", [s() for s in State.children() if not s().is_running()] ) async def test_state_does_not_set_heartbeat_unless_running( self, state, task_run_id ): task_run = await models.TaskRun.where(id=task_run_id).first({"heartbeat"}) assert task_run.heartbeat is None await api.states.set_task_run_state(task_run_id=task_run_id, state=state) task_run = await models.TaskRun.where(id=task_run_id).first({"heartbeat"}) assert task_run.heartbeat is None async def test_running_state_sets_heartbeat(self, task_run_id, running_flow_run_id): task_run = await models.TaskRun.where(id=task_run_id).first({"heartbeat"}) assert task_run.heartbeat is None dt = pendulum.now("UTC") await api.states.set_task_run_state(task_run_id=task_run_id, state=Running()) task_run = await models.TaskRun.where(id=task_run_id).first({"heartbeat"}) assert task_run.heartbeat > dt async def test_trigger_failed_state_does_not_set_end_time(self, task_run_id): await api.states.set_task_run_state( task_run_id=task_run_id, state=TriggerFailed() ) task_run_info = await models.TaskRun.where(id=task_run_id).first( {"id", "start_time", "end_time"} ) assert not task_run_info.start_time assert not task_run_info.end_time @pytest.mark.parametrize( "state", [s() for s in State.children() if s not in _MetaState.children()], ids=[s.__name__ for s in State.children() if s not in _MetaState.children()], ) async def test_setting_a_task_run_state_pulls_cached_inputs_if_possible( self, task_run_id, state, running_flow_run_id ): res1 = SafeResult(1, result_handler=JSONResultHandler()) res2 = SafeResult({"z": 2}, result_handler=JSONResultHandler()) complex_result = {"x": res1, "y": res2} cached_state = Failed(cached_inputs=complex_result) await models.TaskRun.where(id=task_run_id).update( set=dict(serialized_state=cached_state.serialize()) ) # try to schedule the task run to scheduled await api.states.set_task_run_state(task_run_id=task_run_id, state=state) task_run = await models.TaskRun.where(id=task_run_id).first( {"serialized_state"} ) # ensure the state change took place assert task_run.serialized_state["type"] == type(state).__name__ assert task_run.serialized_state["cached_inputs"]["x"]["value"] == 1 assert task_run.serialized_state["cached_inputs"]["y"]["value"] == {"z": 2} @pytest.mark.parametrize( "state", [ s(cached_inputs=None) for s in State.children() if s not in _MetaState.children() ], ids=[s.__name__ for s in State.children() if s not in _MetaState.children()], ) async def test_task_runs_with_null_cached_inputs_do_not_overwrite_cache( self, state, task_run_id, running_flow_run_id ): await api.states.set_task_run_state(task_run_id=task_run_id, state=state) # set up a Retrying state with non-null cached inputs res1 = SafeResult(1, result_handler=JSONResultHandler()) res2 = SafeResult({"z": 2}, result_handler=JSONResultHandler()) complex_result = {"x": res1, "y": res2} cached_state = Retrying(cached_inputs=complex_result) await api.states.set_task_run_state(task_run_id=task_run_id, state=cached_state) run = await models.TaskRun.where(id=task_run_id).first({"serialized_state"}) assert run.serialized_state["cached_inputs"]["x"]["value"] == 1 assert run.serialized_state["cached_inputs"]["y"]["value"] == {"z": 2} @pytest.mark.parametrize( "state_cls", [s for s in State.children() if s not in _MetaState.children()] ) async def test_task_runs_cached_inputs_give_preference_to_new_cached_inputs( self, state_cls, task_run_id, running_flow_run_id ): # set up a Failed state with null cached inputs res1 = SafeResult(1, result_handler=JSONResultHandler()) res2 = SafeResult({"a": 2}, result_handler=JSONResultHandler()) complex_result = {"b": res1, "c": res2} cached_state = state_cls(cached_inputs=complex_result) await api.states.set_task_run_state(task_run_id=task_run_id, state=cached_state) # set up a Retrying state with non-null cached inputs res1 = SafeResult(1, result_handler=JSONResultHandler()) res2 = SafeResult({"z": 2}, result_handler=JSONResultHandler()) complex_result = {"x": res1, "y": res2} cached_state = Retrying(cached_inputs=complex_result) await api.states.set_task_run_state(task_run_id=task_run_id, state=cached_state) run = Box( await models.TaskRun.where(id=task_run_id).first({"serialized_state"}) ) # verify that we have cached inputs, and that preference has been given to the new # cached inputs assert run.serialized_state.cached_inputs assert run.serialized_state.cached_inputs.x.value == 1 assert run.serialized_state.cached_inputs.y.value == {"z": 2} @pytest.mark.parametrize( "flow_run_state", [Pending(), Running(), Failed(), Success()] ) async def test_running_states_can_not_be_set_if_flow_run_is_not_running( self, flow_run_id, task_run_id, flow_run_state ): await api.states.set_flow_run_state( flow_run_id=flow_run_id, state=flow_run_state ) set_running_coroutine = api.states.set_task_run_state( task_run_id=task_run_id, state=Running() ) if flow_run_state.is_running(): assert await set_running_coroutine assert ( await models.TaskRun.where(id=task_run_id).first({"state"}) ).state == "Running" else: with pytest.raises(ValueError, match="is not in a running state"): await set_running_coroutine assert ( await models.TaskRun.where(id=task_run_id).first({"state"}) ).state != "Running"
class TestFlowRunStates: async def test_set_flow_run_state(self, flow_run_id): result = await api.states.set_flow_run_state( flow_run_id=flow_run_id, state=Running() ) assert result.flow_run_id == flow_run_id query = await models.FlowRun.where(id=flow_run_id).first( {"version", "state", "serialized_state"} ) assert query.version == 3 assert query.state == "Running" assert query.serialized_state["type"] == "Running" @pytest.mark.parametrize("state", [Running(), Success()]) async def test_set_flow_run_state_fails_with_wrong_flow_run_id(self, state): with pytest.raises(ValueError, match="State update failed"): await api.states.set_flow_run_state( flow_run_id=str(uuid.uuid4()), state=state ) async def test_trigger_failed_state_does_not_set_end_time(self, flow_run_id): # there is no logic in Prefect that would create this sequence of # events, but a user could manually do this await api.states.set_flow_run_state( flow_run_id=flow_run_id, state=TriggerFailed() ) flow_run_info = await models.FlowRun.where(id=flow_run_id).first( {"id", "start_time", "end_time"} ) assert not flow_run_info.start_time assert not flow_run_info.end_time @pytest.mark.parametrize( "state", [ s() for s in State.children() if not s().is_running() and not s().is_submitted() ], ) async def test_state_does_not_set_heartbeat_unless_running_or_submitted( self, state, flow_run_id ): flow_run = await models.FlowRun.where(id=flow_run_id).first({"heartbeat"}) assert flow_run.heartbeat is None dt = pendulum.now("UTC") await api.states.set_flow_run_state(flow_run_id=flow_run_id, state=state) flow_run = await models.FlowRun.where(id=flow_run_id).first({"heartbeat"}) assert flow_run.heartbeat is None @pytest.mark.parametrize("state", [Running(), Submitted()]) async def test_running_and_submitted_state_sets_heartbeat(self, state, flow_run_id): """ Both Running and Submitted states need to set heartbeats for services like Lazarus to function properly. """ flow_run = await models.FlowRun.where(id=flow_run_id).first({"heartbeat"}) assert flow_run.heartbeat is None dt = pendulum.now("UTC") await api.states.set_flow_run_state(flow_run_id=flow_run_id, state=state) flow_run = await models.FlowRun.where(id=flow_run_id).first({"heartbeat"}) assert flow_run.heartbeat > dt async def test_setting_flow_run_to_cancelled_state_sets_unfinished_task_runs_to_cancelled( self, flow_run_id ): task_runs = await models.TaskRun.where( {"flow_run_id": {"_eq": flow_run_id}} ).get({"id"}) task_run_ids = [run.id for run in task_runs] # update the state to Running await api.states.set_flow_run_state(flow_run_id=flow_run_id, state=Running()) # Currently this flow_run_id fixture has at least 3 tasks, if this # changes the test will need to be updated assert len(task_run_ids) >= 3, "flow_run_id fixture has changed" # Set one task run to pending, one to running, and the rest to success pending_task_run = task_run_ids[0] running_task_run = task_run_ids[1] rest = task_run_ids[2:] await api.states.set_task_run_state( task_run_id=pending_task_run, state=Pending() ) await api.states.set_task_run_state( task_run_id=running_task_run, state=Running() ) for task_run_id in rest: await api.states.set_task_run_state( task_run_id=task_run_id, state=Success() ) # set the flow run to a cancelled state await api.states.set_flow_run_state(flow_run_id=flow_run_id, state=Cancelled()) # Confirm the unfinished task runs have been marked as cancelled task_runs = await models.TaskRun.where( {"flow_run_id": {"_eq": flow_run_id}} ).get({"id", "state"}) new_states = {run.id: run.state for run in task_runs} assert new_states[pending_task_run] == "Cancelled" assert new_states[running_task_run] == "Cancelled" assert all(new_states[id] == "Success" for id in rest)
class TestExecuteFlowRunInSubprocess: @pytest.fixture() def mocks(self, monkeypatch): class Mocks: subprocess = MagicMock() wait_for_flow_run_start_time = MagicMock() fail_flow_run = MagicMock() mocks = Mocks() monkeypatch.setattr("prefect.backend.execution.subprocess", mocks.subprocess) monkeypatch.setattr( "prefect.backend.execution._wait_for_flow_run_start_time", mocks.wait_for_flow_run_start_time, ) monkeypatch.setattr("prefect.backend.execution._fail_flow_run", mocks.fail_flow_run) # Since we mocked the module this error cannot be used in try/catch without # replacing it with the correct type mocks.subprocess.CalledProcessError = CalledProcessError return mocks def test_creates_subprocess_correctly(self, cloud_mocks, mocks): # Returned a scheduled flow run to start cloud_mocks.FlowRunView.from_flow_run_id().state = Scheduled() # Return a finished flow run after the first iteration cloud_mocks.FlowRunView().get_latest().state = Success() execute_flow_run_in_subprocess("flow-run-id") # Should pass the correct flow run id to wait for mocks.wait_for_flow_run_start_time.assert_called_once_with( "flow-run-id") # Calls the correct command w/ environment variables mocks.subprocess.run.assert_called_once_with( [sys.executable, "-m", "prefect", "execute", "flow-run"], env={ "PREFECT__CLOUD__SEND_FLOW_RUN_LOGS": "True", "PREFECT__LOGGING__LEVEL": "INFO", "PREFECT__LOGGING__FORMAT": "[%(asctime)s] %(levelname)s - %(name)s | %(message)s", "PREFECT__LOGGING__DATEFMT": "%Y-%m-%d %H:%M:%S%z", "PREFECT__BACKEND": "cloud", "PREFECT__CLOUD__API": "https://api.prefect.io", "PREFECT__CLOUD__TENANT_ID": "", "PREFECT__CLOUD__API_KEY": cloud_mocks.Client().api_key, "PREFECT__CONTEXT__FLOW_RUN_ID": "flow-run-id", "PREFECT__CONTEXT__FLOW_ID": cloud_mocks.FlowRunView.from_flow_run_id().flow_id, "PREFECT__ENGINE__FLOW_RUNNER__DEFAULT_CLASS": "prefect.engine.cloud.CloudFlowRunner", "PREFECT__ENGINE__TASK_RUNNER__DEFAULT_CLASS": "prefect.engine.cloud.CloudTaskRunner", }, ) # Return code is checked mocks.subprocess.run().check_returncode.assert_called_once() @pytest.mark.parametrize("start_state", [Submitted(), Running()]) def test_fails_immediately_if_flow_run_is_being_executed_elsewhere( self, cloud_mocks, start_state, mocks): cloud_mocks.FlowRunView.from_flow_run_id().state = start_state with pytest.raises(RuntimeError, match="already in state"): execute_flow_run_in_subprocess("flow-run-id") def test_handles_signal_interrupt(self, cloud_mocks, mocks): cloud_mocks.FlowRunView.from_flow_run_id().state = Scheduled() mocks.subprocess.run.side_effect = KeyboardInterrupt() # Keyboard interrupt should be re-raised with pytest.raises(KeyboardInterrupt): execute_flow_run_in_subprocess("flow-run-id") # Only tried to run once mocks.subprocess.run.assert_called_once() # Flow run is failed with the proper message mocks.fail_flow_run.assert_called_once_with( flow_run_id="flow-run-id", message="Flow run received an interrupt signal.") def test_handles_unexpected_exception(self, cloud_mocks, mocks): cloud_mocks.FlowRunView.from_flow_run_id().state = Scheduled() mocks.subprocess.run.side_effect = Exception("Foobar") # Re-raised as `RuntmeError` with pytest.raises( RuntimeError, match="encountered unexpected exception during execution"): execute_flow_run_in_subprocess("flow-run-id") # Only tried to run once mocks.subprocess.run.assert_called_once() # Flow run is failed with the proper message mocks.fail_flow_run.assert_called_once_with( flow_run_id="flow-run-id", message=( "Flow run encountered unexpected exception during execution: " f"{Exception('Foobar')!r}"), ) def test_handles_bad_subprocess_result(self, cloud_mocks, mocks): cloud_mocks.FlowRunView.from_flow_run_id().state = Scheduled() mocks.subprocess.run.return_value.check_returncode.side_effect = ( CalledProcessError(cmd="foo", returncode=1)) # Re-raised as `RuntmeError` with pytest.raises(RuntimeError, match="flow run process failed"): execute_flow_run_in_subprocess("flow-run-id") # Only tried to run once mocks.subprocess.run.assert_called_once() # Flow run is not failed at this time -- left to the FlowRunner mocks.fail_flow_run.assert_not_called() def test_loops_until_flow_run_is_finished(self, cloud_mocks, mocks): cloud_mocks.FlowRunView.from_flow_run_id().state = Scheduled() cloud_mocks.FlowRunView.from_flow_run_id().get_latest.side_effect = [ MagicMock(state=Running()), MagicMock(state=Running()), MagicMock(state=Success()), ] execute_flow_run_in_subprocess("flow-run-id") # Ran the subprocess twice assert mocks.subprocess.run.call_count == 2 # Waited each time assert mocks.wait_for_flow_run_start_time.call_count == 2
def get_flow_run_info(*args, _version=itertools.count(), **kwargs): state = Cancelling() if trigger.is_set() else Running() return MagicMock(version=next(_version), state=state)
async def test_set_running_task_run_state_fails_when_flow_run_is_not_running( self, flow_run_id, task_run_id): await states.set_flow_run_state(flow_run_id, state=Success()) with pytest.raises(ValueError, match="State update failed"): await states.set_task_run_state(task_run_id=task_run_id, state=Running())
class TestTaskRunStates: async def test_returns_status_dict(self, running_flow_run_id: str, task_run_id: str): result = await states.set_task_run_state(task_run_id, state=Success()) assert result["status"] == "SUCCESS" @pytest.mark.parametrize( "state_cls", [s for s in State.children() if s not in _MetaState.children()]) async def test_set_task_run_state(self, running_flow_run_id, task_run_id, state_cls): await states.set_task_run_state(task_run_id=task_run_id, state=state_cls()) query = await models.TaskRun.where(id=task_run_id).first( {"version", "state", "serialized_state"}) assert query.version == 1 assert query.state == state_cls.__name__ assert query.serialized_state["type"] == state_cls.__name__ @pytest.mark.parametrize( "state_cls", [ s for s in State.children() if s not in _MetaState.children() and not s().is_running() ], ) async def test_set_non_running_task_run_state_works_when_flow_run_is_not_running( self, flow_run_id, task_run_id, state_cls): await states.set_flow_run_state(flow_run_id, state=Success()) await states.set_task_run_state(task_run_id=task_run_id, state=state_cls()) query = await models.TaskRun.where(id=task_run_id).first( {"version", "state", "serialized_state"}) assert query.version == 1 assert query.state == state_cls.__name__ assert query.serialized_state["type"] == state_cls.__name__ async def test_set_running_task_run_state_fails_when_flow_run_is_not_running( self, flow_run_id, task_run_id): await states.set_flow_run_state(flow_run_id, state=Success()) with pytest.raises(ValueError, match="State update failed"): await states.set_task_run_state(task_run_id=task_run_id, state=Running()) async def test_set_running_task_run_state_works_when_flow_run_is_not_running_if_force( self, flow_run_id, task_run_id, ): await states.set_flow_run_state(flow_run_id, state=Success()) await states.set_task_run_state(task_run_id=task_run_id, state=Running(), force=True) query = await models.TaskRun.where(id=task_run_id).first( {"version", "state", "serialized_state"}) assert query.version == 1 assert query.state == "Running" assert query.serialized_state["type"] == "Running" async def test_set_task_run_state_does_not_increment_run_count_when_looping( self, task_run_id, running_flow_run_id): # simulate some looping await states.set_task_run_state(task_run_id=task_run_id, state=Running()) await states.set_task_run_state(task_run_id=task_run_id, state=Looped()) result = await states.set_task_run_state(task_run_id=task_run_id, state=Running()) task_run = await models.TaskRun.where(id=task_run_id ).first({"run_count"}) assert task_run.run_count == 1 @pytest.mark.parametrize("state", [Running(), Success()]) async def test_set_task_run_state_fails_with_wrong_task_run_id( self, state, running_flow_run_id): with pytest.raises(ValueError, match="Invalid task run ID"): await states.set_task_run_state(task_run_id=str(uuid.uuid4()), state=state) async def test_trigger_failed_state_does_not_set_end_time( self, task_run_id): await states.set_task_run_state(task_run_id=task_run_id, state=TriggerFailed()) task_run_info = await models.TaskRun.where(id=task_run_id).first( {"id", "start_time", "end_time"}) assert not task_run_info.start_time assert not task_run_info.end_time @pytest.mark.parametrize( "state_cls", [s for s in State.children() if s not in _MetaState.children()]) async def test_setting_a_task_run_state_pulls_cached_inputs_if_possible( self, running_flow_run_id, task_run_id, state_cls): # set up a Failed state with cached inputs res1 = SafeResult(1, result_handler=JSONResultHandler()) res2 = SafeResult({"z": 2}, result_handler=JSONResultHandler()) complex_result = {"x": res1, "y": res2} cached_state = Failed(cached_inputs=complex_result) await models.TaskRun.where(id=task_run_id).update(set=dict( serialized_state=cached_state.serialize())) # try to schedule the task run to scheduled await states.set_task_run_state(task_run_id=task_run_id, state=state_cls()) task_run = await models.TaskRun.where(id=task_run_id ).first({"serialized_state"}) # ensure the state change took place assert task_run.serialized_state["type"] == state_cls.__name__ assert task_run.serialized_state["cached_inputs"]["x"]["value"] == 1 assert task_run.serialized_state["cached_inputs"]["y"]["value"] == { "z": 2 } @pytest.mark.parametrize( "state", [ s(cached_inputs=None) for s in State.children() if s not in _MetaState.children() ], ids=[ s.__name__ for s in State.children() if s not in _MetaState.children() ], ) async def test_task_runs_with_null_cached_inputs_do_not_overwrite_cache( self, running_flow_run_id, state, task_run_id): # set up a Failed state with null cached inputs await states.set_task_run_state(task_run_id=task_run_id, state=state) # set up a Retrying state with non-null cached inputs res1 = SafeResult(1, result_handler=JSONResultHandler()) res2 = SafeResult({"z": 2}, result_handler=JSONResultHandler()) complex_result = {"x": res1, "y": res2} cached_state = Retrying(cached_inputs=complex_result) await states.set_task_run_state(task_run_id=task_run_id, state=cached_state) run = await models.TaskRun.where(id=task_run_id ).first({"serialized_state"}) assert run.serialized_state["cached_inputs"]["x"]["value"] == 1 assert run.serialized_state["cached_inputs"]["y"]["value"] == {"z": 2} @pytest.mark.parametrize( "state_cls", [s for s in State.children() if s not in _MetaState.children()]) async def test_task_runs_cached_inputs_give_preference_to_new_cached_inputs( self, running_flow_run_id, state_cls, task_run_id): # set up a Failed state with null cached inputs res1 = SafeResult(1, result_handler=JSONResultHandler()) res2 = SafeResult({"a": 2}, result_handler=JSONResultHandler()) complex_result = {"b": res1, "c": res2} cached_state = state_cls(cached_inputs=complex_result) await states.set_task_run_state(task_run_id=task_run_id, state=cached_state) # set up a Retrying state with non-null cached inputs res1 = SafeResult(1, result_handler=JSONResultHandler()) res2 = SafeResult({"z": 2}, result_handler=JSONResultHandler()) complex_result = {"x": res1, "y": res2} cached_state = Retrying(cached_inputs=complex_result) await states.set_task_run_state(task_run_id=task_run_id, state=cached_state) run = Box(await models.TaskRun.where(id=task_run_id ).first({"serialized_state"})) # verify that we have cached inputs, and that preference has been given to the new # cached inputs assert run.serialized_state.cached_inputs assert run.serialized_state.cached_inputs.x.value == 1 assert run.serialized_state.cached_inputs.y.value == {"z": 2}
class TestTaskRunStates: async def test_set_task_run_state(self, task_run_id): result = await api.states.set_task_run_state(task_run_id=task_run_id, state=Failed()) assert result.task_run_id == task_run_id query = await models.TaskRun.where(id=task_run_id).first( {"version", "state", "serialized_state"}) assert query.version == 2 assert query.state == "Failed" assert query.serialized_state["type"] == "Failed" @pytest.mark.parametrize("state", [Failed(), Success()]) async def test_set_task_run_state_fails_with_wrong_task_run_id( self, state): with pytest.raises(ValueError, match="State update failed"): await api.states.set_task_run_state(task_run_id=str(uuid.uuid4()), state=state) @pytest.mark.parametrize( "state", [s() for s in State.children() if not s().is_running()]) async def test_state_does_not_set_heartbeat_unless_running( self, state, task_run_id): task_run = await models.TaskRun.where(id=task_run_id ).first({"heartbeat"}) assert task_run.heartbeat is None await api.states.set_task_run_state(task_run_id=task_run_id, state=state) task_run = await models.TaskRun.where(id=task_run_id ).first({"heartbeat"}) assert task_run.heartbeat is None async def test_running_state_sets_heartbeat(self, task_run_id, running_flow_run_id): task_run = await models.TaskRun.where(id=task_run_id ).first({"heartbeat"}) assert task_run.heartbeat is None dt = pendulum.now("UTC") await api.states.set_task_run_state(task_run_id=task_run_id, state=Running()) task_run = await models.TaskRun.where(id=task_run_id ).first({"heartbeat"}) assert task_run.heartbeat > dt async def test_trigger_failed_state_does_not_set_end_time( self, task_run_id): await api.states.set_task_run_state(task_run_id=task_run_id, state=TriggerFailed()) task_run_info = await models.TaskRun.where(id=task_run_id).first( {"id", "start_time", "end_time"}) assert not task_run_info.start_time assert not task_run_info.end_time @pytest.mark.parametrize( "flow_run_state", [Pending(), Running(), Failed(), Success()]) async def test_running_states_can_not_be_set_if_flow_run_is_not_running( self, flow_run_id, task_run_id, flow_run_state): await api.states.set_flow_run_state(flow_run_id=flow_run_id, state=flow_run_state) set_running_coroutine = api.states.set_task_run_state( task_run_id=task_run_id, state=Running()) if flow_run_state.is_running(): assert await set_running_coroutine assert (await models.TaskRun.where(id=task_run_id ).first({"state"})).state == "Running" else: with pytest.raises(ValueError, match="is not in a running state"): await set_running_coroutine assert (await models.TaskRun.where(id=task_run_id).first( {"state"})).state != "Running"