async def test_set_multiple_task_run_states(self, run_query, task_run_id, task_run_id_2, task_run_id_3, running_flow_run_id): result = await run_query( query=self.mutation, variables=dict(input=dict(states=[ dict(task_run_id=task_run_id, state=Running().serialize()), dict(task_run_id=task_run_id_2, state=Success().serialize()), dict( task_run_id=task_run_id_3, version=1, state=Retrying().serialize(), ), ])), ) assert result.data.set_task_run_states.states == [ { "id": task_run_id, "status": "SUCCESS", "message": None }, { "id": task_run_id_2, "status": "SUCCESS", "message": None }, { "id": task_run_id_3, "status": "SUCCESS", "message": None }, ] tr1 = await models.TaskRun.where( id=result.data.set_task_run_states.states[0].id ).first({"state", "version"}) assert tr1.version == 1 assert tr1.state == "Running" tr2 = await models.TaskRun.where( id=result.data.set_task_run_states.states[1].id ).first({"state", "version"}) assert tr2.version == 2 assert tr2.state == "Success" tr3 = await models.TaskRun.where( id=result.data.set_task_run_states.states[2].id ).first({"state", "version"}) assert tr3.version == 2 assert tr3.state == "Retrying"
async def test_set_task_run_state_bad_version(self, run_query, locked_task_run_id): result = await run_query( query=self.mutation, variables=dict(input=dict(states=[ dict( task_run_id=locked_task_run_id, version=100, state=Running().serialize(), ) ])), ) assert result.data.set_task_run_states is None assert "State update failed" in result.errors[0].message
class TestCheckFlowPendingOrRunning: @pytest.mark.parametrize("state", [Pending(), Running(), Retrying(), Scheduled()]) def test_pending_or_running_are_ok(self, state): flow = Flow(name="test", tasks=[Task()]) new_state = FlowRunner(flow=flow).check_flow_is_pending_or_running(state=state) assert new_state is state @pytest.mark.parametrize( "state", [Finished(), Success(), Failed(), Skipped(), State()] ) def test_not_pending_or_running_raise_endrun(self, state): flow = Flow(name="test", tasks=[Task()]) with pytest.raises(ENDRUN): FlowRunner(flow=flow).check_flow_is_pending_or_running(state=state)
async def test_set_task_run_state(self, run_query, task_run_id, running_flow_run_id): result = await run_query( query=self.mutation, variables=dict(input=dict(states=[ dict(task_run_id=task_run_id, state=Running().serialize()) ])), ) assert result.data.set_task_run_states.states[0].id == task_run_id tr = await models.TaskRun.where(id=task_run_id ).first({"state", "version"}) assert tr.version == 1 assert tr.state == "Running"
def test_determine_final_state_preserves_running_states_when_tasks_still_running( self, ): task = Task() flow = Flow(name="test", tasks=[task]) old_state = Running() new_state = FlowRunner(flow=flow).get_flow_run_state( state=old_state, task_states={task: Retrying(start_time=pendulum.now("utc").add(days=1))}, task_contexts={}, return_tasks=set(), task_runner_state_handlers=[], executor=LocalExecutor(), ) assert new_state is old_state
class TestRunModels: @pytest.mark.parametrize( "state", [ Running(message="running", result=1), Scheduled(message="scheduled", result=1, start_time=pendulum.now()), ], ) async def test_flow_run_fields_from_state(self, state): dt = pendulum.now() info = models.FlowRunState.fields_from_state(state) assert info["state"] == type(state).__name__ assert info["timestamp"] > dt assert info["message"] == state.message assert info["result"] == state.result assert info["serialized_state"] == state.serialize() @pytest.mark.parametrize( "state", [ Running(message="running", result=1), Scheduled(message="scheduled", result=1, start_time=pendulum.now()), ], ) async def test_task_run_fields_from_state(self, state): dt = pendulum.now() info = models.TaskRunState.fields_from_state(state) assert info["state"] == type(state).__name__ assert info["timestamp"] > dt assert info["message"] == state.message assert info["result"] == state.result assert info["serialized_state"] == state.serialize()
def test_does_not_write_checkpoint_file_to_disk_on_failure(self, tmp_path): result_handler = PandasResultHandler(tmp_path / "dummy.csv", "csv", write_kwargs={"index": False}) task = Task(name="Task", result_handler=result_handler) result = pd.DataFrame({"one": [1, 2, 3], "two": [4, 5, 6]}) task_runner = DSTaskRunner(task) task_runner.upstream_states = {} old_state = Running() new_state = Failed(result=result) dsh.checkpoint_handler(task_runner, old_state, new_state) with pytest.raises(IOError): pd.read_csv(tmp_path / "dummy.csv")
def test_writes_checkpointed_file_to_disk_on_success(self, tmp_path): result_handler = PandasResultHandler(tmp_path / "dummy.csv", "csv", write_kwargs={"index": False}) task = Task(name="Task", result_handler=result_handler) expected_result = pd.DataFrame({"one": [1, 2, 3], "two": [4, 5, 6]}) task_runner = DSTaskRunner(task) task_runner.upstream_states = {} old_state = Running() new_state = Success(result=expected_result) dsh.checkpoint_handler(task_runner, old_state, new_state) actual_result = pd.read_csv(tmp_path / "dummy.csv") pd.testing.assert_frame_equal(expected_result, actual_result)
async def test_setting_flow_run_to_cancelled_state_sets_unfinished_task_runs_to_cancelled( self, flow_run_id ): task_runs = await models.TaskRun.where( {"flow_run_id": {"_eq": flow_run_id}} ).get({"id"}) task_run_ids = [run.id for run in task_runs] # update the state to Running await api.states.set_flow_run_state(flow_run_id=flow_run_id, state=Running()) # Currently this flow_run_id fixture has at least 3 tasks, if this # changes the test will need to be updated assert len(task_run_ids) >= 3, "flow_run_id fixture has changed" # Set one task run to pending, one to running, and the rest to success pending_task_run = task_run_ids[0] running_task_run = task_run_ids[1] rest = task_run_ids[2:] await api.states.set_task_run_state( task_run_id=pending_task_run, state=Pending() ) await api.states.set_task_run_state( task_run_id=running_task_run, state=Running() ) for task_run_id in rest: await api.states.set_task_run_state( task_run_id=task_run_id, state=Success() ) # set the flow run to a cancelled state await api.states.set_flow_run_state(flow_run_id=flow_run_id, state=Cancelled()) # Confirm the unfinished task runs have been marked as cancelled task_runs = await models.TaskRun.where( {"flow_run_id": {"_eq": flow_run_id}} ).get({"id", "state"}) new_states = {run.id: run.state for run in task_runs} assert new_states[pending_task_run] == "Cancelled" assert new_states[running_task_run] == "Cancelled" assert all(new_states[id] == "Success" for id in rest)
def test_determine_final_state_has_final_say(self): class MyFlowRunner(FlowRunner): def determine_final_state(self, *args, **kwargs): return Failed("Very specific error message") flow = Flow(name="test", tasks=[Task()]) new_state = MyFlowRunner(flow=flow).get_flow_run_state( state=Running(), task_states={}, task_contexts={}, return_tasks=set(), task_runner_state_handlers=[], executor=LocalExecutor(), ) assert new_state.is_failed() assert new_state.message == "Very specific error message"
async def test_set_flow_run_state(self, run_query, flow_run_id): result = await run_query( query=self.mutation, variables=dict(input=dict(states=[ dict(flow_run_id=flow_run_id, state=Running().serialize()) ])), ) assert result.data.set_flow_run_states.states[0].id == flow_run_id assert result.data.set_flow_run_states.states[0].status == "SUCCESS" assert result.data.set_flow_run_states.states[0].message is None fr = await models.FlowRun.where(id=flow_run_id ).first({"state", "version"}) assert fr.version == 3 assert fr.state == "Running"
async def test_set_running_task_run_state_works_when_flow_run_is_not_running_if_force( self, flow_run_id, task_run_id, ): await states.set_flow_run_state(flow_run_id, state=Success()) await states.set_task_run_state(task_run_id=task_run_id, state=Running(), force=True) query = await models.TaskRun.where(id=task_run_id).first( {"version", "state", "serialized_state"}) assert query.version == 1 assert query.state == "Running" assert query.serialized_state["type"] == "Running"
def determine_final_state( self, key_states: Set[State], return_states: Dict[Task, State], terminal_states: Set[State], ) -> State: """ Implements the logic for determining the final state of the flow run. Args: - key_states (Set[State]): the states which will determine the success / failure of the flow run - return_states (Dict[Task, State]): states to return as results - terminal_states (Set[State]): the states of the terminal tasks for this flow Returns: - State: the final state of the flow run """ state = State() # mypy initialization # check that the flow is finished if not all(s.is_finished() for s in terminal_states): self.logger.info( "Flow run RUNNING: terminal tasks are incomplete.") state = Running(message="Flow run in progress.", result=return_states) # check if any key task failed elif any(s.is_failed() for s in key_states): self.logger.info("Flow run FAILED: some reference tasks failed.") state = Failed(message="Some reference tasks failed.", result=return_states) # check if all reference tasks succeeded elif all(s.is_successful() for s in key_states): self.logger.info("Flow run SUCCESS: all reference tasks succeeded") state = Success(message="All reference tasks succeeded.", result=return_states) # check for any unanticipated state that is finished but neither success nor failed else: self.logger.info("Flow run SUCCESS: no reference tasks failed") state = Success(message="No reference tasks failed.", result=return_states) return state
async def test_set_multiple_flow_run_states(self, run_query, flow_run_id, flow_run_id_2, flow_run_id_3): result = await run_query( query=self.mutation, variables=dict(input=dict(states=[ dict(flow_run_id=flow_run_id, state=Running().serialize()), dict(flow_run_id=flow_run_id_2, state=Success().serialize()), dict(flow_run_id=flow_run_id_3, state=Retrying().serialize()), ])), ) assert result.data.set_flow_run_states.states == [ { "id": flow_run_id, "status": "SUCCESS", "message": None }, { "id": flow_run_id_2, "status": "SUCCESS", "message": None }, { "id": flow_run_id_3, "status": "SUCCESS", "message": None }, ] fr1 = await models.FlowRun.where( id=result.data.set_flow_run_states.states[0].id ).first({"state", "version"}) assert fr1.version == 3 assert fr1.state == "Running" fr2 = await models.FlowRun.where( id=result.data.set_flow_run_states.states[1].id ).first({"state", "version"}) assert fr2.version == 4 assert fr2.state == "Success" fr3 = await models.FlowRun.where( id=result.data.set_flow_run_states.states[2].id ).first({"state", "version"}) assert fr3.version == 5 assert fr3.state == "Retrying"
def test_watch_flow_run_timeout(monkeypatch): flow_run = FlowRunView._from_flow_run_data(FLOW_RUN_DATA_1) flow_run.state = Running() # Not finished flow_run.get_latest = MagicMock(return_value=flow_run) flow_run.get_logs = MagicMock() MockView = MagicMock() MockView.from_flow_run_id.return_value = flow_run monkeypatch.setattr("prefect.backend.flow_run.FlowRunView", MockView) # Mock sleep so that we do not have a slow test monkeypatch.setattr("prefect.backend.flow_run.time.sleep", MagicMock()) with pytest.raises(RuntimeError, match="timed out after 12 hours of waiting"): for log in watch_flow_run("id"): pass
def test_flow_runner_prioritizes_kwarg_states_over_db_states( monkeypatch, state): flow = prefect.Flow(name="test") db_state = state("already", result=10) get_flow_run_info = MagicMock(return_value=MagicMock(state=db_state)) set_flow_run_state = MagicMock() client = MagicMock(get_flow_run_info=get_flow_run_info, set_flow_run_state=set_flow_run_state) monkeypatch.setattr("prefect.engine.cloud.flow_runner.Client", MagicMock(return_value=client)) res = CloudFlowRunner(flow=flow).run(state=Pending("let's do this")) ## assertions assert get_flow_run_info.call_count == 1 # one time to pull latest state assert set_flow_run_state.call_count == 2 # Pending -> Running -> Success states = [call[1]["state"] for call in set_flow_run_state.call_args_list] assert states == [Running(), Success(result={})]
def vclient(monkeypatch): cloud_client = MagicMock( get_flow_run_info=MagicMock(return_value=MagicMock(state=None)), set_flow_run_state=MagicMock(), get_task_run_info=MagicMock(return_value=MagicMock(state=None)), set_task_run_state=MagicMock( side_effect=VersionLockError(), return_value=Running() # side_effect=lambda task_run_id, version, state, cache_for: state ), get_latest_task_run_states=MagicMock( side_effect=lambda flow_run_id, states: states), ) monkeypatch.setattr("prefect.engine.cloud.task_runner.Client", MagicMock(return_value=cloud_client)) monkeypatch.setattr("prefect.engine.cloud.flow_runner.Client", MagicMock(return_value=cloud_client)) yield cloud_client
def set_flow_to_running(self, state: State) -> State: """ Puts Pending flows in a Running state; leaves Running flows Running. Args: - state (State): the current state of this flow Returns: - State: the state of the flow after running the check Raises: - ENDRUN: if the flow is not pending or running """ if state.is_pending(): return Running(message="Running flow.") elif state.is_running(): return state else: raise ENDRUN(state)
async def test_set_flow_run_state_with_bad_version(self, run_query, locked_flow_run_id): result = await run_query( query=self.mutation, variables=dict(input=dict(states=[ dict( flow_run_id=locked_flow_run_id, version=10, state=Running().serialize(), ) ])), ) assert "State update failed" in result.errors[0].message fr = await models.FlowRun.where(id=locked_flow_run_id ).first({"state", "version"}) assert fr.version == 1 assert fr.state == "Scheduled"
async def test_running_states_can_not_be_set_if_flow_run_is_not_running( self, flow_run_id, task_run_id, flow_run_state): await api.states.set_flow_run_state(flow_run_id=flow_run_id, state=flow_run_state) set_running_coroutine = api.states.set_task_run_state( task_run_id=task_run_id, state=Running()) if flow_run_state.is_running(): assert await set_running_coroutine assert (await models.TaskRun.where(id=task_run_id ).first({"state"})).state == "Running" else: with pytest.raises(ValueError, match="is not in a running state"): await set_running_coroutine assert (await models.TaskRun.where(id=task_run_id).first( {"state"})).state != "Running"
def test_watch_flow_run_default_timeout(monkeypatch): # Test the default behavior, which sets the timeout to 12 hours # when the `max_duration` kwarg is not provided flow_run = FlowRunView._from_flow_run_data(FLOW_RUN_DATA_1) flow_run.state = Running() # Not finished flow_run.get_latest = MagicMock(return_value=flow_run) flow_run.get_logs = MagicMock() MockView = MagicMock() MockView.from_flow_run_id.return_value = flow_run monkeypatch.setattr("prefect.backend.flow_run.FlowRunView", MockView) # Mock sleep so that we do not have a slow test monkeypatch.setattr("prefect.backend.flow_run.time.sleep", MagicMock()) with pytest.raises(RuntimeError, match="timed out after 12.0 hours of waiting"): for log in watch_flow_run("id"): pass
async def test_set_task_run_states_rejects_states_with_large_payloads( self, run_query, task_run_id, task_run_id_2, running_flow_run_id): result = await run_query( query=self.mutation, variables=dict(input=dict(states=[ dict( task_run_id=task_run_id, # this state should successfully set state=Running().serialize(), ), dict( task_run_id=task_run_id_2, # nonsense payload, just large state={ i: os.urandom(2 * 1000000).decode("latin") for i in range(2) }, ), ])), ) assert "State payload is too large" in result.errors[0].message
def set_task_to_running(self, state: State) -> State: """ Sets the task to running Args: - state (State): the current state of this task Returns: - State: the state of the task after running the check Raises: - ENDRUN: if the task is not ready to run """ if not state.is_pending(): self.logger.debug( "Task '{name}': can't set state to Running because it " "isn't Pending; ending run.".format(name=prefect.context.get( "task_full_name", self.task.name))) raise ENDRUN(state) return Running(message="Starting task run.")
class TestPrefectMessageCloudHook: @pytest.mark.parametrize("state", [Running(), Success(), Failed()]) async def test_prefect_message_cloud_hook( self, run_query, flow_run_id, state, tenant_id ): await models.Message.where().delete() cloud_hook_id = await api.cloud_hooks.create_cloud_hook( tenant_id=tenant_id, type="PREFECT_MESSAGE", config={}, states=[type(state).__name__], ) set_flow_run_state_mutation = """ mutation($input: set_flow_run_states_input!) { set_flow_run_states(input: $input) { states { id status message } } } """ await run_query( query=set_flow_run_state_mutation, variables=dict( input=dict( states=[ dict( flow_run_id=flow_run_id, version=1, state=state.serialize() ) ] ) ), ) await asyncio.sleep(1) assert ( await models.Message.where({"tenant_id": {"_eq": tenant_id}}).count() == 1 )
class TestCheckScheduledStep: @pytest.mark.parametrize("state", [Failed(), Pending(), Running(), Success()]) def test_non_scheduled_states(self, state): assert ( FlowRunner(flow=Flow(name="test")).check_flow_reached_start_time( state=state ) is state ) def test_scheduled_states_without_start_time(self): state = Scheduled(start_time=None) assert ( FlowRunner(flow=Flow(name="test")).check_flow_reached_start_time( state=state ) is state ) def test_scheduled_states_with_future_start_time(self): state = Scheduled( start_time=pendulum.now("utc") + datetime.timedelta(minutes=10) ) with pytest.raises(ENDRUN) as exc: FlowRunner(flow=Flow(name="test")).check_flow_reached_start_time( state=state ) assert exc.value.state is state def test_scheduled_states_with_past_start_time(self): state = Scheduled( start_time=pendulum.now("utc") - datetime.timedelta(minutes=1) ) assert ( FlowRunner(flow=Flow(name="test")).check_flow_reached_start_time( state=state ) is state )
def test_simple_two_task_flow_with_final_task_already_running( monkeypatch, executor): flow_run_id = str(uuid.uuid4()) task_run_id_1 = str(uuid.uuid4()) task_run_id_2 = str(uuid.uuid4()) with prefect.Flow(name="test") as flow: t1 = prefect.Task() t2 = prefect.Task() t2.set_upstream(t1) client = MockedCloudClient( flow_runs=[FlowRun(id=flow_run_id)], task_runs=[ TaskRun(id=task_run_id_1, task_slug=flow.slugs[t1], flow_run_id=flow_run_id), TaskRun( id=task_run_id_2, task_slug=flow.slugs[t2], version=1, flow_run_id=flow_run_id, state=Running(), ), ], monkeypatch=monkeypatch, ) with prefect.context(flow_run_id=flow_run_id): state = CloudFlowRunner(flow=flow).run(return_tasks=flow.tasks, executor=executor) assert state.is_running() assert client.flow_runs[flow_run_id].state.is_running() assert client.task_runs[task_run_id_1].state.is_successful() assert client.task_runs[task_run_id_1].version == 2 assert client.task_runs[task_run_id_2].state.is_running() assert client.task_runs[task_run_id_2].version == 1
def test_set_flow_run_state_gets_queued(patch_post): response = { "data": { "set_flow_run_states": { "states": [{"id": "74-salt", "status": "QUEUED", "message": None}] } } } post = patch_post(response) with set_temporary_config( { "cloud.api": "http://my-cloud.foo", "cloud.auth_token": "secret_token", "backend": "cloud", } ): client = Client() state = Running() result = client.set_flow_run_state(flow_run_id="74-salt", version=0, state=state) assert isinstance(result, State) assert state != result assert result.is_queued()
class TestFlowRunStates: async def test_returns_status_dict(self, flow_run_id: str): result = await states.set_flow_run_state(flow_run_id, state=Success()) assert result["status"] == "SUCCESS" @pytest.mark.parametrize( "state_cls", [s for s in State.children() if s not in _MetaState.children()]) async def test_set_flow_run_state(self, flow_run_id, state_cls): result = await states.set_flow_run_state(flow_run_id=flow_run_id, state=state_cls()) query = await models.FlowRun.where(id=flow_run_id).first( {"version", "state", "serialized_state"}) assert query.version == 2 assert query.state == state_cls.__name__ assert query.serialized_state["type"] == state_cls.__name__ @pytest.mark.parametrize("state", [Running(), Success()]) async def test_set_flow_run_state_fails_with_wrong_flow_run_id( self, state): with pytest.raises(ValueError, match="Invalid flow run ID"): await states.set_flow_run_state(flow_run_id=str(uuid.uuid4()), state=state) async def test_trigger_failed_state_does_not_set_end_time( self, flow_run_id): # there is no logic in Prefect that would create this sequence of # events, but a user could manually do this await states.set_flow_run_state(flow_run_id=flow_run_id, state=TriggerFailed()) flow_run_info = await models.FlowRun.where(id=flow_run_id).first( {"id", "start_time", "end_time"}) assert not flow_run_info.start_time assert not flow_run_info.end_time
async def test_returns_status_from_underlying_call( self, run_query, flow_run_id, payload_response, monkeypatch ): """ This test should ensure that the `status` field should be determined based on the underlying `api.states.set_flow_run_state()` call. """ mock_state_api = CoroutineMock(return_value=payload_response) monkeypatch.setattr( "src.prefect_server.graphql.states.api.states.set_flow_run_state", mock_state_api, ) result = await run_query( query=self.mutation, variables=dict( input=dict( states=[ dict( flow_run_id=flow_run_id, version=1, state=Running().serialize(), ) ] ) ), ) mock_state_api.assert_awaited_once() assert ( result.data.set_flow_run_states.states[0].status == payload_response["status"] )
class TestCancelFlowRun: mutation = """ mutation($input: cancel_flow_run_input!) { cancel_flow_run(input: $input) { state } } """ @pytest.mark.parametrize( "state,res_state,version", [ (Running(), "Cancelling", 4), (Success(), "Success", 3), (Submitted(), "Cancelled", 4), ], ) async def test_cancel_flow_run(self, run_query, flow_run_id, state, res_state, version): await api.states.set_flow_run_state(flow_run_id=flow_run_id, version=1, state=state) result = await run_query( query=self.mutation, variables={"input": { "flow_run_id": flow_run_id }}, ) assert result.data.cancel_flow_run.state == res_state fr = await models.FlowRun.where(id=flow_run_id ).first({"state", "version"}) assert fr.version == version assert fr.state == res_state