예제 #1
0
def checkpoint_handler(task_runner: DSTaskRunner, old_state: State, new_state: State) -> State:
    """
    A handler designed to implement result caching by filename. If the result handler's ``read``
    method can be successfully run, this handler loads the result of that method as the task result
    and sets the task state to ``Success``. Similarly, on successful
    completion of the task, if the task was actually run and not loaded from cache, this handler
    will apply the result handler's ``write`` method to the task.

    Parameters
    ----------
    task_runner : instance of DSTaskRunner
        The task runner associated with the flow the handler is used in.
    old_state : instance of prefect.engine.state.State
        The current state of the task.
    new_state : instance of prefect.engine.state.State
        The expected new state of the task.

    Returns
    -------
    new_state : instance of prefect.engine.state.State
        The actual new state of the task.
    """
    if "PREFECT__FLOWS__CHECKPOINTING" in os.environ and os.environ["PREFECT__FLOWS__CHECKPOINTING"] == "true":
        raise AttributeError("Cannot use standard prefect checkpointing with this handler")

    if task_runner.result_handler is not None and old_state.is_pending() and new_state.is_running():
        if not hasattr(task_runner, "upstream_states"):
            raise TypeError(
                "upstream_states not found in task runner. Make sure to use "
                "prefect_ds.task_runner.DSTaskRunner."
            )
        input_mapping = _create_input_mapping(task_runner.upstream_states)
        try:
            data = task_runner.task.result_handler.read(input_mapping=input_mapping)
        except FileNotFoundError:
            return new_state
        except TypeError: # unexpected argument input_mapping
            raise TypeError(
                "Result handler could not accept input_mapping argument. "
                "Please ensure that you are using a handler from prefect_ds."
            )
        result = Result(value=data, result_handler=task_runner.task.result_handler)
        state = Success(result=result, message="Task loaded from disk.")
        return state

    if task_runner.result_handler is not None and old_state.is_running() and new_state.is_successful():
        input_mapping = _create_input_mapping(task_runner.upstream_states)
        task_runner.task.result_handler.write(new_state.result, input_mapping=input_mapping)

    return new_state
예제 #2
0
    def check_flow_is_pending_or_running(self, state: State) -> State:
        """
        Checks if the flow is in either a Pending state or Running state. Either are valid
        starting points (because we allow simultaneous runs of the same flow run).

        Args:
            - state (State): the current state of this flow

        Returns:
            - State: the state of the flow after running the check

        Raises:
            - ENDRUN: if the flow is not pending or running
        """

        # the flow run is already finished
        if state.is_finished() is True:
            self.logger.info("Flow run has already finished.")
            raise ENDRUN(state)

        # the flow run must be either pending or running (possibly redundant with above)
        elif not (state.is_pending() or state.is_running()):
            self.logger.info("Flow is not ready to run.")
            raise ENDRUN(state)

        return state
예제 #3
0
    def get_task_run_state(
        self,
        state: State,
        inputs: Dict[str, Result],
        timeout_handler: Optional[Callable],
    ) -> State:
        """
        Runs the task and traps any signals or errors it raises.
        Also checkpoints the result of a successful task, if `task.checkpoint` is `True`.

        Args:
            - state (State): the current state of this task
            - inputs (Dict[str, Result], optional): a dictionary of inputs whose keys correspond
                to the task's `run()` arguments.
            - timeout_handler (Callable, optional): function for timing out
                task execution, with call signature `handler(fn, *args, **kwargs)`. Defaults to
                `prefect.utilities.executors.main_thread_timeout`

        Returns:
            - State: the state of the task after running the check

        Raises:
            - signals.PAUSE: if the task raises PAUSE
            - ENDRUN: if the task is not ready to run
        """
        if not state.is_running():
            self.logger.debug(
                "Task '{name}': can't run task because it's not in a "
                "Running state; ending run.".format(name=prefect.context.get(
                    "task_full_name", self.task.name)))

            raise ENDRUN(state)

        try:
            self.logger.debug(
                "Task '{name}': Calling task.run() method...".format(
                    name=prefect.context.get("task_full_name",
                                             self.task.name)))
            timeout_handler = timeout_handler or main_thread_timeout
            raw_inputs = {k: r.value for k, r in inputs.items()}
            result = timeout_handler(self.task.run,
                                     timeout=self.task.timeout,
                                     **raw_inputs)

        # inform user of timeout
        except TimeoutError as exc:
            if prefect.context.get("raise_on_exception"):
                raise exc
            state = TimedOut("Task timed out during execution.",
                             result=exc,
                             cached_inputs=inputs)
            return state

        result = Result(value=result, result_handler=self.result_handler)
        state = Success(result=result, message="Task run succeeded.")

        if state.is_successful() and self.task.checkpoint is True:
            state._result.store_safe_value()

        return state
예제 #4
0
async def set_task_run_state(task_run_id: str, state: State, force=False) -> None:
    """
    Updates a task run state.

    Args:
        - task_run_id (str): the task run id to update
        - state (State): the new state
        - false (bool): if True, avoids pipeline checks
    """

    if task_run_id is None:
        raise ValueError(f"Invalid task run ID.")

    task_run = await models.TaskRun.where({"id": {"_eq": task_run_id},}).first(
        {
            "id": True,
            "version": True,
            "state": True,
            "serialized_state": True,
            "flow_run": {"id": True, "state": True},
        }
    )

    if not task_run:
        raise ValueError(f"Invalid task run ID: {task_run_id}.")

    # ------------------------------------------------------
    # if the state is running, ensure the flow run is also running
    # ------------------------------------------------------
    if not force and state.is_running() and task_run.flow_run.state != "Running":
        raise ValueError(
            f"State update failed for task run ID {task_run_id}: provided "
            f"a running state but associated flow run {task_run.flow_run.id} is not "
            "in a running state."
        )

    # ------------------------------------------------------
    # if we have cached inputs on the old state, we need to carry them forward
    # ------------------------------------------------------
    if not state.cached_inputs and task_run.serialized_state.get("cached_inputs", None):
        # load up the old state's cached inputs and apply them to the new state
        serialized_state = state_schema.load(task_run.serialized_state)
        state.cached_inputs = serialized_state.cached_inputs

    # --------------------------------------------------------
    # prepare the new state for the database
    # --------------------------------------------------------

    task_run_state = models.TaskRunState(
        task_run_id=task_run.id,
        version=(task_run.version or 0) + 1,
        timestamp=pendulum.now("UTC"),
        message=state.message,
        result=state.result,
        start_time=getattr(state, "start_time", None),
        state=type(state).__name__,
        serialized_state=state.serialize(),
    )

    await task_run_state.insert()
예제 #5
0
    def check_task_is_ready(self, state: State) -> State:
        """
        Checks to make sure the task is ready to run (Pending or Mapped).

        Args:
            - state (State): the current state of this task

        Returns:
            - State: the state of the task after running the check

        Raises:
            - ENDRUN: if the task is not ready to run
        """

        # the task is ready
        if state.is_pending():
            return state

        # the task is mapped, in which case we still proceed so that the children tasks
        # are generated (note that if the children tasks)
        elif state.is_mapped():
            self.logger.debug(
                "Task '{name}': task is mapped, but run will proceed so children are generated.".format(
                    name=prefect.context.get("task_full_name", self.task.name)
                )
            )
            return state

        # this task is already running
        elif state.is_running():
            self.logger.debug(
                "Task '{name}': task is already running.".format(
                    name=prefect.context.get("task_full_name", self.task.name)
                )
            )
            raise ENDRUN(state)

        elif state.is_cached():
            return state

        # this task is already finished
        elif state.is_finished():
            self.logger.debug(
                "Task '{name}': task is already finished.".format(
                    name=prefect.context.get("task_full_name", self.task.name)
                )
            )
            raise ENDRUN(state)

        # this task is not pending
        else:
            self.logger.debug(
                "Task '{name}' is not ready to run or state was unrecognized ({state}).".format(
                    name=prefect.context.get("task_full_name", self.task.name),
                    state=state,
                )
            )
            raise ENDRUN(state)
예제 #6
0
    def set_flow_to_running(self, state: State) -> State:
        """
        Puts Pending flows in a Running state; leaves Running flows Running.

        Args:
            - state (State): the current state of this flow

        Returns:
            - State: the state of the flow after running the check

        Raises:
            - ENDRUN: if the flow is not pending or running
        """
        if state.is_pending():
            return Running(message="Running flow.")
        elif state.is_running():
            return state
        else:
            raise ENDRUN(state)
예제 #7
0
    def get_flow_run_state(
        self,
        state: State,
        task_states: Dict[Task, State],
        task_contexts: Dict[Task, Dict[str, Any]],
        return_tasks: Set[Task],
        task_runner_state_handlers: Iterable[Callable],
        executor: "prefect.executors.base.Executor",
    ) -> State:
        """
        Runs the flow.

        Args:
            - state (State): starting state for the Flow. Defaults to
                `Pending`
            - task_states (dict): dictionary of task states to begin
                computation with, with keys being Tasks and values their corresponding state
            - task_contexts (Dict[Task, Dict[str, Any]]): contexts that will be provided to
                each task
            - return_tasks ([Task], optional): list of Tasks to include in the
                final returned Flow state. Defaults to `None`
            - task_runner_state_handlers (Iterable[Callable]): A list of state change handlers
                that will be provided to the task_runner, and called whenever a task changes
                state.
            - executor (Executor): executor to use when performing computation; defaults to the
                executor provided in your prefect configuration

        Returns:
            - State: `State` representing the final post-run state of the `Flow`.

        """
        # this dictionary is used for tracking the states of "children" mapped tasks;
        # when running on Dask, we want to avoid serializing futures, so instead
        # of storing child task states in the `map_states` attribute we instead store
        # in this dictionary and only after they are resolved do we attach them to the Mapped state
        mapped_children = dict()  # type: Dict[Task, list]

        if not state.is_running():
            self.logger.info("Flow is not in a Running state.")
            raise ENDRUN(state)

        if return_tasks is None:
            return_tasks = set()
        if set(return_tasks).difference(self.flow.tasks):
            raise ValueError("Some tasks in return_tasks were not found in the flow.")

        def extra_context(task: Task, task_index: int = None) -> dict:
            return {
                "task_name": task.name,
                "task_tags": task.tags,
                "task_index": task_index,
            }

        # -- process each task in order

        with self.check_for_cancellation(), executor.start():

            for task in self.flow.sorted_tasks():
                task_state = task_states.get(task)

                # if a task is a constant task, we already know its return value
                # no need to use up resources by running it through a task runner
                if task_state is None and isinstance(
                    task, prefect.tasks.core.constants.Constant
                ):
                    task_states[task] = task_state = Success(result=task.value)

                # Always restart completed resource setup/cleanup tasks and
                # secret tasks unless they were explicitly cached.
                # TODO: we only need to rerun these tasks if any pending
                # downstream tasks depend on them.
                if (
                    isinstance(
                        task,
                        (
                            prefect.tasks.core.resource_manager.ResourceSetupTask,
                            prefect.tasks.core.resource_manager.ResourceCleanupTask,
                            prefect.tasks.secrets.SecretBase,
                        ),
                    )
                    and task_state is not None
                    and task_state.is_finished()
                    and not task_state.is_cached()
                ):
                    task_states[task] = task_state = Pending()

                # if the state is finished, don't run the task, just use the provided state if
                # the state is cached / mapped, we still want to run the task runner pipeline
                # steps to either ensure the cache is still valid / or to recreate the mapped
                # pipeline for possible retries
                if (
                    isinstance(task_state, State)
                    and task_state.is_finished()
                    and not task_state.is_cached()
                    and not task_state.is_mapped()
                ):
                    continue

                upstream_states = {}  # type: Dict[Edge, State]

                # this dictionary is used exclusively for "reduce" tasks in particular we store
                # the states / futures corresponding to the upstream children, and if running
                # on Dask, let Dask resolve them at the appropriate time.
                # Note: this is an optimization that allows Dask to resolve the mapped
                # dependencies by "elevating" them to a function argument.
                upstream_mapped_states = {}  # type: Dict[Edge, list]

                # -- process each edge to the task
                for edge in self.flow.edges_to(task):

                    # load the upstream task states (supplying Pending as a default)
                    upstream_states[edge] = task_states.get(
                        edge.upstream_task, Pending(message="Task state not available.")
                    )

                    # if the edge is flattened and not the result of a map, then we
                    # preprocess the upstream states. If it IS the result of a
                    # map, it will be handled in `prepare_upstream_states_for_mapping`
                    if edge.flattened:
                        if not isinstance(upstream_states[edge], Mapped):
                            upstream_states[edge] = executor.submit(
                                executors.flatten_upstream_state, upstream_states[edge]
                            )

                    # this checks whether the task is a "reduce" task for a mapped pipeline
                    # and if so, collects the appropriate upstream children
                    if not edge.mapped and isinstance(upstream_states[edge], Mapped):
                        children = mapped_children.get(edge.upstream_task, [])

                        # if the edge is flattened, then we need to wait for the mapped children
                        # to complete and then flatten them
                        if edge.flattened:
                            children = executors.flatten_mapped_children(
                                mapped_children=children, executor=executor
                            )

                        upstream_mapped_states[edge] = children

                # augment edges with upstream constants
                for key, val in self.flow.constants[task].items():
                    edge = Edge(
                        upstream_task=prefect.tasks.core.constants.Constant(val),
                        downstream_task=task,
                        key=key,
                    )
                    upstream_states[edge] = Success(
                        "Auto-generated constant value",
                        result=ConstantResult(value=val),
                    )

                # handle mapped tasks
                if any(edge.mapped for edge in upstream_states.keys()):

                    # wait on upstream states to determine the width of the pipeline
                    # this is the key to depth-first execution
                    upstream_states = executor.wait(
                        {e: state for e, state in upstream_states.items()}
                    )
                    # we submit the task to the task runner to determine if
                    # we can proceed with mapping - if the new task state is not a Mapped
                    # state then we don't proceed
                    task_states[task] = executor.wait(
                        executor.submit(
                            run_task,
                            task=task,
                            state=task_state,  # original state
                            upstream_states=upstream_states,
                            context=dict(
                                prefect.context, **task_contexts.get(task, {})
                            ),
                            flow_result=self.flow.result,
                            task_runner_cls=self.task_runner_cls,
                            task_runner_state_handlers=task_runner_state_handlers,
                            upstream_mapped_states=upstream_mapped_states,
                            is_mapped_parent=True,
                            extra_context=extra_context(task),
                        )
                    )

                    # either way, we should now have enough resolved states to restructure
                    # the upstream states into a list of upstream state dictionaries to iterate over
                    list_of_upstream_states = (
                        executors.prepare_upstream_states_for_mapping(
                            task_states[task],
                            upstream_states,
                            mapped_children,
                            executor=executor,
                        )
                    )

                    submitted_states = []

                    for idx, states in enumerate(list_of_upstream_states):
                        # if we are on a future rerun of a partially complete flow run,
                        # there might be mapped children in a retrying state; this check
                        # looks into the current task state's map_states for such info
                        if (
                            isinstance(task_state, Mapped)
                            and len(task_state.map_states) >= idx + 1
                        ):
                            current_state = task_state.map_states[
                                idx
                            ]  # type: Optional[State]
                        elif isinstance(task_state, Mapped):
                            current_state = None
                        else:
                            current_state = task_state

                        # this is where each child is submitted for actual work
                        submitted_states.append(
                            executor.submit(
                                run_task,
                                task=task,
                                state=current_state,
                                upstream_states=states,
                                context=dict(
                                    prefect.context,
                                    **task_contexts.get(task, {}),
                                    map_index=idx,
                                ),
                                flow_result=self.flow.result,
                                task_runner_cls=self.task_runner_cls,
                                task_runner_state_handlers=task_runner_state_handlers,
                                upstream_mapped_states=upstream_mapped_states,
                                extra_context=extra_context(task, task_index=idx),
                            )
                        )
                    if isinstance(task_states.get(task), Mapped):
                        mapped_children[task] = submitted_states  # type: ignore

                else:
                    task_states[task] = executor.submit(
                        run_task,
                        task=task,
                        state=task_state,
                        upstream_states=upstream_states,
                        context=dict(prefect.context, **task_contexts.get(task, {})),
                        flow_result=self.flow.result,
                        task_runner_cls=self.task_runner_cls,
                        task_runner_state_handlers=task_runner_state_handlers,
                        upstream_mapped_states=upstream_mapped_states,
                        extra_context=extra_context(task),
                    )

            # ---------------------------------------------
            # Collect results
            # ---------------------------------------------

            # terminal tasks determine if the flow is finished
            terminal_tasks = self.flow.terminal_tasks()

            # reference tasks determine flow state
            reference_tasks = self.flow.reference_tasks()

            # wait until all terminal tasks are finished
            final_tasks = terminal_tasks.union(reference_tasks).union(return_tasks)
            final_states = executor.wait(
                {
                    t: task_states.get(t, Pending("Task not evaluated by FlowRunner."))
                    for t in final_tasks
                }
            )

            # also wait for any children of Mapped tasks to finish, and add them
            # to the dictionary to determine flow state
            all_final_states = final_states.copy()
            for t, s in list(final_states.items()):
                if s.is_mapped():
                    # ensure we wait for any mapped children to complete
                    if t in mapped_children:
                        s.map_states = executor.wait(mapped_children[t])
                    s.result = [ms.result for ms in s.map_states]
                    all_final_states[t] = s.map_states

            assert isinstance(final_states, dict)

        key_states = set(flatten_seq([all_final_states[t] for t in reference_tasks]))
        terminal_states = set(
            flatten_seq([all_final_states[t] for t in terminal_tasks])
        )
        return_states = {t: final_states[t] for t in return_tasks}

        state = self.determine_final_state(
            state=state,
            key_states=key_states,
            return_states=return_states,
            terminal_states=terminal_states,
        )

        return state
예제 #8
0
    def get_task_run_state(self, state: State, inputs: Dict[str,
                                                            Result]) -> State:
        """
        Runs the task and traps any signals or errors it raises.
        Also checkpoints the result of a successful task, if `task.checkpoint` is `True`.

        Args:
            - state (State): the current state of this task
            - inputs (Dict[str, Result], optional): a dictionary of inputs whose keys correspond
                to the task's `run()` arguments.

        Returns:
            - State: the state of the task after running the check

        Raises:
            - signals.PAUSE: if the task raises PAUSE
            - ENDRUN: if the task is not ready to run
        """
        if not state.is_running():
            self.logger.debug(
                "Task '{name}': can't run task because it's not in a "
                "Running state; ending run.".format(name=prefect.context.get(
                    "task_full_name", self.task.name)))

            raise ENDRUN(state)

        value = None
        raw_inputs = {k: r.value for k, r in inputs.items()}
        try:
            self.logger.debug(
                "Task '{name}': Calling task.run() method...".format(
                    name=prefect.context.get("task_full_name",
                                             self.task.name)))
            timeout_handler = prefect.utilities.executors.timeout_handler
            if getattr(self.task, "log_stdout", False):
                with redirect_stdout(
                        prefect.utilities.logging.RedirectToLog(
                            self.logger)  # type: ignore
                ):
                    value = timeout_handler(self.task.run,
                                            timeout=self.task.timeout,
                                            **raw_inputs)
            else:
                value = timeout_handler(self.task.run,
                                        timeout=self.task.timeout,
                                        **raw_inputs)

        # inform user of timeout
        except TimeoutError as exc:
            if prefect.context.get("raise_on_exception"):
                raise exc
            state = TimedOut("Task timed out during execution.", result=exc)
            return state

        except signals.LOOP as exc:
            new_state = exc.state
            assert isinstance(new_state, Looped)
            new_state.result = self.result.from_value(value=new_state.result)
            new_state.message = exc.state.message or "Task is looping ({})".format(
                new_state.loop_count)
            return new_state

        # checkpoint tasks if a result is present, except for when the user has opted out by
        # disabling checkpointing
        if (prefect.context.get("checkpointing") is True
                and self.task.checkpoint is not False and value is not None):
            try:
                formatting_kwargs = {
                    **prefect.context.get("parameters", {}).copy(),
                    **raw_inputs,
                    **prefect.context,
                }
                result = self.result.write(value, **formatting_kwargs)
            except NotImplementedError:
                result = self.result.from_value(value=value)
        else:
            result = self.result.from_value(value=value)

        state = Success(result=result, message="Task run succeeded.")
        return state
예제 #9
0
    def get_flow_run_state(
        self,
        state: State,
        task_states: Dict[Task, State],
        task_contexts: Dict[Task, Dict[str, Any]],
        return_tasks: Set[Task],
        task_runner_state_handlers: Iterable[Callable],
        executor: "prefect.engine.executors.base.Executor",
    ) -> State:
        """
        Runs the flow.

        Args:
            - state (State): starting state for the Flow. Defaults to
                `Pending`
            - task_states (dict): dictionary of task states to begin
                computation with, with keys being Tasks and values their corresponding state
            - task_contexts (Dict[Task, Dict[str, Any]]): contexts that will be provided to each task
            - return_tasks ([Task], optional): list of Tasks to include in the
                final returned Flow state. Defaults to `None`
            - task_runner_state_handlers (Iterable[Callable]): A list of state change
                handlers that will be provided to the task_runner, and called whenever a task changes
                state.
            - executor (Executor): executor to use when performing
                computation; defaults to the executor provided in your prefect configuration

        Returns:
            - State: `State` representing the final post-run state of the `Flow`.

        """

        if not state.is_running():
            self.logger.info("Flow is not in a Running state.")
            raise ENDRUN(state)

        if return_tasks is None:
            return_tasks = set()
        if set(return_tasks).difference(self.flow.tasks):
            raise ValueError(
                "Some tasks in return_tasks were not found in the flow.")

        # -- process each task in order

        with executor.start():

            for task in self.flow.sorted_tasks():

                task_state = task_states.get(task)
                if task_state is None and isinstance(
                        task, prefect.tasks.core.constants.Constant):
                    task_states[task] = task_state = Success(result=task.value)

                # if the state is finished, don't run the task, just use the provided state
                if (isinstance(task_state, State) and task_state.is_finished()
                        and not task_state.is_cached()
                        and not task_state.is_mapped()):
                    continue

                upstream_states = {
                }  # type: Dict[Edge, Union[State, Iterable]]

                # -- process each edge to the task
                for edge in self.flow.edges_to(task):
                    upstream_states[edge] = task_states.get(
                        edge.upstream_task,
                        Pending(message="Task state not available."))

                # -- run the task

                with prefect.context(task_full_name=task.name,
                                     task_tags=task.tags):
                    task_states[task] = executor.submit(
                        self.run_task,
                        task=task,
                        state=task_state,
                        upstream_states=upstream_states,
                        context=dict(prefect.context,
                                     **task_contexts.get(task, {})),
                        task_runner_state_handlers=task_runner_state_handlers,
                        executor=executor,
                    )

            # ---------------------------------------------
            # Collect results
            # ---------------------------------------------

            # terminal tasks determine if the flow is finished
            terminal_tasks = self.flow.terminal_tasks()

            # reference tasks determine flow state
            reference_tasks = self.flow.reference_tasks()

            # wait until all terminal tasks are finished
            final_tasks = terminal_tasks.union(reference_tasks).union(
                return_tasks)
            final_states = executor.wait({
                t:
                task_states.get(t,
                                Pending("Task not evaluated by FlowRunner."))
                for t in final_tasks
            })

            # also wait for any children of Mapped tasks to finish, and add them
            # to the dictionary to determine flow state
            all_final_states = final_states.copy()
            for t, s in list(final_states.items()):
                if s.is_mapped():
                    s.map_states = executor.wait(s.map_states)
                    s.result = [ms.result for ms in s.map_states]
                    all_final_states[t] = s.map_states

            assert isinstance(final_states, dict)

        key_states = set(
            flatten_seq([all_final_states[t] for t in reference_tasks]))
        terminal_states = set(
            flatten_seq([all_final_states[t] for t in terminal_tasks]))
        return_states = {t: final_states[t] for t in return_tasks}

        state = self.determine_final_state(
            state=state,
            key_states=key_states,
            return_states=return_states,
            terminal_states=terminal_states,
        )

        return state
예제 #10
0
    def get_task_run_state(self, state: State, inputs: Dict[str,
                                                            Result]) -> State:
        """
        Runs the task and traps any signals or errors it raises.
        Also checkpoints the result of a successful task, if `task.checkpoint` is `True`.

        Args:
            - state (State): the current state of this task
            - inputs (Dict[str, Result], optional): a dictionary of inputs whose keys correspond
                to the task's `run()` arguments.

        Returns:
            - State: the state of the task after running the check

        Raises:
            - signals.PAUSE: if the task raises PAUSE
            - ENDRUN: if the task is not ready to run
        """
        task_name = prefect.context.get("task_full_name", self.task.name)

        if not state.is_running():
            self.logger.debug(
                f"Task {task_name!r}: Can't run task because it's not in a Running "
                "state; ending run.")

            raise ENDRUN(state)

        value = None
        raw_inputs = {k: r.value for k, r in inputs.items()}
        new_state = None
        try:
            self.logger.debug(
                f"Task {task_name!r}: Calling task.run() method...")

            # Create a stdout redirect if the task has log_stdout enabled
            log_context = (
                redirect_stdout(
                    prefect.utilities.logging.RedirectToLog(self.logger))
                if getattr(self.task, "log_stdout", False) else nullcontext()
            )  # type: AbstractContextManager

            with log_context:
                value = prefect.utilities.executors.run_task_with_timeout(
                    task=self.task,
                    args=(),
                    kwargs=raw_inputs,
                    logger=self.logger,
                )

        except TaskTimeoutSignal as exc:  # Convert timeouts to a `TimedOut` state
            if prefect.context.get("raise_on_exception"):
                raise exc
            state = TimedOut("Task timed out during execution.", result=exc)
            return state

        except signals.LOOP as exc:  # Convert loop signals to a `Looped` state
            new_state = exc.state
            assert isinstance(new_state, Looped)
            value = new_state.result
            new_state.message = exc.state.message or "Task is looping ({})".format(
                new_state.loop_count)

        except signals.SUCCESS as exc:
            # Success signals can be treated like a normal result
            new_state = exc.state
            assert isinstance(new_state, Success)
            value = new_state.result

        except Exception as exc:  # Handle exceptions in the task
            if prefect.context.get("raise_on_exception"):
                raise
            self.logger.error(
                f"Task {task_name!r}: Exception encountered during task execution!",
                exc_info=True,
            )
            state = Failed(f"Error during execution of task: {exc!r}",
                           result=exc)
            return state

        # checkpoint tasks if a result is present, except for when the user has opted out by
        # disabling checkpointing
        if (prefect.context.get("checkpointing") is True
                and self.task.checkpoint is not False and value is not None):
            try:
                formatting_kwargs = {
                    **prefect.context.get("parameters", {}).copy(),
                    **prefect.context,
                    **raw_inputs,
                }
                result = self.result.write(value, **formatting_kwargs)
            except ResultNotImplementedError:
                result = self.result.from_value(value=value)
        else:
            result = self.result.from_value(value=value)

        if new_state is not None:
            new_state.result = result
            return new_state

        state = Success(result=result, message="Task run succeeded.")
        return state
예제 #11
0
    def get_task_run_state(
        self,
        state: State,
        inputs: Dict[str, Result],
        timeout_handler: Optional[Callable] = None,
    ) -> State:
        """
        Runs the task and traps any signals or errors it raises.
        Also checkpoints the result of a successful task, if `task.checkpoint` is `True`.

        Args:
            - state (State): the current state of this task
            - inputs (Dict[str, Result], optional): a dictionary of inputs whose keys correspond
                to the task's `run()` arguments.
            - timeout_handler (Callable, optional): function for timing out
                task execution, with call signature `handler(fn, *args, **kwargs)`. Defaults to
                `prefect.utilities.executors.timeout_handler`

        Returns:
            - State: the state of the task after running the check

        Raises:
            - signals.PAUSE: if the task raises PAUSE
            - ENDRUN: if the task is not ready to run
        """
        if not state.is_running():
            self.logger.debug(
                "Task '{name}': can't run task because it's not in a "
                "Running state; ending run.".format(name=prefect.context.get(
                    "task_full_name", self.task.name)))

            raise ENDRUN(state)

        try:
            self.logger.debug(
                "Task '{name}': Calling task.run() method...".format(
                    name=prefect.context.get("task_full_name",
                                             self.task.name)))
            timeout_handler = (timeout_handler
                               or prefect.utilities.executors.timeout_handler)
            raw_inputs = {k: r.value for k, r in inputs.items()}

            if getattr(self.task, "log_stdout", False):
                with redirect_stdout(
                        prefect.utilities.logging.RedirectToLog(
                            self.logger)):  # type: ignore
                    result = timeout_handler(self.task.run,
                                             timeout=self.task.timeout,
                                             **raw_inputs)
            else:
                result = timeout_handler(self.task.run,
                                         timeout=self.task.timeout,
                                         **raw_inputs)

        except KeyboardInterrupt:
            self.logger.debug("Interrupt signal raised, cancelling task run.")
            state = Cancelled(
                message="Interrupt signal raised, cancelling task run.")
            return state

        # inform user of timeout
        except TimeoutError as exc:
            if prefect.context.get("raise_on_exception"):
                raise exc
            state = TimedOut("Task timed out during execution.",
                             result=exc,
                             cached_inputs=inputs)
            return state

        except signals.LOOP as exc:
            new_state = exc.state
            assert isinstance(new_state, Looped)
            new_state.result = Result(value=new_state.result,
                                      result_handler=self.result_handler)
            new_state.cached_inputs = inputs
            new_state.message = exc.state.message or "Task is looping ({})".format(
                new_state.loop_count)
            return new_state

        result = Result(value=result, result_handler=self.result_handler)
        state = Success(result=result,
                        message="Task run succeeded.",
                        cached_inputs=inputs)

        ## checkpoint tasks if a result_handler is present, except for when the user has opted out by disabling checkpointing
        if (state.is_successful()
                and prefect.context.get("checkpointing") is True
                and self.task.checkpoint is not False
                and self.result_handler is not None):
            state._result.store_safe_value()

        return state
예제 #12
0
파일: states.py 프로젝트: vitasiku/server
async def set_flow_run_state(flow_run_id: str,
                             state: State,
                             version: int = None,
                             agent_id: str = None) -> models.FlowRunState:
    """
    Updates a flow run state.

    Args:
        - flow_run_id (str): the flow run id to update
        - state (State): the new state
        - version (int): a version to enforce version-locking
        - agent_id (str): the ID of an agent instance setting the state

    Returns:
        - models.FlowRunState
    """

    if flow_run_id is None:
        raise ValueError(f"Invalid flow run ID.")

    where = {
        "id": {
            "_eq": flow_run_id
        },
        "_or": [
            # EITHER version locking is enabled and versions match
            {
                "version": {
                    "_eq": version
                },
                "flow": {
                    "flow_group": {
                        "settings": {
                            "_contains": {
                                "version_locking_enabled": True
                            }
                        }
                    }
                },
            },
            # OR version locking is not enabled
            {
                "flow": {
                    "flow_group": {
                        "_not": {
                            "settings": {
                                "_contains": {
                                    "version_locking_enabled": True
                                }
                            }
                        }
                    }
                }
            },
        ],
    }

    flow_run = await models.FlowRun.where(where).first({
        "id": True,
        "state": True,
        "name": True,
        "version": True,
        "flow": {"id", "name", "flow_group_id", "version_group_id"},
        "tenant": {"id", "slug"},
    })

    if not flow_run:
        raise ValueError(f"State update failed for flow run ID {flow_run_id}")

    # --------------------------------------------------------
    # apply downstream updates
    # --------------------------------------------------------

    # FOR CANCELLED STATES:
    #   - set all non-finished task run states to Cancelled
    if isinstance(state, Cancelled):
        task_runs = await models.TaskRun.where({
            "flow_run_id": {
                "_eq": flow_run_id
            }
        }).get({"id", "serialized_state"})
        to_cancel = [
            t for t in task_runs
            if not state_schema.load(t.serialized_state).is_finished()
        ]
        # For a run with many tasks this may be a lot of tasks - at some point
        # we might want to batch this rather than kicking off lots of asyncio
        # tasks at once.
        await asyncio.gather(
            *(api.states.set_task_run_state(t.id, state) for t in to_cancel),
            return_exceptions=True,
        )

    # --------------------------------------------------------
    # insert the new state in the database
    # --------------------------------------------------------

    flow_run_state = models.FlowRunState(
        id=str(uuid.uuid4()),
        tenant_id=flow_run.tenant_id,
        flow_run_id=flow_run_id,
        version=(flow_run.version or 0) + 1,
        state=type(state).__name__,
        timestamp=pendulum.now("UTC"),
        message=state.message,
        result=state.result,
        start_time=getattr(state, "start_time", None),
        serialized_state=state.serialize(),
    )

    await flow_run_state.insert()

    # --------------------------------------------------------
    # apply downstream updates
    # --------------------------------------------------------

    # FOR RUNNING STATES:
    #   - update the flow run heartbeat
    if state.is_running() or state.is_submitted():
        await api.runs.update_flow_run_heartbeat(flow_run_id=flow_run_id)

    # Set agent ID on flow run when submitted by agent
    if state.is_submitted() and agent_id:
        await api.runs.update_flow_run_agent(flow_run_id=flow_run_id,
                                             agent_id=agent_id)

    # --------------------------------------------------------
    # call cloud hooks
    # --------------------------------------------------------

    event = events.FlowRunStateChange(
        flow_run=flow_run,
        state=flow_run_state,
        flow=flow_run.flow,
        tenant=flow_run.tenant,
    )

    asyncio.create_task(api.cloud_hooks.call_hooks(event))

    return flow_run_state
예제 #13
0
파일: states.py 프로젝트: vitasiku/server
async def set_task_run_state(
        task_run_id: str,
        state: State,
        version: int = None,
        flow_run_version: int = None) -> models.TaskRunState:
    """
    Updates a task run state.

    Args:
        - task_run_id (str): the task run id to update
        - state (State): the new state
        - version (int): a version to enforce version-locking
        - flow_run_version (int): a flow run version to enforce version-lockgin

    Returns:
        - models.TaskRunState
    """

    if task_run_id is None:
        raise ValueError(f"Invalid task run ID.")

    where = {
        "id": {
            "_eq": task_run_id
        },
        "_or": [
            {
                # EITHER version locking is enabled and the versions match
                "version": {
                    "_eq": version
                },
                "flow_run": {
                    "version": {
                        "_eq": flow_run_version
                    },
                    "flow": {
                        "flow_group": {
                            "settings": {
                                "_contains": {
                                    "version_locking_enabled": True
                                }
                            }
                        }
                    },
                },
            },
            # OR version locking is not enabled
            {
                "flow_run": {
                    "flow": {
                        "flow_group": {
                            "_not": {
                                "settings": {
                                    "_contains": {
                                        "version_locking_enabled": True
                                    }
                                }
                            }
                        }
                    }
                }
            },
        ],
    }

    task_run = await models.TaskRun.where(where).first({
        "id": True,
        "tenant_id": True,
        "version": True,
        "state": True,
        "serialized_state": True,
        "flow_run": {
            "id": True,
            "state": True
        },
    })

    if not task_run:
        raise ValueError(f"State update failed for task run ID {task_run_id}")

    # ------------------------------------------------------
    # if the state is running, ensure the flow run is also running
    # ------------------------------------------------------
    if state.is_running() and task_run.flow_run.state != "Running":
        raise ValueError(
            f"State update failed for task run ID {task_run_id}: provided "
            f"a running state but associated flow run {task_run.flow_run.id} is not "
            "in a running state.")

    # ------------------------------------------------------
    # if we have cached inputs on the old state, we need to carry them forward
    # ------------------------------------------------------
    if not state.cached_inputs and task_run.serialized_state.get(
            "cached_inputs", None):
        # load up the old state's cached inputs and apply them to the new state
        serialized_state = state_schema.load(task_run.serialized_state)
        state.cached_inputs = serialized_state.cached_inputs

    # --------------------------------------------------------
    # prepare the new state for the database
    # --------------------------------------------------------

    task_run_state = models.TaskRunState(
        id=str(uuid.uuid4()),
        tenant_id=task_run.tenant_id,
        task_run_id=task_run.id,
        version=(task_run.version or 0) + 1,
        timestamp=pendulum.now("UTC"),
        message=state.message,
        result=state.result,
        start_time=getattr(state, "start_time", None),
        state=type(state).__name__,
        serialized_state=state.serialize(),
    )

    await task_run_state.insert()

    # --------------------------------------------------------
    # apply downstream updates
    # --------------------------------------------------------

    # FOR RUNNING STATES:
    #   - update the task run heartbeat
    if state.is_running():
        await api.runs.update_task_run_heartbeat(task_run_id=task_run_id)

    return task_run_state