class TestFlowRunStates: @pytest.mark.parametrize( "state_cls", [s for s in State.children() if s not in _MetaState.children()]) async def test_set_flow_run_state(self, flow_run_id, state_cls): result = await states.set_flow_run_state(flow_run_id=flow_run_id, state=state_cls()) query = await models.FlowRun.where(id=flow_run_id).first( {"version", "state", "serialized_state"}) assert query.version == 2 assert query.state == state_cls.__name__ assert query.serialized_state["type"] == state_cls.__name__ @pytest.mark.parametrize("state", [Running(), Success()]) async def test_set_flow_run_state_fails_with_wrong_flow_run_id( self, state): with pytest.raises(ValueError, match="Invalid flow run ID"): await states.set_flow_run_state(flow_run_id=str(uuid.uuid4()), state=state) async def test_trigger_failed_state_does_not_set_end_time( self, flow_run_id): # there is no logic in Prefect that would create this sequence of # events, but a user could manually do this await states.set_flow_run_state(flow_run_id=flow_run_id, state=TriggerFailed()) flow_run_info = await models.FlowRun.where(id=flow_run_id).first( {"id", "start_time", "end_time"}) assert not flow_run_info.start_time assert not flow_run_info.end_time
class TestTaskRunStates: async def test_set_task_run_state(self, task_run_id): result = await api.states.set_task_run_state( task_run_id=task_run_id, state=Failed() ) assert result.task_run_id == task_run_id query = await models.TaskRun.where(id=task_run_id).first( {"version", "state", "serialized_state"} ) assert query.version == 2 assert query.state == "Failed" assert query.serialized_state["type"] == "Failed" @pytest.mark.parametrize("state", [Failed(), Success()]) async def test_set_task_run_state_fails_with_wrong_task_run_id(self, state): with pytest.raises(ValueError, match="State update failed"): await api.states.set_task_run_state( task_run_id=str(uuid.uuid4()), state=state ) @pytest.mark.parametrize( "state", [s() for s in State.children() if not s().is_running()] ) async def test_state_does_not_set_heartbeat_unless_running( self, state, task_run_id ): task_run = await models.TaskRun.where(id=task_run_id).first({"heartbeat"}) assert task_run.heartbeat is None await api.states.set_task_run_state(task_run_id=task_run_id, state=state) task_run = await models.TaskRun.where(id=task_run_id).first({"heartbeat"}) assert task_run.heartbeat is None async def test_running_state_sets_heartbeat(self, task_run_id, running_flow_run_id): task_run = await models.TaskRun.where(id=task_run_id).first({"heartbeat"}) assert task_run.heartbeat is None dt = pendulum.now("UTC") await api.states.set_task_run_state(task_run_id=task_run_id, state=Running()) task_run = await models.TaskRun.where(id=task_run_id).first({"heartbeat"}) assert task_run.heartbeat > dt async def test_trigger_failed_state_does_not_set_end_time(self, task_run_id): await api.states.set_task_run_state( task_run_id=task_run_id, state=TriggerFailed() ) task_run_info = await models.TaskRun.where(id=task_run_id).first( {"id", "start_time", "end_time"} ) assert not task_run_info.start_time assert not task_run_info.end_time @pytest.mark.parametrize( "state", [s() for s in State.children() if s not in _MetaState.children()], ids=[s.__name__ for s in State.children() if s not in _MetaState.children()], ) async def test_setting_a_task_run_state_pulls_cached_inputs_if_possible( self, task_run_id, state, running_flow_run_id ): res1 = SafeResult(1, result_handler=JSONResultHandler()) res2 = SafeResult({"z": 2}, result_handler=JSONResultHandler()) complex_result = {"x": res1, "y": res2} cached_state = Failed(cached_inputs=complex_result) await models.TaskRun.where(id=task_run_id).update( set=dict(serialized_state=cached_state.serialize()) ) # try to schedule the task run to scheduled await api.states.set_task_run_state(task_run_id=task_run_id, state=state) task_run = await models.TaskRun.where(id=task_run_id).first( {"serialized_state"} ) # ensure the state change took place assert task_run.serialized_state["type"] == type(state).__name__ assert task_run.serialized_state["cached_inputs"]["x"]["value"] == 1 assert task_run.serialized_state["cached_inputs"]["y"]["value"] == {"z": 2} @pytest.mark.parametrize( "state", [ s(cached_inputs=None) for s in State.children() if s not in _MetaState.children() ], ids=[s.__name__ for s in State.children() if s not in _MetaState.children()], ) async def test_task_runs_with_null_cached_inputs_do_not_overwrite_cache( self, state, task_run_id, running_flow_run_id ): await api.states.set_task_run_state(task_run_id=task_run_id, state=state) # set up a Retrying state with non-null cached inputs res1 = SafeResult(1, result_handler=JSONResultHandler()) res2 = SafeResult({"z": 2}, result_handler=JSONResultHandler()) complex_result = {"x": res1, "y": res2} cached_state = Retrying(cached_inputs=complex_result) await api.states.set_task_run_state(task_run_id=task_run_id, state=cached_state) run = await models.TaskRun.where(id=task_run_id).first({"serialized_state"}) assert run.serialized_state["cached_inputs"]["x"]["value"] == 1 assert run.serialized_state["cached_inputs"]["y"]["value"] == {"z": 2} @pytest.mark.parametrize( "state_cls", [s for s in State.children() if s not in _MetaState.children()] ) async def test_task_runs_cached_inputs_give_preference_to_new_cached_inputs( self, state_cls, task_run_id, running_flow_run_id ): # set up a Failed state with null cached inputs res1 = SafeResult(1, result_handler=JSONResultHandler()) res2 = SafeResult({"a": 2}, result_handler=JSONResultHandler()) complex_result = {"b": res1, "c": res2} cached_state = state_cls(cached_inputs=complex_result) await api.states.set_task_run_state(task_run_id=task_run_id, state=cached_state) # set up a Retrying state with non-null cached inputs res1 = SafeResult(1, result_handler=JSONResultHandler()) res2 = SafeResult({"z": 2}, result_handler=JSONResultHandler()) complex_result = {"x": res1, "y": res2} cached_state = Retrying(cached_inputs=complex_result) await api.states.set_task_run_state(task_run_id=task_run_id, state=cached_state) run = Box( await models.TaskRun.where(id=task_run_id).first({"serialized_state"}) ) # verify that we have cached inputs, and that preference has been given to the new # cached inputs assert run.serialized_state.cached_inputs assert run.serialized_state.cached_inputs.x.value == 1 assert run.serialized_state.cached_inputs.y.value == {"z": 2} @pytest.mark.parametrize( "flow_run_state", [Pending(), Running(), Failed(), Success()] ) async def test_running_states_can_not_be_set_if_flow_run_is_not_running( self, flow_run_id, task_run_id, flow_run_state ): await api.states.set_flow_run_state( flow_run_id=flow_run_id, state=flow_run_state ) set_running_coroutine = api.states.set_task_run_state( task_run_id=task_run_id, state=Running() ) if flow_run_state.is_running(): assert await set_running_coroutine assert ( await models.TaskRun.where(id=task_run_id).first({"state"}) ).state == "Running" else: with pytest.raises(ValueError, match="is not in a running state"): await set_running_coroutine assert ( await models.TaskRun.where(id=task_run_id).first({"state"}) ).state != "Running"
class TestTaskRunStates: async def test_returns_status_dict(self, running_flow_run_id: str, task_run_id: str): result = await states.set_task_run_state(task_run_id, state=Success()) assert result["status"] == "SUCCESS" @pytest.mark.parametrize( "state_cls", [s for s in State.children() if s not in _MetaState.children()]) async def test_set_task_run_state(self, running_flow_run_id, task_run_id, state_cls): await states.set_task_run_state(task_run_id=task_run_id, state=state_cls()) query = await models.TaskRun.where(id=task_run_id).first( {"version", "state", "serialized_state"}) assert query.version == 1 assert query.state == state_cls.__name__ assert query.serialized_state["type"] == state_cls.__name__ @pytest.mark.parametrize( "state_cls", [ s for s in State.children() if s not in _MetaState.children() and not s().is_running() ], ) async def test_set_non_running_task_run_state_works_when_flow_run_is_not_running( self, flow_run_id, task_run_id, state_cls): await states.set_flow_run_state(flow_run_id, state=Success()) await states.set_task_run_state(task_run_id=task_run_id, state=state_cls()) query = await models.TaskRun.where(id=task_run_id).first( {"version", "state", "serialized_state"}) assert query.version == 1 assert query.state == state_cls.__name__ assert query.serialized_state["type"] == state_cls.__name__ async def test_set_running_task_run_state_fails_when_flow_run_is_not_running( self, flow_run_id, task_run_id): await states.set_flow_run_state(flow_run_id, state=Success()) with pytest.raises(ValueError, match="State update failed"): await states.set_task_run_state(task_run_id=task_run_id, state=Running()) async def test_set_running_task_run_state_works_when_flow_run_is_not_running_if_force( self, flow_run_id, task_run_id, ): await states.set_flow_run_state(flow_run_id, state=Success()) await states.set_task_run_state(task_run_id=task_run_id, state=Running(), force=True) query = await models.TaskRun.where(id=task_run_id).first( {"version", "state", "serialized_state"}) assert query.version == 1 assert query.state == "Running" assert query.serialized_state["type"] == "Running" async def test_set_task_run_state_does_not_increment_run_count_when_looping( self, task_run_id, running_flow_run_id): # simulate some looping await states.set_task_run_state(task_run_id=task_run_id, state=Running()) await states.set_task_run_state(task_run_id=task_run_id, state=Looped()) result = await states.set_task_run_state(task_run_id=task_run_id, state=Running()) task_run = await models.TaskRun.where(id=task_run_id ).first({"run_count"}) assert task_run.run_count == 1 @pytest.mark.parametrize("state", [Running(), Success()]) async def test_set_task_run_state_fails_with_wrong_task_run_id( self, state, running_flow_run_id): with pytest.raises(ValueError, match="Invalid task run ID"): await states.set_task_run_state(task_run_id=str(uuid.uuid4()), state=state) async def test_trigger_failed_state_does_not_set_end_time( self, task_run_id): await states.set_task_run_state(task_run_id=task_run_id, state=TriggerFailed()) task_run_info = await models.TaskRun.where(id=task_run_id).first( {"id", "start_time", "end_time"}) assert not task_run_info.start_time assert not task_run_info.end_time @pytest.mark.parametrize( "state_cls", [s for s in State.children() if s not in _MetaState.children()]) async def test_setting_a_task_run_state_pulls_cached_inputs_if_possible( self, running_flow_run_id, task_run_id, state_cls): # set up a Failed state with cached inputs res1 = SafeResult(1, result_handler=JSONResultHandler()) res2 = SafeResult({"z": 2}, result_handler=JSONResultHandler()) complex_result = {"x": res1, "y": res2} cached_state = Failed(cached_inputs=complex_result) await models.TaskRun.where(id=task_run_id).update(set=dict( serialized_state=cached_state.serialize())) # try to schedule the task run to scheduled await states.set_task_run_state(task_run_id=task_run_id, state=state_cls()) task_run = await models.TaskRun.where(id=task_run_id ).first({"serialized_state"}) # ensure the state change took place assert task_run.serialized_state["type"] == state_cls.__name__ assert task_run.serialized_state["cached_inputs"]["x"]["value"] == 1 assert task_run.serialized_state["cached_inputs"]["y"]["value"] == { "z": 2 } @pytest.mark.parametrize( "state", [ s(cached_inputs=None) for s in State.children() if s not in _MetaState.children() ], ids=[ s.__name__ for s in State.children() if s not in _MetaState.children() ], ) async def test_task_runs_with_null_cached_inputs_do_not_overwrite_cache( self, running_flow_run_id, state, task_run_id): # set up a Failed state with null cached inputs await states.set_task_run_state(task_run_id=task_run_id, state=state) # set up a Retrying state with non-null cached inputs res1 = SafeResult(1, result_handler=JSONResultHandler()) res2 = SafeResult({"z": 2}, result_handler=JSONResultHandler()) complex_result = {"x": res1, "y": res2} cached_state = Retrying(cached_inputs=complex_result) await states.set_task_run_state(task_run_id=task_run_id, state=cached_state) run = await models.TaskRun.where(id=task_run_id ).first({"serialized_state"}) assert run.serialized_state["cached_inputs"]["x"]["value"] == 1 assert run.serialized_state["cached_inputs"]["y"]["value"] == {"z": 2} @pytest.mark.parametrize( "state_cls", [s for s in State.children() if s not in _MetaState.children()]) async def test_task_runs_cached_inputs_give_preference_to_new_cached_inputs( self, running_flow_run_id, state_cls, task_run_id): # set up a Failed state with null cached inputs res1 = SafeResult(1, result_handler=JSONResultHandler()) res2 = SafeResult({"a": 2}, result_handler=JSONResultHandler()) complex_result = {"b": res1, "c": res2} cached_state = state_cls(cached_inputs=complex_result) await states.set_task_run_state(task_run_id=task_run_id, state=cached_state) # set up a Retrying state with non-null cached inputs res1 = SafeResult(1, result_handler=JSONResultHandler()) res2 = SafeResult({"z": 2}, result_handler=JSONResultHandler()) complex_result = {"x": res1, "y": res2} cached_state = Retrying(cached_inputs=complex_result) await states.set_task_run_state(task_run_id=task_run_id, state=cached_state) run = Box(await models.TaskRun.where(id=task_run_id ).first({"serialized_state"})) # verify that we have cached inputs, and that preference has been given to the new # cached inputs assert run.serialized_state.cached_inputs assert run.serialized_state.cached_inputs.x.value == 1 assert run.serialized_state.cached_inputs.y.value == {"z": 2}