Exemplo n.º 1
0
def test_task_runner_handles_looping_with_retries(client):
    # note that looping _requires_ a result handler in Cloud
    @prefect.task(
        max_retries=1,
        retry_delay=datetime.timedelta(seconds=0),
        result=PrefectResult(),
    )
    def looper():
        if (prefect.context.get("task_loop_count") == 2
                and prefect.context.get("task_run_count", 1) == 1):
            raise ValueError("Stop")
        if prefect.context.get("task_loop_count", 1) < 3:
            raise LOOP(result=prefect.context.get("task_loop_result", 0) + 10)
        return prefect.context.get("task_loop_result")

    client.get_task_run_info.side_effect = [
        MagicMock(version=i,
                  state=Pending() if i == 0 else Looped(loop_count=i))
        for i in range(5)
    ]
    res = CloudTaskRunner(task=looper).run(context={"task_run_version": 1},
                                           state=None,
                                           upstream_states={})

    ## assertions
    assert res.is_successful()
    assert client.get_task_run_info.call_count == 4
    assert (
        client.set_task_run_state.call_count == 9
    )  # Pending -> Running -> Looped (1) -> Running -> Failed -> Retrying -> Running -> Looped(2) -> Running -> Success
    versions = [
        call[1]["version"] for call in client.set_task_run_state.call_args_list
        if call[1]["version"]
    ]
    assert versions == [1, 2, 3]
Exemplo n.º 2
0
    async def test_set_task_run_state_does_not_increment_run_count_when_looping(
        self, task_run_id
    ):
        # simulate some looping
        await states.set_task_run_state(task_run_id=task_run_id, state=Running())
        await states.set_task_run_state(task_run_id=task_run_id, state=Looped())
        result = await states.set_task_run_state(
            task_run_id=task_run_id, state=Running()
        )

        task_run = await models.TaskRun.where(id=task_run_id).first({"run_count"})
        assert task_run.run_count == 1
Exemplo n.º 3
0
    async def test_set_task_run_state_does_not_increment_run_count_when_looping(
            self, task_run_id, flow_run_id):
        # ensure the flow run is running
        await api.states.set_flow_run_state(flow_run_id=flow_run_id,
                                            state=Running())

        # simulate some looping
        await api.states.set_task_run_state(task_run_id=task_run_id,
                                            state=Running())
        await api.states.set_task_run_state(task_run_id=task_run_id,
                                            state=Looped())
        result = await api.states.set_task_run_state(task_run_id=task_run_id,
                                                     state=Running())

        assert result.task_run_id == task_run_id
        task_run = await models.TaskRun.where(id=task_run_id
                                              ).first({"run_count"})
        assert task_run.run_count == 1
Exemplo n.º 4
0
def test_looped_stores_default_loop_count_in_context():
    with prefect.context(task_loop_count=5):
        state = Looped()
    assert state.loop_count == 5
Exemplo n.º 5
0
def test_looped_stores_default_loop_count():
    state = Looped()
    assert state.loop_count == 1
Exemplo n.º 6
0
def test_retry_stores_loop_count():
    state = Looped(loop_count=2)
    assert state.loop_count == 2
Exemplo n.º 7
0
        assert issubclass(TimedOut, Failed)

    def test_trigger_failed_is_failed(self):
        assert issubclass(TriggerFailed, Failed)


@pytest.mark.parametrize(
    "state_check",
    [
        dict(state=Cancelled(), assert_true={"is_finished"}),
        dict(state=Cached(),
             assert_true={"is_cached", "is_finished", "is_successful"}),
        dict(state=ClientFailed(), assert_true={"is_meta_state"}),
        dict(state=Failed(), assert_true={"is_finished", "is_failed"}),
        dict(state=Finished(), assert_true={"is_finished"}),
        dict(state=Looped(), assert_true={"is_finished", "is_looped"}),
        dict(state=Mapped(),
             assert_true={"is_finished", "is_mapped", "is_successful"}),
        dict(state=Paused(), assert_true={"is_pending", "is_scheduled"}),
        dict(state=Pending(), assert_true={"is_pending"}),
        dict(state=Queued(), assert_true={"is_meta_state", "is_queued"}),
        dict(state=Resume(), assert_true={"is_pending", "is_scheduled"}),
        dict(state=Retrying(),
             assert_true={"is_pending", "is_scheduled", "is_retrying"}),
        dict(state=Running(), assert_true={"is_running"}),
        dict(state=Scheduled(), assert_true={"is_pending", "is_scheduled"}),
        dict(state=Skipped(),
             assert_true={"is_finished", "is_successful", "is_skipped"}),
        dict(state=Submitted(), assert_true={"is_meta_state", "is_submitted"}),
        dict(state=Success(), assert_true={"is_finished", "is_successful"}),
        dict(state=TimedOut(), assert_true={"is_finished", "is_failed"}),