def test_task_runner_handles_looping_with_retries(client): # note that looping _requires_ a result handler in Cloud @prefect.task( max_retries=1, retry_delay=datetime.timedelta(seconds=0), result=PrefectResult(), ) def looper(): if (prefect.context.get("task_loop_count") == 2 and prefect.context.get("task_run_count", 1) == 1): raise ValueError("Stop") if prefect.context.get("task_loop_count", 1) < 3: raise LOOP(result=prefect.context.get("task_loop_result", 0) + 10) return prefect.context.get("task_loop_result") client.get_task_run_info.side_effect = [ MagicMock(version=i, state=Pending() if i == 0 else Looped(loop_count=i)) for i in range(5) ] res = CloudTaskRunner(task=looper).run(context={"task_run_version": 1}, state=None, upstream_states={}) ## assertions assert res.is_successful() assert client.get_task_run_info.call_count == 4 assert ( client.set_task_run_state.call_count == 9 ) # Pending -> Running -> Looped (1) -> Running -> Failed -> Retrying -> Running -> Looped(2) -> Running -> Success versions = [ call[1]["version"] for call in client.set_task_run_state.call_args_list if call[1]["version"] ] assert versions == [1, 2, 3]
async def test_set_task_run_state_does_not_increment_run_count_when_looping( self, task_run_id ): # simulate some looping await states.set_task_run_state(task_run_id=task_run_id, state=Running()) await states.set_task_run_state(task_run_id=task_run_id, state=Looped()) result = await states.set_task_run_state( task_run_id=task_run_id, state=Running() ) task_run = await models.TaskRun.where(id=task_run_id).first({"run_count"}) assert task_run.run_count == 1
async def test_set_task_run_state_does_not_increment_run_count_when_looping( self, task_run_id, flow_run_id): # ensure the flow run is running await api.states.set_flow_run_state(flow_run_id=flow_run_id, state=Running()) # simulate some looping await api.states.set_task_run_state(task_run_id=task_run_id, state=Running()) await api.states.set_task_run_state(task_run_id=task_run_id, state=Looped()) result = await api.states.set_task_run_state(task_run_id=task_run_id, state=Running()) assert result.task_run_id == task_run_id task_run = await models.TaskRun.where(id=task_run_id ).first({"run_count"}) assert task_run.run_count == 1
def test_looped_stores_default_loop_count_in_context(): with prefect.context(task_loop_count=5): state = Looped() assert state.loop_count == 5
def test_looped_stores_default_loop_count(): state = Looped() assert state.loop_count == 1
def test_retry_stores_loop_count(): state = Looped(loop_count=2) assert state.loop_count == 2
assert issubclass(TimedOut, Failed) def test_trigger_failed_is_failed(self): assert issubclass(TriggerFailed, Failed) @pytest.mark.parametrize( "state_check", [ dict(state=Cancelled(), assert_true={"is_finished"}), dict(state=Cached(), assert_true={"is_cached", "is_finished", "is_successful"}), dict(state=ClientFailed(), assert_true={"is_meta_state"}), dict(state=Failed(), assert_true={"is_finished", "is_failed"}), dict(state=Finished(), assert_true={"is_finished"}), dict(state=Looped(), assert_true={"is_finished", "is_looped"}), dict(state=Mapped(), assert_true={"is_finished", "is_mapped", "is_successful"}), dict(state=Paused(), assert_true={"is_pending", "is_scheduled"}), dict(state=Pending(), assert_true={"is_pending"}), dict(state=Queued(), assert_true={"is_meta_state", "is_queued"}), dict(state=Resume(), assert_true={"is_pending", "is_scheduled"}), dict(state=Retrying(), assert_true={"is_pending", "is_scheduled", "is_retrying"}), dict(state=Running(), assert_true={"is_running"}), dict(state=Scheduled(), assert_true={"is_pending", "is_scheduled"}), dict(state=Skipped(), assert_true={"is_finished", "is_successful", "is_skipped"}), dict(state=Submitted(), assert_true={"is_meta_state", "is_submitted"}), dict(state=Success(), assert_true={"is_finished", "is_successful"}), dict(state=TimedOut(), assert_true={"is_finished", "is_failed"}),