Exemplo n.º 1
0
async def cancel_flow_run(flow_run_id: str) -> models.FlowRun:
    """
    Cancel a flow run.

    If the flow run is already finished, this is a no-op.

    Args:
        - flow_run_id (str): the flow run to cancel
    """
    if not flow_run_id:
        raise ValueError("Invalid flow run ID.")

    flow_run = await models.FlowRun.where(id=flow_run_id).first(
        {"id", "state", "serialized_state"})
    if not flow_run:
        raise ValueError(f"Invalid flow run ID: {flow_run_id}.")

    state = state_schema.load(flow_run.serialized_state)

    if state.is_finished():
        return flow_run
    else:
        if state.is_running():
            state = Cancelling("Flow run is cancelling")
        else:
            state = Cancelled("Flow run is cancelled")
        return await set_flow_run_state(flow_run_id=flow_run_id, state=state)
Exemplo n.º 2
0
 async def test_setting_flow_run_to_cancelled_state_sets_unfinished_task_runs_to_cancelled(
         self, flow_run_id):
     task_runs = await models.TaskRun.where({
         "flow_run_id": {
             "_eq": flow_run_id
         }
     }).get({"id"})
     task_run_ids = [run.id for run in task_runs]
     # update the state to Running
     await api.states.set_flow_run_state(flow_run_id=flow_run_id,
                                         state=Running())
     # Currently this flow_run_id fixture has at least 3 tasks, if this
     # changes the test will need to be updated
     assert len(task_run_ids) >= 3, "flow_run_id fixture has changed"
     # Set one task run to pending, one to running, and the rest to success
     pending_task_run = task_run_ids[0]
     running_task_run = task_run_ids[1]
     rest = task_run_ids[2:]
     await api.states.set_task_run_state(task_run_id=pending_task_run,
                                         state=Pending())
     await api.states.set_task_run_state(task_run_id=running_task_run,
                                         state=Running())
     for task_run_id in rest:
         await api.states.set_task_run_state(task_run_id=task_run_id,
                                             state=Success())
     # set the flow run to a cancelled state
     await api.states.set_flow_run_state(flow_run_id=flow_run_id,
                                         state=Cancelled())
     # Confirm the unfinished task runs have been marked as cancelled
     task_runs = await models.TaskRun.where({
         "flow_run_id": {
             "_eq": flow_run_id
         }
     }).get({"id", "state"})
     new_states = {run.id: run.state for run in task_runs}
     assert new_states[pending_task_run] == "Cancelled"
     assert new_states[running_task_run] == "Cancelled"
     assert all(new_states[id] == "Success" for id in rest)
Exemplo n.º 3
0
    def get_task_run_state(
        self,
        state: State,
        inputs: Dict[str, Result],
        timeout_handler: Optional[Callable] = None,
    ) -> State:
        """
        Runs the task and traps any signals or errors it raises.
        Also checkpoints the result of a successful task, if `task.checkpoint` is `True`.

        Args:
            - state (State): the current state of this task
            - inputs (Dict[str, Result], optional): a dictionary of inputs whose keys correspond
                to the task's `run()` arguments.
            - timeout_handler (Callable, optional): function for timing out
                task execution, with call signature `handler(fn, *args, **kwargs)`. Defaults to
                `prefect.utilities.executors.timeout_handler`

        Returns:
            - State: the state of the task after running the check

        Raises:
            - signals.PAUSE: if the task raises PAUSE
            - ENDRUN: if the task is not ready to run
        """
        if not state.is_running():
            self.logger.debug(
                "Task '{name}': can't run task because it's not in a "
                "Running state; ending run.".format(
                    name=prefect.context.get("task_full_name", self.task.name)
                )
            )

            raise ENDRUN(state)

        try:
            self.logger.debug(
                "Task '{name}': Calling task.run() method...".format(
                    name=prefect.context.get("task_full_name", self.task.name)
                )
            )
            timeout_handler = (
                timeout_handler or prefect.utilities.executors.timeout_handler
            )
            raw_inputs = {k: r.value for k, r in inputs.items()}
            result = timeout_handler(
                self.task.run, timeout=self.task.timeout, **raw_inputs
            )

        except KeyboardInterrupt:
            self.logger.debug("Interrupt signal raised, cancelling task run.")
            state = Cancelled(message="Interrupt signal raised, cancelling task run.")
            return state

        # inform user of timeout
        except TimeoutError as exc:
            if prefect.context.get("raise_on_exception"):
                raise exc
            state = TimedOut(
                "Task timed out during execution.", result=exc, cached_inputs=inputs
            )
            return state

        except signals.LOOP as exc:
            new_state = exc.state
            assert isinstance(new_state, Looped)
            new_state.result = Result(
                value=new_state.result, result_handler=self.result_handler
            )
            new_state.message = exc.state.message or "Task is looping ({})".format(
                new_state.loop_count
            )
            return new_state

        result = Result(value=result, result_handler=self.result_handler)
        state = Success(result=result, message="Task run succeeded.")

        ## only checkpoint tasks if checkpointing is turned on
        if (
            state.is_successful()
            and prefect.context.get("checkpointing") is True
            and self.task.checkpoint is True
        ):
            state._result.store_safe_value()

        return state
Exemplo n.º 4
0
        assert issubclass(Skipped, Finished)

    def test_skipped_is_success(self):
        assert issubclass(Skipped, Success)

    def test_timedout_is_failed(self):
        assert issubclass(TimedOut, Failed)

    def test_trigger_failed_is_failed(self):
        assert issubclass(TriggerFailed, Failed)


@pytest.mark.parametrize(
    "state_check",
    [
        dict(state=Cancelled(), assert_true={"is_finished"}),
        dict(state=Cached(),
             assert_true={"is_cached", "is_finished", "is_successful"}),
        dict(state=ClientFailed(), assert_true={"is_meta_state"}),
        dict(state=Failed(), assert_true={"is_finished", "is_failed"}),
        dict(state=Finished(), assert_true={"is_finished"}),
        dict(state=Looped(), assert_true={"is_finished", "is_looped"}),
        dict(state=Mapped(),
             assert_true={"is_finished", "is_mapped", "is_successful"}),
        dict(state=Paused(), assert_true={"is_pending", "is_scheduled"}),
        dict(state=Pending(), assert_true={"is_pending"}),
        dict(state=Queued(), assert_true={"is_meta_state", "is_queued"}),
        dict(state=Resume(), assert_true={"is_pending", "is_scheduled"}),
        dict(state=Retrying(),
             assert_true={"is_pending", "is_scheduled", "is_retrying"}),
        dict(state=Running(), assert_true={"is_running"}),
Exemplo n.º 5
0
    def check_for_cancellation(self) -> Iterator:
        """Contextmanager used to wrap a cancellable section of a flow run."""

        cancelling = False
        done = threading.Event()
        flow_run_version = None
        context = prefect.context.to_dict()

        def interrupt_if_cancelling() -> None:
            # We need to copy the context into this thread, since context is a
            # thread local.
            with prefect.context(context):
                flow_run_id = prefect.context["flow_run_id"]
                while True:
                    exiting_context = done.wait(
                        prefect.config.cloud.check_cancellation_interval
                    )
                    try:
                        self.logger.debug("Checking flow run state...")
                        flow_run_info = self.client.get_flow_run_info(flow_run_id)
                    except Exception:
                        self.logger.warning(
                            "Error getting flow run info", exc_info=True
                        )
                        continue
                    if not flow_run_info.state.is_running():
                        self.logger.warning(
                            "Flow run is no longer in a running state; the current state is: %r",
                            flow_run_info.state,
                        )
                    if isinstance(flow_run_info.state, Cancelling):
                        self.logger.info(
                            "Flow run has been cancelled, cancelling active tasks"
                        )
                        nonlocal cancelling
                        nonlocal flow_run_version
                        cancelling = True
                        flow_run_version = flow_run_info.version
                        # If not already leaving context, raise KeyboardInterrupt in the main thread
                        if not exiting_context:
                            if hasattr(signal, "raise_signal"):
                                # New in python 3.8
                                signal.raise_signal(signal.SIGINT)  # type: ignore
                            else:
                                if os.name == "nt":
                                    # This doesn't actually send a signal, so it will only
                                    # interrupt the next Python bytecode instruction - if the
                                    # main thread is blocked in a c extension the interrupt
                                    # won't be seen until that returns.
                                    from _thread import interrupt_main

                                    interrupt_main()
                                else:
                                    signal.pthread_kill(
                                        threading.main_thread().ident, signal.SIGINT  # type: ignore
                                    )
                        break
                    elif exiting_context:
                        break

        thread = threading.Thread(target=interrupt_if_cancelling, daemon=True)
        thread.start()
        try:
            yield
        except KeyboardInterrupt:
            if not cancelling:
                raise
        finally:
            done.set()
            thread.join()
            if cancelling:
                prefect.context.update(flow_run_version=flow_run_version)
                raise ENDRUN(state=Cancelled("Flow run is cancelled"))
Exemplo n.º 6
0
    def run(
        self,
        state: State = None,
        task_states: Dict[Task, State] = None,
        return_tasks: Iterable[Task] = None,
        parameters: Dict[str, Any] = None,
        task_runner_state_handlers: Iterable[Callable] = None,
        executor: "prefect.engine.executors.Executor" = None,
        context: Dict[str, Any] = None,
        task_contexts: Dict[Task, Dict[str, Any]] = None,
    ) -> State:
        """
        The main endpoint for FlowRunners.  Calling this method will perform all
        computations contained within the Flow and return the final state of the Flow.

        Args:
            - state (State, optional): starting state for the Flow. Defaults to
                `Pending`
            - task_states (dict, optional): dictionary of task states to begin
                computation with, with keys being Tasks and values their corresponding state
            - return_tasks ([Task], optional): list of Tasks to include in the
                final returned Flow state. Defaults to `None`
            - parameters (dict, optional): dictionary of any needed Parameter
                values, with keys being strings representing Parameter names and values being
                their corresponding values
            - task_runner_state_handlers (Iterable[Callable], optional): A list of state change
                handlers that will be provided to the task_runner, and called whenever a task
                changes state.
            - executor (Executor, optional): executor to use when performing
                computation; defaults to the executor specified in your prefect configuration
            - context (Dict[str, Any], optional): prefect.Context to use for execution
                to use for each Task run
            - task_contexts (Dict[Task, Dict[str, Any]], optional): contexts that will be
                provided to each task

        Returns:
            - State: `State` representing the final post-run state of the `Flow`.

        """

        self.logger.info("Beginning Flow run for '{}'".format(self.flow.name))

        # make copies to avoid modifying user inputs
        task_states = dict(task_states or {})
        context = dict(context or {})
        task_contexts = dict(task_contexts or {})
        parameters = dict(parameters or {})
        if executor is None:
            executor = prefect.engine.get_default_executor_class()()

        try:
            state, task_states, context, task_contexts = self.initialize_run(
                state=state,
                task_states=task_states,
                context=context,
                task_contexts=task_contexts,
                parameters=parameters,
            )

            with prefect.context(context):
                state = self.check_flow_is_pending_or_running(state)
                state = self.check_flow_reached_start_time(state)
                state = self.set_flow_to_running(state)
                state = self.get_flow_run_state(
                    state,
                    task_states=task_states,
                    task_contexts=task_contexts,
                    return_tasks=return_tasks,
                    task_runner_state_handlers=task_runner_state_handlers,
                    executor=executor,
                )

        except ENDRUN as exc:
            state = exc.state

        except KeyboardInterrupt:
            self.logger.debug("Interrupt signal raised, cancelling Flow run.")
            state = Cancelled(
                message="Interrupt signal raised, cancelling flow run.")

        # All other exceptions are trapped and turned into Failed states
        except Exception as exc:
            self.logger.exception(
                "Unexpected error while running flow: {}".format(repr(exc)))
            if prefect.context.get("raise_on_exception"):
                raise exc
            new_state = Failed(
                message="Unexpected error while running flow: {}".format(
                    repr(exc)),
                result=exc,
            )
            state = self.handle_state_change(state or Pending(), new_state)

        return state
Exemplo n.º 7
0
    def get_task_run_state(
        self,
        state: State,
        inputs: Dict[str, Result],
    ) -> State:
        """
        Runs the task and traps any signals or errors it raises.
        Also checkpoints the result of a successful task, if `task.checkpoint` is `True`.

        Args:
            - state (State): the current state of this task
            - inputs (Dict[str, Result], optional): a dictionary of inputs whose keys correspond
                to the task's `run()` arguments.

        Returns:
            - State: the state of the task after running the check

        Raises:
            - signals.PAUSE: if the task raises PAUSE
            - ENDRUN: if the task is not ready to run
        """
        if not state.is_running():
            self.logger.debug(
                "Task '{name}': can't run task because it's not in a "
                "Running state; ending run.".format(name=prefect.context.get(
                    "task_full_name", self.task.name)))

            raise ENDRUN(state)

        value = None
        raw_inputs = {k: r.value for k, r in inputs.items()}
        try:
            self.logger.debug(
                "Task '{name}': Calling task.run() method...".format(
                    name=prefect.context.get("task_full_name",
                                             self.task.name)))
            timeout_handler = prefect.utilities.executors.timeout_handler
            if getattr(self.task, "log_stdout", False):
                with redirect_stdout(
                        prefect.utilities.logging.RedirectToLog(
                            self.logger)  # type: ignore
                ):
                    value = timeout_handler(self.task.run,
                                            timeout=self.task.timeout,
                                            **raw_inputs)
            else:
                value = timeout_handler(self.task.run,
                                        timeout=self.task.timeout,
                                        **raw_inputs)

        except KeyboardInterrupt:
            self.logger.debug("Interrupt signal raised, cancelling task run.")
            state = Cancelled(
                message="Interrupt signal raised, cancelling task run.")
            return state

        # inform user of timeout
        except TimeoutError as exc:
            if prefect.context.get("raise_on_exception"):
                raise exc
            state = TimedOut(
                "Task timed out during execution.",
                result=exc,
            )
            return state

        except signals.LOOP as exc:
            new_state = exc.state
            assert isinstance(new_state, Looped)
            new_state.result = self.result.from_value(value=new_state.result)
            new_state.message = exc.state.message or "Task is looping ({})".format(
                new_state.loop_count)
            return new_state

        # checkpoint tasks if a result is present, except for when the user has opted out by
        # disabling checkpointing
        if (prefect.context.get("checkpointing") is True
                and self.task.checkpoint is not False and value is not None):
            try:
                formatting_kwargs = {
                    **prefect.context.get("parameters", {}).copy(),
                    **raw_inputs,
                    **prefect.context,
                }
                result = self.result.write(
                    value,
                    **formatting_kwargs,
                )
            except NotImplementedError:
                result = self.result.from_value(value=value)
        else:
            result = self.result.from_value(value=value)

        state = Success(
            result=result,
            message="Task run succeeded.",
        )
        return state
class TestCloudFlowRunnerQueuedState:
    queue_time = 55
    check_cancellation_interval = 8

    def do_mocked_run(self,
                      client,
                      monkeypatch,
                      n_attempts=None,
                      n_queries=None,
                      query_end_state=None):
        """Mock out a cloud flow run that starts in a queued state and either
        succeeds or exits early due to a state change."""
        mock_sleep = MagicMock()

        def run(*args, **kwargs):
            if n_attempts is None or mock_run.call_count < n_attempts:
                info = get_flow_run_info()
                if info.state.is_queued():
                    return Queued(start_time=pendulum.now("UTC").add(
                        seconds=self.queue_time))
                return info.state
            return Success()

        mock_run = MagicMock(side_effect=run)

        def get_flow_run_info(*args, **kwargs):
            if n_queries is None or mock_get_flow_run_info.call_count < n_queries:
                state = Queued()
            else:
                state = query_end_state
            return MagicMock(version=mock_get_flow_run_info.call_count,
                             state=state)

        mock_get_flow_run_info = MagicMock(side_effect=get_flow_run_info)

        client.get_flow_run_info = mock_get_flow_run_info
        monkeypatch.setattr("prefect.engine.cloud.flow_runner.FlowRunner.run",
                            mock_run)
        monkeypatch.setattr("prefect.engine.cloud.flow_runner.time_sleep",
                            mock_sleep)

        @prefect.task
        def return_one():
            return 1

        with prefect.Flow("test-cloud-flow-runner-with-queues") as flow:
            return_one()

        with set_temporary_config({
                "cloud.check_cancellation_interval":
                self.check_cancellation_interval
        }):
            state = CloudFlowRunner(flow=flow).run()
        return state, mock_sleep, mock_run

    @pytest.mark.parametrize("n_attempts", [5, 10])
    def test_rety_queued_state_until_success(self, client, monkeypatch,
                                             n_attempts):
        state, mock_sleep, mock_run = self.do_mocked_run(client,
                                                         monkeypatch,
                                                         n_attempts=n_attempts)

        assert state.is_successful()
        assert mock_run.call_count == n_attempts
        sleep_times = [i[0][0] for i in mock_sleep.call_args_list]
        assert max(sleep_times) == self.check_cancellation_interval
        total_sleep_time = sum(sleep_times)
        expected_sleep_time = (n_attempts - 1) * self.queue_time
        # Slept for approximately the right amount of time. Due to processing time,
        # the amount of time spent in sleep may be slightly less.
        assert expected_sleep_time - 2 < total_sleep_time < expected_sleep_time + 2

    @pytest.mark.parametrize("n_queries", [5, 10])
    @pytest.mark.parametrize("final_state", [Cancelled(), Success()])
    def test_exit_queued_loop_early_if_no_longer_queued(
            self, client, monkeypatch, n_queries, final_state):
        state, mock_sleep, mock_run = self.do_mocked_run(
            client,
            monkeypatch,
            n_queries=n_queries,
            query_end_state=final_state)

        assert type(state) == type(final_state)
        sleep_times = [i[0][0] for i in mock_sleep.call_args_list]
        assert max(sleep_times) == self.check_cancellation_interval
        total_sleep_time = sum(sleep_times)
        expected_sleep_time = n_queries * self.check_cancellation_interval
        # Slept for approximately the right amount of time. Due to processing time,
        # the amount of time spent in sleep may be slightly less.
        assert expected_sleep_time - 2 < total_sleep_time < expected_sleep_time + 2