Example #1
0
class TestFlowRunStates:
    @pytest.mark.parametrize(
        "state_cls",
        [s for s in State.children() if s not in _MetaState.children()])
    async def test_set_flow_run_state(self, flow_run_id, state_cls):
        result = await states.set_flow_run_state(flow_run_id=flow_run_id,
                                                 state=state_cls())

        query = await models.FlowRun.where(id=flow_run_id).first(
            {"version", "state", "serialized_state"})

        assert query.version == 2
        assert query.state == state_cls.__name__
        assert query.serialized_state["type"] == state_cls.__name__

    @pytest.mark.parametrize("state", [Running(), Success()])
    async def test_set_flow_run_state_fails_with_wrong_flow_run_id(
            self, state):
        with pytest.raises(ValueError, match="Invalid flow run ID"):
            await states.set_flow_run_state(flow_run_id=str(uuid.uuid4()),
                                            state=state)

    async def test_trigger_failed_state_does_not_set_end_time(
            self, flow_run_id):
        # there is no logic in Prefect that would create this sequence of
        # events, but a user could manually do this
        await states.set_flow_run_state(flow_run_id=flow_run_id,
                                        state=TriggerFailed())
        flow_run_info = await models.FlowRun.where(id=flow_run_id).first(
            {"id", "start_time", "end_time"})
        assert not flow_run_info.start_time
        assert not flow_run_info.end_time
class TestTaskRunStates:
    async def test_set_task_run_state(self, task_run_id):
        result = await api.states.set_task_run_state(
            task_run_id=task_run_id, state=Failed()
        )

        assert result.task_run_id == task_run_id

        query = await models.TaskRun.where(id=task_run_id).first(
            {"version", "state", "serialized_state"}
        )

        assert query.version == 2
        assert query.state == "Failed"
        assert query.serialized_state["type"] == "Failed"

    @pytest.mark.parametrize("state", [Failed(), Success()])
    async def test_set_task_run_state_fails_with_wrong_task_run_id(self, state):
        with pytest.raises(ValueError, match="State update failed"):
            await api.states.set_task_run_state(
                task_run_id=str(uuid.uuid4()), state=state
            )

    @pytest.mark.parametrize(
        "state", [s() for s in State.children() if not s().is_running()]
    )
    async def test_state_does_not_set_heartbeat_unless_running(
        self, state, task_run_id
    ):
        task_run = await models.TaskRun.where(id=task_run_id).first({"heartbeat"})
        assert task_run.heartbeat is None

        await api.states.set_task_run_state(task_run_id=task_run_id, state=state)

        task_run = await models.TaskRun.where(id=task_run_id).first({"heartbeat"})
        assert task_run.heartbeat is None

    async def test_running_state_sets_heartbeat(self, task_run_id, running_flow_run_id):
        task_run = await models.TaskRun.where(id=task_run_id).first({"heartbeat"})
        assert task_run.heartbeat is None

        dt = pendulum.now("UTC")
        await api.states.set_task_run_state(task_run_id=task_run_id, state=Running())

        task_run = await models.TaskRun.where(id=task_run_id).first({"heartbeat"})
        assert task_run.heartbeat > dt

    async def test_trigger_failed_state_does_not_set_end_time(self, task_run_id):
        await api.states.set_task_run_state(
            task_run_id=task_run_id, state=TriggerFailed()
        )
        task_run_info = await models.TaskRun.where(id=task_run_id).first(
            {"id", "start_time", "end_time"}
        )
        assert not task_run_info.start_time
        assert not task_run_info.end_time

    @pytest.mark.parametrize(
        "state",
        [s() for s in State.children() if s not in _MetaState.children()],
        ids=[s.__name__ for s in State.children() if s not in _MetaState.children()],
    )
    async def test_setting_a_task_run_state_pulls_cached_inputs_if_possible(
        self, task_run_id, state, running_flow_run_id
    ):

        res1 = SafeResult(1, result_handler=JSONResultHandler())
        res2 = SafeResult({"z": 2}, result_handler=JSONResultHandler())
        complex_result = {"x": res1, "y": res2}
        cached_state = Failed(cached_inputs=complex_result)
        await models.TaskRun.where(id=task_run_id).update(
            set=dict(serialized_state=cached_state.serialize())
        )

        # try to schedule the task run to scheduled
        await api.states.set_task_run_state(task_run_id=task_run_id, state=state)

        task_run = await models.TaskRun.where(id=task_run_id).first(
            {"serialized_state"}
        )

        # ensure the state change took place
        assert task_run.serialized_state["type"] == type(state).__name__
        assert task_run.serialized_state["cached_inputs"]["x"]["value"] == 1
        assert task_run.serialized_state["cached_inputs"]["y"]["value"] == {"z": 2}

    @pytest.mark.parametrize(
        "state",
        [
            s(cached_inputs=None)
            for s in State.children()
            if s not in _MetaState.children()
        ],
        ids=[s.__name__ for s in State.children() if s not in _MetaState.children()],
    )
    async def test_task_runs_with_null_cached_inputs_do_not_overwrite_cache(
        self, state, task_run_id, running_flow_run_id
    ):

        await api.states.set_task_run_state(task_run_id=task_run_id, state=state)
        # set up a Retrying state with non-null cached inputs
        res1 = SafeResult(1, result_handler=JSONResultHandler())
        res2 = SafeResult({"z": 2}, result_handler=JSONResultHandler())
        complex_result = {"x": res1, "y": res2}
        cached_state = Retrying(cached_inputs=complex_result)
        await api.states.set_task_run_state(task_run_id=task_run_id, state=cached_state)
        run = await models.TaskRun.where(id=task_run_id).first({"serialized_state"})

        assert run.serialized_state["cached_inputs"]["x"]["value"] == 1
        assert run.serialized_state["cached_inputs"]["y"]["value"] == {"z": 2}

    @pytest.mark.parametrize(
        "state_cls", [s for s in State.children() if s not in _MetaState.children()]
    )
    async def test_task_runs_cached_inputs_give_preference_to_new_cached_inputs(
        self, state_cls, task_run_id, running_flow_run_id
    ):

        # set up a Failed state with null cached inputs
        res1 = SafeResult(1, result_handler=JSONResultHandler())
        res2 = SafeResult({"a": 2}, result_handler=JSONResultHandler())
        complex_result = {"b": res1, "c": res2}
        cached_state = state_cls(cached_inputs=complex_result)
        await api.states.set_task_run_state(task_run_id=task_run_id, state=cached_state)
        # set up a Retrying state with non-null cached inputs
        res1 = SafeResult(1, result_handler=JSONResultHandler())
        res2 = SafeResult({"z": 2}, result_handler=JSONResultHandler())
        complex_result = {"x": res1, "y": res2}
        cached_state = Retrying(cached_inputs=complex_result)
        await api.states.set_task_run_state(task_run_id=task_run_id, state=cached_state)
        run = Box(
            await models.TaskRun.where(id=task_run_id).first({"serialized_state"})
        )

        # verify that we have cached inputs, and that preference has been given to the new
        # cached inputs
        assert run.serialized_state.cached_inputs
        assert run.serialized_state.cached_inputs.x.value == 1
        assert run.serialized_state.cached_inputs.y.value == {"z": 2}

    @pytest.mark.parametrize(
        "flow_run_state", [Pending(), Running(), Failed(), Success()]
    )
    async def test_running_states_can_not_be_set_if_flow_run_is_not_running(
        self, flow_run_id, task_run_id, flow_run_state
    ):

        await api.states.set_flow_run_state(
            flow_run_id=flow_run_id, state=flow_run_state
        )

        set_running_coroutine = api.states.set_task_run_state(
            task_run_id=task_run_id, state=Running()
        )

        if flow_run_state.is_running():
            assert await set_running_coroutine
            assert (
                await models.TaskRun.where(id=task_run_id).first({"state"})
            ).state == "Running"
        else:

            with pytest.raises(ValueError, match="is not in a running state"):
                await set_running_coroutine
            assert (
                await models.TaskRun.where(id=task_run_id).first({"state"})
            ).state != "Running"
Example #3
0
class TestTaskRunStates:
    async def test_returns_status_dict(self, running_flow_run_id: str,
                                       task_run_id: str):
        result = await states.set_task_run_state(task_run_id, state=Success())
        assert result["status"] == "SUCCESS"

    @pytest.mark.parametrize(
        "state_cls",
        [s for s in State.children() if s not in _MetaState.children()])
    async def test_set_task_run_state(self, running_flow_run_id, task_run_id,
                                      state_cls):
        await states.set_task_run_state(task_run_id=task_run_id,
                                        state=state_cls())

        query = await models.TaskRun.where(id=task_run_id).first(
            {"version", "state", "serialized_state"})

        assert query.version == 1
        assert query.state == state_cls.__name__
        assert query.serialized_state["type"] == state_cls.__name__

    @pytest.mark.parametrize(
        "state_cls",
        [
            s for s in State.children()
            if s not in _MetaState.children() and not s().is_running()
        ],
    )
    async def test_set_non_running_task_run_state_works_when_flow_run_is_not_running(
            self, flow_run_id, task_run_id, state_cls):
        await states.set_flow_run_state(flow_run_id, state=Success())
        await states.set_task_run_state(task_run_id=task_run_id,
                                        state=state_cls())

        query = await models.TaskRun.where(id=task_run_id).first(
            {"version", "state", "serialized_state"})

        assert query.version == 1
        assert query.state == state_cls.__name__
        assert query.serialized_state["type"] == state_cls.__name__

    async def test_set_running_task_run_state_fails_when_flow_run_is_not_running(
            self, flow_run_id, task_run_id):
        await states.set_flow_run_state(flow_run_id, state=Success())
        with pytest.raises(ValueError, match="State update failed"):
            await states.set_task_run_state(task_run_id=task_run_id,
                                            state=Running())

    async def test_set_running_task_run_state_works_when_flow_run_is_not_running_if_force(
        self,
        flow_run_id,
        task_run_id,
    ):
        await states.set_flow_run_state(flow_run_id, state=Success())
        await states.set_task_run_state(task_run_id=task_run_id,
                                        state=Running(),
                                        force=True)

        query = await models.TaskRun.where(id=task_run_id).first(
            {"version", "state", "serialized_state"})

        assert query.version == 1
        assert query.state == "Running"
        assert query.serialized_state["type"] == "Running"

    async def test_set_task_run_state_does_not_increment_run_count_when_looping(
            self, task_run_id, running_flow_run_id):
        # simulate some looping
        await states.set_task_run_state(task_run_id=task_run_id,
                                        state=Running())
        await states.set_task_run_state(task_run_id=task_run_id,
                                        state=Looped())
        result = await states.set_task_run_state(task_run_id=task_run_id,
                                                 state=Running())

        task_run = await models.TaskRun.where(id=task_run_id
                                              ).first({"run_count"})
        assert task_run.run_count == 1

    @pytest.mark.parametrize("state", [Running(), Success()])
    async def test_set_task_run_state_fails_with_wrong_task_run_id(
            self, state, running_flow_run_id):
        with pytest.raises(ValueError, match="Invalid task run ID"):
            await states.set_task_run_state(task_run_id=str(uuid.uuid4()),
                                            state=state)

    async def test_trigger_failed_state_does_not_set_end_time(
            self, task_run_id):
        await states.set_task_run_state(task_run_id=task_run_id,
                                        state=TriggerFailed())
        task_run_info = await models.TaskRun.where(id=task_run_id).first(
            {"id", "start_time", "end_time"})
        assert not task_run_info.start_time
        assert not task_run_info.end_time

    @pytest.mark.parametrize(
        "state_cls",
        [s for s in State.children() if s not in _MetaState.children()])
    async def test_setting_a_task_run_state_pulls_cached_inputs_if_possible(
            self, running_flow_run_id, task_run_id, state_cls):
        # set up a Failed state with cached inputs
        res1 = SafeResult(1, result_handler=JSONResultHandler())
        res2 = SafeResult({"z": 2}, result_handler=JSONResultHandler())
        complex_result = {"x": res1, "y": res2}
        cached_state = Failed(cached_inputs=complex_result)
        await models.TaskRun.where(id=task_run_id).update(set=dict(
            serialized_state=cached_state.serialize()))

        # try to schedule the task run to scheduled
        await states.set_task_run_state(task_run_id=task_run_id,
                                        state=state_cls())

        task_run = await models.TaskRun.where(id=task_run_id
                                              ).first({"serialized_state"})

        # ensure the state change took place
        assert task_run.serialized_state["type"] == state_cls.__name__
        assert task_run.serialized_state["cached_inputs"]["x"]["value"] == 1
        assert task_run.serialized_state["cached_inputs"]["y"]["value"] == {
            "z": 2
        }

    @pytest.mark.parametrize(
        "state",
        [
            s(cached_inputs=None)
            for s in State.children() if s not in _MetaState.children()
        ],
        ids=[
            s.__name__ for s in State.children()
            if s not in _MetaState.children()
        ],
    )
    async def test_task_runs_with_null_cached_inputs_do_not_overwrite_cache(
            self, running_flow_run_id, state, task_run_id):
        # set up a Failed state with null cached inputs
        await states.set_task_run_state(task_run_id=task_run_id, state=state)
        # set up a Retrying state with non-null cached inputs
        res1 = SafeResult(1, result_handler=JSONResultHandler())
        res2 = SafeResult({"z": 2}, result_handler=JSONResultHandler())
        complex_result = {"x": res1, "y": res2}
        cached_state = Retrying(cached_inputs=complex_result)
        await states.set_task_run_state(task_run_id=task_run_id,
                                        state=cached_state)
        run = await models.TaskRun.where(id=task_run_id
                                         ).first({"serialized_state"})

        assert run.serialized_state["cached_inputs"]["x"]["value"] == 1
        assert run.serialized_state["cached_inputs"]["y"]["value"] == {"z": 2}

    @pytest.mark.parametrize(
        "state_cls",
        [s for s in State.children() if s not in _MetaState.children()])
    async def test_task_runs_cached_inputs_give_preference_to_new_cached_inputs(
            self, running_flow_run_id, state_cls, task_run_id):
        # set up a Failed state with null cached inputs
        res1 = SafeResult(1, result_handler=JSONResultHandler())
        res2 = SafeResult({"a": 2}, result_handler=JSONResultHandler())
        complex_result = {"b": res1, "c": res2}
        cached_state = state_cls(cached_inputs=complex_result)
        await states.set_task_run_state(task_run_id=task_run_id,
                                        state=cached_state)
        # set up a Retrying state with non-null cached inputs
        res1 = SafeResult(1, result_handler=JSONResultHandler())
        res2 = SafeResult({"z": 2}, result_handler=JSONResultHandler())
        complex_result = {"x": res1, "y": res2}
        cached_state = Retrying(cached_inputs=complex_result)
        await states.set_task_run_state(task_run_id=task_run_id,
                                        state=cached_state)
        run = Box(await models.TaskRun.where(id=task_run_id
                                             ).first({"serialized_state"}))

        # verify that we have cached inputs, and that preference has been given to the new
        # cached inputs
        assert run.serialized_state.cached_inputs
        assert run.serialized_state.cached_inputs.x.value == 1
        assert run.serialized_state.cached_inputs.y.value == {"z": 2}