Beispiel #1
0
class TestCheckFlowPendingOrRunning:
    @pytest.mark.parametrize(
        "state", [Pending(), Running(),
                  Retrying(), Scheduled()])
    def test_pending_or_running_are_ok(self, state):
        flow = Flow(name="test", tasks=[Task()])
        new_state = FlowRunner(flow=flow).check_flow_is_pending_or_running(
            state=state)
        assert new_state is state

    @pytest.mark.parametrize(
        "state",
        [Finished(), Success(),
         Failed(), Skipped(),
         State()])
    def test_not_pending_or_running_raise_endrun(self, state):
        flow = Flow(name="test", tasks=[Task()])
        with pytest.raises(ENDRUN):
            FlowRunner(flow=flow).check_flow_is_pending_or_running(state=state)
Beispiel #2
0
def test_flow_runner_prioritizes_kwarg_states_over_db_states(
        monkeypatch, state, client):
    flow = prefect.Flow(name="test")
    db_state = state("already", result=10)
    get_flow_run_info = MagicMock(return_value=MagicMock(state=db_state))
    client.get_flow_run_info = get_flow_run_info

    monkeypatch.setattr("prefect.engine.cloud.flow_runner.Client",
                        MagicMock(return_value=client))
    res = CloudFlowRunner(flow=flow).run(state=Pending("let's do this"))

    ## assertions
    assert get_flow_run_info.call_count == 2  # initial state & cancel check
    assert client.set_flow_run_state.call_count == 2  # Pending -> Running -> Success

    states = [
        call[1]["state"] for call in client.set_flow_run_state.call_args_list
    ]
    assert states == [Running(), Success(result={})]
Beispiel #3
0
    def set_flow_to_running(self, state: State) -> State:
        """
        Puts Pending flows in a Running state; leaves Running flows Running.

        Args:
            - state (State): the current state of this flow

        Returns:
            - State: the state of the flow after running the check

        Raises:
            - ENDRUN: if the flow is not pending or running
        """
        if state.is_pending():
            return Running(message="Running flow.")
        elif state.is_running():
            return state
        else:
            raise ENDRUN(state)
Beispiel #4
0
    async def test_set_flow_run_state_with_bad_version(self, run_query,
                                                       locked_flow_run_id):
        result = await run_query(
            query=self.mutation,
            variables=dict(input=dict(states=[
                dict(
                    flow_run_id=locked_flow_run_id,
                    version=10,
                    state=Running().serialize(),
                )
            ])),
        )

        assert "State update failed" in result.errors[0].message

        fr = await models.FlowRun.where(id=locked_flow_run_id
                                        ).first({"state", "version"})
        assert fr.version == 1
        assert fr.state == "Scheduled"
Beispiel #5
0
    async def test_set_task_run_state_with_version(self, run_query,
                                                   task_run_id,
                                                   running_flow_run_id):
        result = await run_query(
            query=self.mutation,
            variables=dict(input=dict(states=[
                dict(
                    task_run_id=task_run_id,
                    version=0,
                    state=Running().serialize(),
                )
            ])),
        )

        assert result.data.set_task_run_states.states[0].id == task_run_id
        tr = await models.TaskRun.where(id=task_run_id
                                        ).first({"state", "version"})
        assert tr.version == 1
        assert tr.state == "Running"
Beispiel #6
0
    def determine_final_state(
        self,
        key_states: Set[State],
        return_states: Dict[Task, State],
        terminal_states: Set[State],
    ) -> State:
        """
        Implements the logic for determining the final state of the flow run.

        Args:
            - key_states (Set[State]): the states which will determine the success / failure of the flow run
            - return_states (Dict[Task, State]): states to return as results
            - terminal_states (Set[State]): the states of the terminal tasks for this flow

        Returns:
            - State: the final state of the flow run
        """
        state = State()  # mypy initialization

        # check that the flow is finished
        if not all(s.is_finished() for s in terminal_states):
            self.logger.info("Flow run RUNNING: terminal tasks are incomplete.")
            state = Running(message="Flow run in progress.", result=return_states)

        # check if any key task failed
        elif any(s.is_failed() for s in key_states):
            self.logger.info("Flow run FAILED: some reference tasks failed.")
            state = Failed(message="Some reference tasks failed.", result=return_states)

        # check if all reference tasks succeeded
        elif all(s.is_successful() for s in key_states):
            self.logger.info("Flow run SUCCESS: all reference tasks succeeded")
            state = Success(
                message="All reference tasks succeeded.", result=return_states
            )

        # check for any unanticipated state that is finished but neither success nor failed
        else:
            self.logger.info("Flow run SUCCESS: no reference tasks failed")
            state = Success(message="No reference tasks failed.", result=return_states)

        return state
Beispiel #7
0
    async def test_set_multiple_task_run_states(
        self, run_query, task_run_id, task_run_id_2, task_run_id_3, running_flow_run_id
    ):
        result = await run_query(
            query=self.mutation,
            variables=dict(
                input=dict(
                    states=[
                        dict(task_run_id=task_run_id, state=Running().serialize()),
                        dict(task_run_id=task_run_id_2, state=Success().serialize()),
                        dict(
                            task_run_id=task_run_id_3,
                            version=1,
                            state=Retrying().serialize(),
                        ),
                    ]
                )
            ),
        )
        assert result.data.set_task_run_states.states == [
            {"id": task_run_id, "status": "SUCCESS", "message": None},
            {"id": task_run_id_2, "status": "SUCCESS", "message": None},
            {"id": task_run_id_3, "status": "SUCCESS", "message": None},
        ]

        tr1 = await models.TaskRun.where(
            id=result.data.set_task_run_states.states[0].id
        ).first({"state", "version"})
        assert tr1.version == 2
        assert tr1.state == "Running"

        tr2 = await models.TaskRun.where(
            id=result.data.set_task_run_states.states[1].id
        ).first({"state", "version"})
        assert tr2.version == 3
        assert tr2.state == "Success"

        tr3 = await models.TaskRun.where(
            id=result.data.set_task_run_states.states[2].id
        ).first({"state", "version"})
        assert tr3.version == 3
        assert tr3.state == "Retrying"
Beispiel #8
0
def test_watch_flow_run_default_timeout(monkeypatch):
    # Test the default behavior, which sets the timeout to 12 hours
    # when the `max_duration` kwarg is not provided
    flow_run = FlowRunView._from_flow_run_data(FLOW_RUN_DATA_1)
    flow_run.state = Running()  # Not finished
    flow_run.get_latest = MagicMock(return_value=flow_run)
    flow_run.get_logs = MagicMock()

    MockView = MagicMock()
    MockView.from_flow_run_id.return_value = flow_run

    monkeypatch.setattr("prefect.backend.flow_run.FlowRunView", MockView)

    # Mock sleep so that we do not have a slow test
    monkeypatch.setattr("prefect.backend.flow_run.time.sleep", MagicMock())

    with pytest.raises(RuntimeError,
                       match="timed out after 12.0 hours of waiting"):
        for log in watch_flow_run("id"):
            pass
Beispiel #9
0
    async def test_set_flow_run_state(self, run_query, flow_run_id):
        result = await run_query(
            query=self.mutation,
            variables=dict(input=dict(states=[
                dict(
                    flow_run_id=flow_run_id,
                    version=1,
                    state=Running().serialize(),
                )
            ])),
        )

        assert result.data.set_flow_run_states.states[0].status == "SUCCESS"
        assert result.data.set_flow_run_states.states[0].message is None

        fr = await models.FlowRun.where(
            id=result.data.set_flow_run_states.states[0].id
        ).first({"state", "version"})
        assert fr.version == 2
        assert fr.state == "Running"
Beispiel #10
0
    async def test_running_states_can_not_be_set_if_flow_run_is_not_running(
            self, flow_run_id, task_run_id, flow_run_state):

        await api.states.set_flow_run_state(flow_run_id=flow_run_id,
                                            state=flow_run_state)

        set_running_coroutine = api.states.set_task_run_state(
            task_run_id=task_run_id, state=Running())

        if flow_run_state.is_running():
            assert await set_running_coroutine
            assert (await
                    models.TaskRun.where(id=task_run_id
                                         ).first({"state"})).state == "Running"
        else:

            with pytest.raises(ValueError, match="is not in a running state"):
                await set_running_coroutine
            assert (await models.TaskRun.where(id=task_run_id).first(
                {"state"})).state != "Running"
Beispiel #11
0
    def set_task_to_running(self, state: State) -> State:
        """
        Sets the task to running

        Args:
            - state (State): the current state of this task

        Returns:
            - State: the state of the task after running the check

        Raises:
            - ENDRUN: if the task is not ready to run
        """
        if not state.is_pending():
            self.logger.debug(
                "Task '{name}': can't set state to Running because it "
                "isn't Pending; ending run.".format(name=prefect.context.get(
                    "task_full_name", self.task.name)))
            raise ENDRUN(state)

        return Running(message="Starting task run.")
Beispiel #12
0
def test_watch_flow_run_timeout(monkeypatch):
    flow_run = FlowRunView._from_flow_run_data(FLOW_RUN_DATA_1)
    flow_run.state = Running()  # Not finished
    flow_run.get_latest = MagicMock(return_value=flow_run)
    flow_run.get_logs = MagicMock()

    MockView = MagicMock()
    MockView.from_flow_run_id.return_value = flow_run

    monkeypatch.setattr("prefect.backend.flow_run.FlowRunView", MockView)

    # Mock sleep so that we do not have a slow test
    monkeypatch.setattr("prefect.backend.flow_run.time.sleep", MagicMock())

    with pytest.raises(RuntimeError,
                       match="timed out after 36.5 hours of waiting"):
        for log in watch_flow_run("id",
                                  max_duration=timedelta(days=1,
                                                         hours=12.5,
                                                         seconds=1)):
            pass
def vclient(monkeypatch):
    cloud_client = MagicMock(
        get_flow_run_info=MagicMock(return_value=MagicMock(state=None)),
        set_flow_run_state=MagicMock(),
        get_task_run_info=MagicMock(return_value=MagicMock(state=None)),
        set_task_run_state=MagicMock(
            side_effect=VersionLockError(),
            return_value=Running()
            # side_effect=lambda task_run_id, version, state, cache_for: state
        ),
        get_latest_task_run_states=MagicMock(
            side_effect=lambda flow_run_id, states: states
        ),
    )
    monkeypatch.setattr(
        "prefect.engine.cloud.task_runner.Client", MagicMock(return_value=cloud_client)
    )
    monkeypatch.setattr(
        "prefect.engine.cloud.flow_runner.Client", MagicMock(return_value=cloud_client)
    )
    yield cloud_client
 async def test_set_task_run_states_rejects_states_with_large_payloads(
         self, run_query, task_run_id, task_run_id_2, running_flow_run_id):
     result = await run_query(
         query=self.mutation,
         variables=dict(input=dict(states=[
             dict(
                 task_run_id=task_run_id,
                 # this state should successfully set
                 state=Running().serialize(),
             ),
             dict(
                 task_run_id=task_run_id_2,
                 # nonsense payload, just large
                 state={
                     i: os.urandom(2 * 1000000).decode("latin")
                     for i in range(2)
                 },
             ),
         ])),
     )
     assert "State payload is too large" in result.errors[0].message
Beispiel #15
0
class TestPrefectMessageCloudHook:
    @pytest.mark.parametrize("state", [Running(), Success(), Failed()])
    async def test_prefect_message_cloud_hook(
        self, run_query, flow_run_id, state, tenant_id
    ):
        await models.Message.where().delete()

        cloud_hook_id = await api.cloud_hooks.create_cloud_hook(
            tenant_id=tenant_id,
            type="PREFECT_MESSAGE",
            config={},
            states=[type(state).__name__],
        )
        set_flow_run_state_mutation = """
            mutation($input: set_flow_run_states_input!) {
                set_flow_run_states(input: $input) {
                    states {
                        id
                        status
                        message
                    }
                }
            }
            """
        await run_query(
            query=set_flow_run_state_mutation,
            variables=dict(
                input=dict(
                    states=[
                        dict(
                            flow_run_id=flow_run_id, version=1, state=state.serialize()
                        )
                    ]
                )
            ),
        )
        await asyncio.sleep(1)
        assert (
            await models.Message.where({"tenant_id": {"_eq": tenant_id}}).count() == 1
        )
Beispiel #16
0
class TestCheckScheduledStep:
    @pytest.mark.parametrize("state", [Failed(), Pending(), Running(), Success()])
    def test_non_scheduled_states(self, state):
        assert (
            FlowRunner(flow=Flow(name="test")).check_flow_reached_start_time(
                state=state
            )
            is state
        )

    def test_scheduled_states_without_start_time(self):
        state = Scheduled(start_time=None)
        assert (
            FlowRunner(flow=Flow(name="test")).check_flow_reached_start_time(
                state=state
            )
            is state
        )

    def test_scheduled_states_with_future_start_time(self):
        state = Scheduled(
            start_time=pendulum.now("utc") + datetime.timedelta(minutes=10)
        )
        with pytest.raises(ENDRUN) as exc:
            FlowRunner(flow=Flow(name="test")).check_flow_reached_start_time(
                state=state
            )
        assert exc.value.state is state

    def test_scheduled_states_with_past_start_time(self):
        state = Scheduled(
            start_time=pendulum.now("utc") - datetime.timedelta(minutes=1)
        )
        assert (
            FlowRunner(flow=Flow(name="test")).check_flow_reached_start_time(
                state=state
            )
            is state
        )
def test_simple_two_task_flow_with_final_task_already_running(
        monkeypatch, executor):

    flow_run_id = str(uuid.uuid4())
    task_run_id_1 = str(uuid.uuid4())
    task_run_id_2 = str(uuid.uuid4())

    with prefect.Flow(name="test") as flow:
        t1 = prefect.Task()
        t2 = prefect.Task()
        t2.set_upstream(t1)

    client = MockedCloudClient(
        flow_runs=[FlowRun(id=flow_run_id)],
        task_runs=[
            TaskRun(id=task_run_id_1,
                    task_slug=flow.slugs[t1],
                    flow_run_id=flow_run_id),
            TaskRun(
                id=task_run_id_2,
                task_slug=flow.slugs[t2],
                version=1,
                flow_run_id=flow_run_id,
                state=Running(),
            ),
        ],
        monkeypatch=monkeypatch,
    )

    with prefect.context(flow_run_id=flow_run_id):
        state = CloudFlowRunner(flow=flow).run(return_tasks=flow.tasks,
                                               executor=executor)

    assert state.is_running()
    assert client.flow_runs[flow_run_id].state.is_running()
    assert client.task_runs[task_run_id_1].state.is_successful()
    assert client.task_runs[task_run_id_1].version == 2
    assert client.task_runs[task_run_id_2].state.is_running()
    assert client.task_runs[task_run_id_2].version == 1
Beispiel #18
0
def test_set_flow_run_state_gets_queued(patch_post):
    response = {
        "data": {
            "set_flow_run_states": {
                "states": [{"id": "74-salt", "status": "QUEUED", "message": None}]
            }
        }
    }
    post = patch_post(response)
    with set_temporary_config(
        {
            "cloud.api": "http://my-cloud.foo",
            "cloud.auth_token": "secret_token",
            "backend": "cloud",
        }
    ):
        client = Client()

    state = Running()
    result = client.set_flow_run_state(flow_run_id="74-salt", version=0, state=state)
    assert isinstance(result, State)
    assert state != result
    assert result.is_queued()
class TestCancelFlowRun:
    mutation = """
        mutation($input: cancel_flow_run_input!) {
            cancel_flow_run(input: $input) {
                state
            }
        }
    """

    @pytest.mark.parametrize(
        "state,res_state,version",
        [
            (Running(), "Cancelling", 4),
            (Success(), "Success", 3),
            (Submitted(), "Cancelled", 4),
        ],
    )
    async def test_cancel_flow_run(self, run_query, flow_run_id, state,
                                   res_state, version):
        await api.states.set_flow_run_state(flow_run_id=flow_run_id,
                                            version=1,
                                            state=state)

        result = await run_query(
            query=self.mutation,
            variables={"input": {
                "flow_run_id": flow_run_id
            }},
        )

        assert result.data.cancel_flow_run.state == res_state

        fr = await models.FlowRun.where(id=flow_run_id
                                        ).first({"state", "version"})
        assert fr.version == version
        assert fr.state == res_state
Beispiel #20
0
 async def test_set_multiple_flow_run_states_with_error(
         self, run_query, flow_run_id, flow_run_id_2, flow_run_id_3):
     result = await run_query(
         query=self.mutation,
         variables=dict(input=dict(states=[
             dict(
                 flow_run_id=flow_run_id,
                 version=1,
                 state=Running().serialize(),
             ),
             dict(
                 flow_run_id=flow_run_id_2,
                 version=10,
                 state="a bad state",
             ),
             dict(
                 flow_run_id=flow_run_id_3,
                 version=3,
                 state=Retrying().serialize(),
             ),
         ])),
     )
     assert result.data.set_flow_run_states is None
     assert result.errors[0].message
Beispiel #21
0
    async def test_returns_status_from_underlying_call(
        self, run_query, flow_run_id, payload_response, monkeypatch
    ):
        """
        This test should ensure that the `status` field should be determined
        based on the underlying `api.states.set_flow_run_state()` call. 
        """

        mock_state_api = CoroutineMock(return_value=payload_response)

        monkeypatch.setattr(
            "src.prefect_server.graphql.states.api.states.set_flow_run_state",
            mock_state_api,
        )

        result = await run_query(
            query=self.mutation,
            variables=dict(
                input=dict(
                    states=[
                        dict(
                            flow_run_id=flow_run_id,
                            version=1,
                            state=Running().serialize(),
                        )
                    ]
                )
            ),
        )

        mock_state_api.assert_awaited_once()

        assert (
            result.data.set_flow_run_states.states[0].status
            == payload_response["status"]
        )
Beispiel #22
0
class TestFlowRunStates:
    async def test_returns_status_dict(self, flow_run_id: str):
        result = await states.set_flow_run_state(flow_run_id, state=Success())
        assert result["status"] == "SUCCESS"

    @pytest.mark.parametrize(
        "state_cls",
        [s for s in State.children() if s not in _MetaState.children()])
    async def test_set_flow_run_state(self, flow_run_id, state_cls):
        result = await states.set_flow_run_state(flow_run_id=flow_run_id,
                                                 state=state_cls())

        query = await models.FlowRun.where(id=flow_run_id).first(
            {"version", "state", "serialized_state"})

        assert query.version == 2
        assert query.state == state_cls.__name__
        assert query.serialized_state["type"] == state_cls.__name__

    @pytest.mark.parametrize("state", [Running(), Success()])
    async def test_set_flow_run_state_fails_with_wrong_flow_run_id(
            self, state):
        with pytest.raises(ValueError, match="Invalid flow run ID"):
            await states.set_flow_run_state(flow_run_id=str(uuid.uuid4()),
                                            state=state)

    async def test_trigger_failed_state_does_not_set_end_time(
            self, flow_run_id):
        # there is no logic in Prefect that would create this sequence of
        # events, but a user could manually do this
        await states.set_flow_run_state(flow_run_id=flow_run_id,
                                        state=TriggerFailed())
        flow_run_info = await models.FlowRun.where(id=flow_run_id).first(
            {"id", "start_time", "end_time"})
        assert not flow_run_info.start_time
        assert not flow_run_info.end_time
Beispiel #23
0
        dict(state=Cancelled(), assert_true={"is_finished"}),
        dict(state=Cached(),
             assert_true={"is_cached", "is_finished", "is_successful"}),
        dict(state=ClientFailed(), assert_true={"is_meta_state"}),
        dict(state=Failed(), assert_true={"is_finished", "is_failed"}),
        dict(state=Finished(), assert_true={"is_finished"}),
        dict(state=Looped(), assert_true={"is_finished", "is_looped"}),
        dict(state=Mapped(),
             assert_true={"is_finished", "is_mapped", "is_successful"}),
        dict(state=Paused(), assert_true={"is_pending", "is_scheduled"}),
        dict(state=Pending(), assert_true={"is_pending"}),
        dict(state=Queued(), assert_true={"is_meta_state", "is_queued"}),
        dict(state=Resume(), assert_true={"is_pending", "is_scheduled"}),
        dict(state=Retrying(),
             assert_true={"is_pending", "is_scheduled", "is_retrying"}),
        dict(state=Running(), assert_true={"is_running"}),
        dict(state=Scheduled(), assert_true={"is_pending", "is_scheduled"}),
        dict(state=Skipped(),
             assert_true={"is_finished", "is_successful", "is_skipped"}),
        dict(state=Submitted(), assert_true={"is_meta_state", "is_submitted"}),
        dict(state=Success(), assert_true={"is_finished", "is_successful"}),
        dict(state=TimedOut(), assert_true={"is_finished", "is_failed"}),
        dict(state=TriggerFailed(), assert_true={"is_finished", "is_failed"}),
    ],
)
def test_state_is_methods(state_check):
    """
    Iterates over all of the "is_*()" methods of the state, asserting that each one is
    False, unless the name of that method is provided as `assert_true`.

    For example, if `state_check == (Pending(), {'is_pending'})`, then this method will
class TestTaskRunStates:
    async def test_set_task_run_state(self, task_run_id):
        result = await api.states.set_task_run_state(
            task_run_id=task_run_id, state=Failed()
        )

        assert result.task_run_id == task_run_id

        query = await models.TaskRun.where(id=task_run_id).first(
            {"version", "state", "serialized_state"}
        )

        assert query.version == 2
        assert query.state == "Failed"
        assert query.serialized_state["type"] == "Failed"

    @pytest.mark.parametrize("state", [Failed(), Success()])
    async def test_set_task_run_state_fails_with_wrong_task_run_id(self, state):
        with pytest.raises(ValueError, match="State update failed"):
            await api.states.set_task_run_state(
                task_run_id=str(uuid.uuid4()), state=state
            )

    @pytest.mark.parametrize(
        "state", [s() for s in State.children() if not s().is_running()]
    )
    async def test_state_does_not_set_heartbeat_unless_running(
        self, state, task_run_id
    ):
        task_run = await models.TaskRun.where(id=task_run_id).first({"heartbeat"})
        assert task_run.heartbeat is None

        await api.states.set_task_run_state(task_run_id=task_run_id, state=state)

        task_run = await models.TaskRun.where(id=task_run_id).first({"heartbeat"})
        assert task_run.heartbeat is None

    async def test_running_state_sets_heartbeat(self, task_run_id, running_flow_run_id):
        task_run = await models.TaskRun.where(id=task_run_id).first({"heartbeat"})
        assert task_run.heartbeat is None

        dt = pendulum.now("UTC")
        await api.states.set_task_run_state(task_run_id=task_run_id, state=Running())

        task_run = await models.TaskRun.where(id=task_run_id).first({"heartbeat"})
        assert task_run.heartbeat > dt

    async def test_trigger_failed_state_does_not_set_end_time(self, task_run_id):
        await api.states.set_task_run_state(
            task_run_id=task_run_id, state=TriggerFailed()
        )
        task_run_info = await models.TaskRun.where(id=task_run_id).first(
            {"id", "start_time", "end_time"}
        )
        assert not task_run_info.start_time
        assert not task_run_info.end_time

    @pytest.mark.parametrize(
        "state",
        [s() for s in State.children() if s not in _MetaState.children()],
        ids=[s.__name__ for s in State.children() if s not in _MetaState.children()],
    )
    async def test_setting_a_task_run_state_pulls_cached_inputs_if_possible(
        self, task_run_id, state, running_flow_run_id
    ):

        res1 = SafeResult(1, result_handler=JSONResultHandler())
        res2 = SafeResult({"z": 2}, result_handler=JSONResultHandler())
        complex_result = {"x": res1, "y": res2}
        cached_state = Failed(cached_inputs=complex_result)
        await models.TaskRun.where(id=task_run_id).update(
            set=dict(serialized_state=cached_state.serialize())
        )

        # try to schedule the task run to scheduled
        await api.states.set_task_run_state(task_run_id=task_run_id, state=state)

        task_run = await models.TaskRun.where(id=task_run_id).first(
            {"serialized_state"}
        )

        # ensure the state change took place
        assert task_run.serialized_state["type"] == type(state).__name__
        assert task_run.serialized_state["cached_inputs"]["x"]["value"] == 1
        assert task_run.serialized_state["cached_inputs"]["y"]["value"] == {"z": 2}

    @pytest.mark.parametrize(
        "state",
        [
            s(cached_inputs=None)
            for s in State.children()
            if s not in _MetaState.children()
        ],
        ids=[s.__name__ for s in State.children() if s not in _MetaState.children()],
    )
    async def test_task_runs_with_null_cached_inputs_do_not_overwrite_cache(
        self, state, task_run_id, running_flow_run_id
    ):

        await api.states.set_task_run_state(task_run_id=task_run_id, state=state)
        # set up a Retrying state with non-null cached inputs
        res1 = SafeResult(1, result_handler=JSONResultHandler())
        res2 = SafeResult({"z": 2}, result_handler=JSONResultHandler())
        complex_result = {"x": res1, "y": res2}
        cached_state = Retrying(cached_inputs=complex_result)
        await api.states.set_task_run_state(task_run_id=task_run_id, state=cached_state)
        run = await models.TaskRun.where(id=task_run_id).first({"serialized_state"})

        assert run.serialized_state["cached_inputs"]["x"]["value"] == 1
        assert run.serialized_state["cached_inputs"]["y"]["value"] == {"z": 2}

    @pytest.mark.parametrize(
        "state_cls", [s for s in State.children() if s not in _MetaState.children()]
    )
    async def test_task_runs_cached_inputs_give_preference_to_new_cached_inputs(
        self, state_cls, task_run_id, running_flow_run_id
    ):

        # set up a Failed state with null cached inputs
        res1 = SafeResult(1, result_handler=JSONResultHandler())
        res2 = SafeResult({"a": 2}, result_handler=JSONResultHandler())
        complex_result = {"b": res1, "c": res2}
        cached_state = state_cls(cached_inputs=complex_result)
        await api.states.set_task_run_state(task_run_id=task_run_id, state=cached_state)
        # set up a Retrying state with non-null cached inputs
        res1 = SafeResult(1, result_handler=JSONResultHandler())
        res2 = SafeResult({"z": 2}, result_handler=JSONResultHandler())
        complex_result = {"x": res1, "y": res2}
        cached_state = Retrying(cached_inputs=complex_result)
        await api.states.set_task_run_state(task_run_id=task_run_id, state=cached_state)
        run = Box(
            await models.TaskRun.where(id=task_run_id).first({"serialized_state"})
        )

        # verify that we have cached inputs, and that preference has been given to the new
        # cached inputs
        assert run.serialized_state.cached_inputs
        assert run.serialized_state.cached_inputs.x.value == 1
        assert run.serialized_state.cached_inputs.y.value == {"z": 2}

    @pytest.mark.parametrize(
        "flow_run_state", [Pending(), Running(), Failed(), Success()]
    )
    async def test_running_states_can_not_be_set_if_flow_run_is_not_running(
        self, flow_run_id, task_run_id, flow_run_state
    ):

        await api.states.set_flow_run_state(
            flow_run_id=flow_run_id, state=flow_run_state
        )

        set_running_coroutine = api.states.set_task_run_state(
            task_run_id=task_run_id, state=Running()
        )

        if flow_run_state.is_running():
            assert await set_running_coroutine
            assert (
                await models.TaskRun.where(id=task_run_id).first({"state"})
            ).state == "Running"
        else:

            with pytest.raises(ValueError, match="is not in a running state"):
                await set_running_coroutine
            assert (
                await models.TaskRun.where(id=task_run_id).first({"state"})
            ).state != "Running"
class TestFlowRunStates:
    async def test_set_flow_run_state(self, flow_run_id):
        result = await api.states.set_flow_run_state(
            flow_run_id=flow_run_id, state=Running()
        )

        assert result.flow_run_id == flow_run_id

        query = await models.FlowRun.where(id=flow_run_id).first(
            {"version", "state", "serialized_state"}
        )

        assert query.version == 3
        assert query.state == "Running"
        assert query.serialized_state["type"] == "Running"

    @pytest.mark.parametrize("state", [Running(), Success()])
    async def test_set_flow_run_state_fails_with_wrong_flow_run_id(self, state):
        with pytest.raises(ValueError, match="State update failed"):
            await api.states.set_flow_run_state(
                flow_run_id=str(uuid.uuid4()), state=state
            )

    async def test_trigger_failed_state_does_not_set_end_time(self, flow_run_id):
        # there is no logic in Prefect that would create this sequence of
        # events, but a user could manually do this
        await api.states.set_flow_run_state(
            flow_run_id=flow_run_id, state=TriggerFailed()
        )
        flow_run_info = await models.FlowRun.where(id=flow_run_id).first(
            {"id", "start_time", "end_time"}
        )
        assert not flow_run_info.start_time
        assert not flow_run_info.end_time

    @pytest.mark.parametrize(
        "state",
        [
            s()
            for s in State.children()
            if not s().is_running() and not s().is_submitted()
        ],
    )
    async def test_state_does_not_set_heartbeat_unless_running_or_submitted(
        self, state, flow_run_id
    ):
        flow_run = await models.FlowRun.where(id=flow_run_id).first({"heartbeat"})
        assert flow_run.heartbeat is None

        dt = pendulum.now("UTC")
        await api.states.set_flow_run_state(flow_run_id=flow_run_id, state=state)

        flow_run = await models.FlowRun.where(id=flow_run_id).first({"heartbeat"})
        assert flow_run.heartbeat is None

    @pytest.mark.parametrize("state", [Running(), Submitted()])
    async def test_running_and_submitted_state_sets_heartbeat(self, state, flow_run_id):
        """
        Both Running and Submitted states need to set heartbeats for services like Lazarus to
        function properly.
        """
        flow_run = await models.FlowRun.where(id=flow_run_id).first({"heartbeat"})
        assert flow_run.heartbeat is None

        dt = pendulum.now("UTC")
        await api.states.set_flow_run_state(flow_run_id=flow_run_id, state=state)

        flow_run = await models.FlowRun.where(id=flow_run_id).first({"heartbeat"})
        assert flow_run.heartbeat > dt

    async def test_setting_flow_run_to_cancelled_state_sets_unfinished_task_runs_to_cancelled(
        self, flow_run_id
    ):
        task_runs = await models.TaskRun.where(
            {"flow_run_id": {"_eq": flow_run_id}}
        ).get({"id"})
        task_run_ids = [run.id for run in task_runs]
        # update the state to Running
        await api.states.set_flow_run_state(flow_run_id=flow_run_id, state=Running())
        # Currently this flow_run_id fixture has at least 3 tasks, if this
        # changes the test will need to be updated
        assert len(task_run_ids) >= 3, "flow_run_id fixture has changed"
        # Set one task run to pending, one to running, and the rest to success
        pending_task_run = task_run_ids[0]
        running_task_run = task_run_ids[1]
        rest = task_run_ids[2:]
        await api.states.set_task_run_state(
            task_run_id=pending_task_run, state=Pending()
        )
        await api.states.set_task_run_state(
            task_run_id=running_task_run, state=Running()
        )
        for task_run_id in rest:
            await api.states.set_task_run_state(
                task_run_id=task_run_id, state=Success()
            )
        # set the flow run to a cancelled state
        await api.states.set_flow_run_state(flow_run_id=flow_run_id, state=Cancelled())
        # Confirm the unfinished task runs have been marked as cancelled
        task_runs = await models.TaskRun.where(
            {"flow_run_id": {"_eq": flow_run_id}}
        ).get({"id", "state"})
        new_states = {run.id: run.state for run in task_runs}
        assert new_states[pending_task_run] == "Cancelled"
        assert new_states[running_task_run] == "Cancelled"
        assert all(new_states[id] == "Success" for id in rest)
Beispiel #26
0
class TestExecuteFlowRunInSubprocess:
    @pytest.fixture()
    def mocks(self, monkeypatch):
        class Mocks:
            subprocess = MagicMock()
            wait_for_flow_run_start_time = MagicMock()
            fail_flow_run = MagicMock()

        mocks = Mocks()
        monkeypatch.setattr("prefect.backend.execution.subprocess",
                            mocks.subprocess)
        monkeypatch.setattr(
            "prefect.backend.execution._wait_for_flow_run_start_time",
            mocks.wait_for_flow_run_start_time,
        )
        monkeypatch.setattr("prefect.backend.execution._fail_flow_run",
                            mocks.fail_flow_run)

        # Since we mocked the module this error cannot be used in try/catch without
        # replacing it with the correct type
        mocks.subprocess.CalledProcessError = CalledProcessError

        return mocks

    def test_creates_subprocess_correctly(self, cloud_mocks, mocks):
        # Returned a scheduled flow run to start
        cloud_mocks.FlowRunView.from_flow_run_id().state = Scheduled()
        # Return a finished flow run after the first iteration
        cloud_mocks.FlowRunView().get_latest().state = Success()

        execute_flow_run_in_subprocess("flow-run-id")

        # Should pass the correct flow run id to wait for
        mocks.wait_for_flow_run_start_time.assert_called_once_with(
            "flow-run-id")

        # Calls the correct command w/ environment variables
        mocks.subprocess.run.assert_called_once_with(
            [sys.executable, "-m", "prefect", "execute", "flow-run"],
            env={
                "PREFECT__CLOUD__SEND_FLOW_RUN_LOGS":
                "True",
                "PREFECT__LOGGING__LEVEL":
                "INFO",
                "PREFECT__LOGGING__FORMAT":
                "[%(asctime)s] %(levelname)s - %(name)s | %(message)s",
                "PREFECT__LOGGING__DATEFMT":
                "%Y-%m-%d %H:%M:%S%z",
                "PREFECT__BACKEND":
                "cloud",
                "PREFECT__CLOUD__API":
                "https://api.prefect.io",
                "PREFECT__CLOUD__TENANT_ID":
                "",
                "PREFECT__CLOUD__API_KEY":
                cloud_mocks.Client().api_key,
                "PREFECT__CONTEXT__FLOW_RUN_ID":
                "flow-run-id",
                "PREFECT__CONTEXT__FLOW_ID":
                cloud_mocks.FlowRunView.from_flow_run_id().flow_id,
                "PREFECT__ENGINE__FLOW_RUNNER__DEFAULT_CLASS":
                "prefect.engine.cloud.CloudFlowRunner",
                "PREFECT__ENGINE__TASK_RUNNER__DEFAULT_CLASS":
                "prefect.engine.cloud.CloudTaskRunner",
            },
        )

        # Return code is checked
        mocks.subprocess.run().check_returncode.assert_called_once()

    @pytest.mark.parametrize("start_state", [Submitted(), Running()])
    def test_fails_immediately_if_flow_run_is_being_executed_elsewhere(
            self, cloud_mocks, start_state, mocks):
        cloud_mocks.FlowRunView.from_flow_run_id().state = start_state
        with pytest.raises(RuntimeError, match="already in state"):
            execute_flow_run_in_subprocess("flow-run-id")

    def test_handles_signal_interrupt(self, cloud_mocks, mocks):
        cloud_mocks.FlowRunView.from_flow_run_id().state = Scheduled()
        mocks.subprocess.run.side_effect = KeyboardInterrupt()

        # Keyboard interrupt should be re-raised
        with pytest.raises(KeyboardInterrupt):
            execute_flow_run_in_subprocess("flow-run-id")

        # Only tried to run once
        mocks.subprocess.run.assert_called_once()

        # Flow run is failed with the proper message
        mocks.fail_flow_run.assert_called_once_with(
            flow_run_id="flow-run-id",
            message="Flow run received an interrupt signal.")

    def test_handles_unexpected_exception(self, cloud_mocks, mocks):
        cloud_mocks.FlowRunView.from_flow_run_id().state = Scheduled()
        mocks.subprocess.run.side_effect = Exception("Foobar")

        # Re-raised as `RuntmeError`
        with pytest.raises(
                RuntimeError,
                match="encountered unexpected exception during execution"):
            execute_flow_run_in_subprocess("flow-run-id")

        # Only tried to run once
        mocks.subprocess.run.assert_called_once()

        # Flow run is failed with the proper message
        mocks.fail_flow_run.assert_called_once_with(
            flow_run_id="flow-run-id",
            message=(
                "Flow run encountered unexpected exception during execution: "
                f"{Exception('Foobar')!r}"),
        )

    def test_handles_bad_subprocess_result(self, cloud_mocks, mocks):
        cloud_mocks.FlowRunView.from_flow_run_id().state = Scheduled()
        mocks.subprocess.run.return_value.check_returncode.side_effect = (
            CalledProcessError(cmd="foo", returncode=1))

        # Re-raised as `RuntmeError`
        with pytest.raises(RuntimeError, match="flow run process failed"):
            execute_flow_run_in_subprocess("flow-run-id")

        # Only tried to run once
        mocks.subprocess.run.assert_called_once()

        # Flow run is not failed at this time -- left to the FlowRunner
        mocks.fail_flow_run.assert_not_called()

    def test_loops_until_flow_run_is_finished(self, cloud_mocks, mocks):
        cloud_mocks.FlowRunView.from_flow_run_id().state = Scheduled()
        cloud_mocks.FlowRunView.from_flow_run_id().get_latest.side_effect = [
            MagicMock(state=Running()),
            MagicMock(state=Running()),
            MagicMock(state=Success()),
        ]

        execute_flow_run_in_subprocess("flow-run-id")

        # Ran the subprocess twice
        assert mocks.subprocess.run.call_count == 2
        # Waited each time
        assert mocks.wait_for_flow_run_start_time.call_count == 2
 def get_flow_run_info(*args, _version=itertools.count(), **kwargs):
     state = Cancelling() if trigger.is_set() else Running()
     return MagicMock(version=next(_version), state=state)
Beispiel #28
0
 async def test_set_running_task_run_state_fails_when_flow_run_is_not_running(
         self, flow_run_id, task_run_id):
     await states.set_flow_run_state(flow_run_id, state=Success())
     with pytest.raises(ValueError, match="State update failed"):
         await states.set_task_run_state(task_run_id=task_run_id,
                                         state=Running())
Beispiel #29
0
class TestTaskRunStates:
    async def test_returns_status_dict(self, running_flow_run_id: str,
                                       task_run_id: str):
        result = await states.set_task_run_state(task_run_id, state=Success())
        assert result["status"] == "SUCCESS"

    @pytest.mark.parametrize(
        "state_cls",
        [s for s in State.children() if s not in _MetaState.children()])
    async def test_set_task_run_state(self, running_flow_run_id, task_run_id,
                                      state_cls):
        await states.set_task_run_state(task_run_id=task_run_id,
                                        state=state_cls())

        query = await models.TaskRun.where(id=task_run_id).first(
            {"version", "state", "serialized_state"})

        assert query.version == 1
        assert query.state == state_cls.__name__
        assert query.serialized_state["type"] == state_cls.__name__

    @pytest.mark.parametrize(
        "state_cls",
        [
            s for s in State.children()
            if s not in _MetaState.children() and not s().is_running()
        ],
    )
    async def test_set_non_running_task_run_state_works_when_flow_run_is_not_running(
            self, flow_run_id, task_run_id, state_cls):
        await states.set_flow_run_state(flow_run_id, state=Success())
        await states.set_task_run_state(task_run_id=task_run_id,
                                        state=state_cls())

        query = await models.TaskRun.where(id=task_run_id).first(
            {"version", "state", "serialized_state"})

        assert query.version == 1
        assert query.state == state_cls.__name__
        assert query.serialized_state["type"] == state_cls.__name__

    async def test_set_running_task_run_state_fails_when_flow_run_is_not_running(
            self, flow_run_id, task_run_id):
        await states.set_flow_run_state(flow_run_id, state=Success())
        with pytest.raises(ValueError, match="State update failed"):
            await states.set_task_run_state(task_run_id=task_run_id,
                                            state=Running())

    async def test_set_running_task_run_state_works_when_flow_run_is_not_running_if_force(
        self,
        flow_run_id,
        task_run_id,
    ):
        await states.set_flow_run_state(flow_run_id, state=Success())
        await states.set_task_run_state(task_run_id=task_run_id,
                                        state=Running(),
                                        force=True)

        query = await models.TaskRun.where(id=task_run_id).first(
            {"version", "state", "serialized_state"})

        assert query.version == 1
        assert query.state == "Running"
        assert query.serialized_state["type"] == "Running"

    async def test_set_task_run_state_does_not_increment_run_count_when_looping(
            self, task_run_id, running_flow_run_id):
        # simulate some looping
        await states.set_task_run_state(task_run_id=task_run_id,
                                        state=Running())
        await states.set_task_run_state(task_run_id=task_run_id,
                                        state=Looped())
        result = await states.set_task_run_state(task_run_id=task_run_id,
                                                 state=Running())

        task_run = await models.TaskRun.where(id=task_run_id
                                              ).first({"run_count"})
        assert task_run.run_count == 1

    @pytest.mark.parametrize("state", [Running(), Success()])
    async def test_set_task_run_state_fails_with_wrong_task_run_id(
            self, state, running_flow_run_id):
        with pytest.raises(ValueError, match="Invalid task run ID"):
            await states.set_task_run_state(task_run_id=str(uuid.uuid4()),
                                            state=state)

    async def test_trigger_failed_state_does_not_set_end_time(
            self, task_run_id):
        await states.set_task_run_state(task_run_id=task_run_id,
                                        state=TriggerFailed())
        task_run_info = await models.TaskRun.where(id=task_run_id).first(
            {"id", "start_time", "end_time"})
        assert not task_run_info.start_time
        assert not task_run_info.end_time

    @pytest.mark.parametrize(
        "state_cls",
        [s for s in State.children() if s not in _MetaState.children()])
    async def test_setting_a_task_run_state_pulls_cached_inputs_if_possible(
            self, running_flow_run_id, task_run_id, state_cls):
        # set up a Failed state with cached inputs
        res1 = SafeResult(1, result_handler=JSONResultHandler())
        res2 = SafeResult({"z": 2}, result_handler=JSONResultHandler())
        complex_result = {"x": res1, "y": res2}
        cached_state = Failed(cached_inputs=complex_result)
        await models.TaskRun.where(id=task_run_id).update(set=dict(
            serialized_state=cached_state.serialize()))

        # try to schedule the task run to scheduled
        await states.set_task_run_state(task_run_id=task_run_id,
                                        state=state_cls())

        task_run = await models.TaskRun.where(id=task_run_id
                                              ).first({"serialized_state"})

        # ensure the state change took place
        assert task_run.serialized_state["type"] == state_cls.__name__
        assert task_run.serialized_state["cached_inputs"]["x"]["value"] == 1
        assert task_run.serialized_state["cached_inputs"]["y"]["value"] == {
            "z": 2
        }

    @pytest.mark.parametrize(
        "state",
        [
            s(cached_inputs=None)
            for s in State.children() if s not in _MetaState.children()
        ],
        ids=[
            s.__name__ for s in State.children()
            if s not in _MetaState.children()
        ],
    )
    async def test_task_runs_with_null_cached_inputs_do_not_overwrite_cache(
            self, running_flow_run_id, state, task_run_id):
        # set up a Failed state with null cached inputs
        await states.set_task_run_state(task_run_id=task_run_id, state=state)
        # set up a Retrying state with non-null cached inputs
        res1 = SafeResult(1, result_handler=JSONResultHandler())
        res2 = SafeResult({"z": 2}, result_handler=JSONResultHandler())
        complex_result = {"x": res1, "y": res2}
        cached_state = Retrying(cached_inputs=complex_result)
        await states.set_task_run_state(task_run_id=task_run_id,
                                        state=cached_state)
        run = await models.TaskRun.where(id=task_run_id
                                         ).first({"serialized_state"})

        assert run.serialized_state["cached_inputs"]["x"]["value"] == 1
        assert run.serialized_state["cached_inputs"]["y"]["value"] == {"z": 2}

    @pytest.mark.parametrize(
        "state_cls",
        [s for s in State.children() if s not in _MetaState.children()])
    async def test_task_runs_cached_inputs_give_preference_to_new_cached_inputs(
            self, running_flow_run_id, state_cls, task_run_id):
        # set up a Failed state with null cached inputs
        res1 = SafeResult(1, result_handler=JSONResultHandler())
        res2 = SafeResult({"a": 2}, result_handler=JSONResultHandler())
        complex_result = {"b": res1, "c": res2}
        cached_state = state_cls(cached_inputs=complex_result)
        await states.set_task_run_state(task_run_id=task_run_id,
                                        state=cached_state)
        # set up a Retrying state with non-null cached inputs
        res1 = SafeResult(1, result_handler=JSONResultHandler())
        res2 = SafeResult({"z": 2}, result_handler=JSONResultHandler())
        complex_result = {"x": res1, "y": res2}
        cached_state = Retrying(cached_inputs=complex_result)
        await states.set_task_run_state(task_run_id=task_run_id,
                                        state=cached_state)
        run = Box(await models.TaskRun.where(id=task_run_id
                                             ).first({"serialized_state"}))

        # verify that we have cached inputs, and that preference has been given to the new
        # cached inputs
        assert run.serialized_state.cached_inputs
        assert run.serialized_state.cached_inputs.x.value == 1
        assert run.serialized_state.cached_inputs.y.value == {"z": 2}
Beispiel #30
0
class TestTaskRunStates:
    async def test_set_task_run_state(self, task_run_id):
        result = await api.states.set_task_run_state(task_run_id=task_run_id,
                                                     state=Failed())

        assert result.task_run_id == task_run_id

        query = await models.TaskRun.where(id=task_run_id).first(
            {"version", "state", "serialized_state"})

        assert query.version == 2
        assert query.state == "Failed"
        assert query.serialized_state["type"] == "Failed"

    @pytest.mark.parametrize("state", [Failed(), Success()])
    async def test_set_task_run_state_fails_with_wrong_task_run_id(
            self, state):
        with pytest.raises(ValueError, match="State update failed"):
            await api.states.set_task_run_state(task_run_id=str(uuid.uuid4()),
                                                state=state)

    @pytest.mark.parametrize(
        "state", [s() for s in State.children() if not s().is_running()])
    async def test_state_does_not_set_heartbeat_unless_running(
            self, state, task_run_id):
        task_run = await models.TaskRun.where(id=task_run_id
                                              ).first({"heartbeat"})
        assert task_run.heartbeat is None

        await api.states.set_task_run_state(task_run_id=task_run_id,
                                            state=state)

        task_run = await models.TaskRun.where(id=task_run_id
                                              ).first({"heartbeat"})
        assert task_run.heartbeat is None

    async def test_running_state_sets_heartbeat(self, task_run_id,
                                                running_flow_run_id):
        task_run = await models.TaskRun.where(id=task_run_id
                                              ).first({"heartbeat"})
        assert task_run.heartbeat is None

        dt = pendulum.now("UTC")
        await api.states.set_task_run_state(task_run_id=task_run_id,
                                            state=Running())

        task_run = await models.TaskRun.where(id=task_run_id
                                              ).first({"heartbeat"})
        assert task_run.heartbeat > dt

    async def test_trigger_failed_state_does_not_set_end_time(
            self, task_run_id):
        await api.states.set_task_run_state(task_run_id=task_run_id,
                                            state=TriggerFailed())
        task_run_info = await models.TaskRun.where(id=task_run_id).first(
            {"id", "start_time", "end_time"})
        assert not task_run_info.start_time
        assert not task_run_info.end_time

    @pytest.mark.parametrize(
        "flow_run_state",
        [Pending(), Running(), Failed(),
         Success()])
    async def test_running_states_can_not_be_set_if_flow_run_is_not_running(
            self, flow_run_id, task_run_id, flow_run_state):

        await api.states.set_flow_run_state(flow_run_id=flow_run_id,
                                            state=flow_run_state)

        set_running_coroutine = api.states.set_task_run_state(
            task_run_id=task_run_id, state=Running())

        if flow_run_state.is_running():
            assert await set_running_coroutine
            assert (await
                    models.TaskRun.where(id=task_run_id
                                         ).first({"state"})).state == "Running"
        else:

            with pytest.raises(ValueError, match="is not in a running state"):
                await set_running_coroutine
            assert (await models.TaskRun.where(id=task_run_id).first(
                {"state"})).state != "Running"