Пример #1
0
class TaskRunner(Runner):
    """
    TaskRunners handle the execution of Tasks and determine the State of a Task
    before, during and after the Task is run.

    In particular, through the TaskRunner you can specify the states of any upstream dependencies
    and what state the Task should be initialized with.

    Args:
        - task (Task): the Task to be run / executed
        - state_handlers (Iterable[Callable], optional): A list of state change handlers that
            will be called whenever the task changes state, providing an opportunity to inspect
            or modify the new state. The handler will be passed the task runner instance, the
            old (prior) state, and the new (current) state, with the following signature:
            `state_handler(TaskRunner, old_state, new_state) -> Optional[State]`; If multiple
            functions are passed, then the `new_state` argument will be the result of the
            previous handler.
        - flow_result: the result instance configured for the flow (if any)
    """
    def __init__(
        self,
        task: Task,
        state_handlers: Iterable[Callable] = None,
        flow_result: Result = None,
    ):
        self.context = prefect.context.to_dict()
        self.task = task

        # Use result from task over the one provided off the parent Flow object
        if task.result:
            self.result = task.result
        else:
            self.result = Result().copy(
            ) if flow_result is None else flow_result.copy()

        self.flow_result = flow_result
        super().__init__(state_handlers=state_handlers)

    def __repr__(self) -> str:
        return "<{}: {}>".format(type(self).__name__, self.task.name)

    def call_runner_target_handlers(self, old_state: State,
                                    new_state: State) -> State:
        """
        A special state handler that the TaskRunner uses to call its task's state handlers.
        This method is called as part of the base Runner's `handle_state_change()` method.

        Args:
            - old_state (State): the old (previous) state
            - new_state (State): the new (current) state

        Returns:
            - State: the new state
        """
        self.logger.debug(
            "Task '{name}': Handling state change from {old} to {new}".format(
                name=prefect.context.get("task_full_name", self.task.name),
                old=type(old_state).__name__,
                new=type(new_state).__name__,
            ))
        for handler in self.task.state_handlers:
            new_state = handler(self.task, old_state, new_state) or new_state

        return new_state

    def initialize_run(  # type: ignore
            self, state: Optional[State],
            context: Dict[str, Any]) -> TaskRunnerInitializeResult:
        """
        Initializes the Task run by initializing state and context appropriately.

        If the task is being retried, then we retrieve the run count from the initial Retry
        state. Otherwise, we assume the run count is 1. The run count is stored in context as
        task_run_count.

        Also, if the task is being resumed through a `Resume` state, updates context to have
        `resume=True`.

        Args:
            - state (Optional[State]): the initial state of the run
            - context (Dict[str, Any]): the context to be updated with relevant information

        Returns:
            - tuple: a tuple of the updated state, context, upstream_states, and inputs objects
        """
        state, context = super().initialize_run(state=state, context=context)

        if isinstance(state, Retrying):
            run_count = state.run_count + 1
        else:
            run_count = state.context.get("task_run_count", 1)

        if isinstance(state, Resume):
            context.update(resume=True)

        if "_loop_count" in state.context:
            loop_result = state._result
            if loop_result.value is None and loop_result.location is not None:  # type: ignore
                loop_result_value = self.result.read(
                    loop_result.location).value  # type: ignore
            else:
                loop_result_value = loop_result.value  # type: ignore
            loop_context = {
                "task_loop_count": state.context.pop("_loop_count"),
                "task_loop_result": loop_result_value,
            }
            context.update(loop_context)

        context.update(
            task_run_count=run_count,
            task_name=self.task.name,
            task_tags=self.task.tags,
        )
        # Use the config stored in context if possible (should always be present)
        try:
            checkpointing = context["config"]["flows"]["checkpointing"]
        except KeyError:
            checkpointing = config.flows.checkpointing
        context.setdefault("checkpointing", checkpointing)

        map_index = context.get("map_index", None)
        if isinstance(map_index, int) and context.get("task_full_name"):
            context.update(logger=prefect.utilities.logging.get_logger(
                context.get("task_full_name")))
        else:
            context.update(logger=self.task.logger)

        # If provided, use task's target as result location
        if self.task.target:
            if not isinstance(self.task.target, str):
                self.result._formatter = self.task.target
                self.result.location = None
            else:
                self.result.location = self.task.target

        return TaskRunnerInitializeResult(state=state, context=context)

    @tail_recursive
    def run(
        self,
        state: State = None,
        upstream_states: Dict[Edge, State] = None,
        context: Dict[str, Any] = None,
        is_mapped_parent: bool = False,
    ) -> State:
        """
        The main endpoint for TaskRunners.  Calling this method will conditionally execute
        `self.task.run` with any provided inputs, assuming the upstream dependencies are in a
        state which allow this Task to run.

        Args:
            - state (State, optional): initial `State` to begin task run from;
                defaults to `Pending()`
            - upstream_states (Dict[Edge, State]): a dictionary
                representing the states of any tasks upstream of this one. The keys of the
                dictionary should correspond to the edges leading to the task.
            - context (dict, optional): prefect Context to use for execution
            - is_mapped_parent (bool): a boolean indicating whether this task run is the run of
                a parent mapped task

        Returns:
            - `State` object representing the final post-run state of the Task
        """
        upstream_states = upstream_states or {}
        context = context or prefect.context.to_dict()
        map_index = context.setdefault("map_index", None)
        context["task_full_name"] = "{name}{index}".format(
            name=self.task.name,
            index=("" if map_index is None else "[{}]".format(map_index)),
        )

        task_inputs = {}  # type: Dict[str, Any]

        try:
            # initialize the run
            state, context = self.initialize_run(state, context)

            # run state transformation pipeline
            with prefect.context(context):

                if prefect.context.get("task_loop_count") is None:
                    self.logger.info(
                        "Task '{name}': Starting task run...".format(
                            name=context["task_full_name"]))

                # check to make sure the task is in a pending state
                state = self.check_task_is_ready(state)

                # check if the task has reached its scheduled time
                state = self.check_task_reached_start_time(state)

                # Tasks never run if the upstream tasks haven't finished
                state = self.check_upstream_finished(
                    state, upstream_states=upstream_states)

                # check if any upstream tasks skipped (and if we need to skip)
                state = self.check_upstream_skipped(
                    state, upstream_states=upstream_states)

                # populate / hydrate all result objects
                state, upstream_states = self.load_results(
                    state=state, upstream_states=upstream_states)

                # retrieve task inputs from upstream and also explicitly passed inputs
                task_inputs = self.get_task_inputs(
                    state=state, upstream_states=upstream_states)

                if is_mapped_parent:
                    state = self.check_task_ready_to_map(
                        state, upstream_states=upstream_states)

                if self.task.target:
                    # check to see if there is a Result at the task's target
                    state = self.check_target(state, inputs=task_inputs)
                else:
                    # check to see if the task has a cached result
                    state = self.check_task_is_cached(state,
                                                      inputs=task_inputs)

                # check if the task's trigger passes
                # triggers can raise Pauses, which require task_inputs to be available for caching
                # so we run this after the previous step
                state = self.check_task_trigger(
                    state, upstream_states=upstream_states)

                # set the task state to running
                state = self.set_task_to_running(state, inputs=task_inputs)

                # run the task
                state = self.get_task_run_state(state, inputs=task_inputs)

                # cache the output, if appropriate
                state = self.cache_result(state, inputs=task_inputs)

                # check if the task needs to be retried
                state = self.check_for_retry(state, inputs=task_inputs)

                state = self.check_task_is_looping(
                    state,
                    inputs=task_inputs,
                    upstream_states=upstream_states,
                    context=context,
                )

        # for pending signals, including retries and pauses we need to make sure the
        # task_inputs are set
        except (ENDRUN, signals.PrefectStateSignal) as exc:
            state = exc.state
        except RecursiveCall as exc:
            raise exc

        except Exception as exc:
            msg = "Task '{name}': unexpected error while running task: {exc}".format(
                name=context["task_full_name"], exc=repr(exc))
            self.logger.exception(msg)
            state = Failed(message=msg, result=exc)
            if prefect.context.get("raise_on_exception"):
                raise exc

        # to prevent excessive repetition of this log
        # since looping relies on recursively calling self.run
        # TODO: figure out a way to only log this one single time instead of twice
        if prefect.context.get("task_loop_count") is None:
            # wrapping this final log in prefect.context(context) ensures
            # that any run-context, including task-run-ids, are respected
            with prefect.context(context):
                self.logger.info(
                    "Task '{name}': finished task run for task with final state: '{state}'"
                    .format(name=context["task_full_name"],
                            state=type(state).__name__))

        return state

    @call_state_handlers
    def check_upstream_finished(self, state: State,
                                upstream_states: Dict[Edge, State]) -> State:
        """
        Checks if the upstream tasks have all finshed.

        Args:
            - state (State): the current state of this task
            - upstream_states (Dict[Edge, Union[State, List[State]]]): the upstream states

        Returns:
            - State: the state of the task after running the check

        Raises:
            - ENDRUN: if upstream tasks are not finished.
        """
        all_states = set()  # type: Set[State]
        for edge, upstream_state in upstream_states.items():
            # if the upstream state is Mapped, and this task is also mapped,
            # we want each individual child to determine if it should
            # proceed or not based on its upstream parent in the mapping
            if isinstance(upstream_state, Mapped) and not edge.mapped:
                all_states.update(upstream_state.map_states)
            else:
                all_states.add(upstream_state)

        if not all(s.is_finished() for s in all_states):
            self.logger.debug(
                "Task '{name}': not all upstream states are finished; ending run."
                .format(name=prefect.context.get("task_full_name",
                                                 self.task.name)))
            raise ENDRUN(state)
        return state

    @call_state_handlers
    def check_upstream_skipped(self, state: State,
                               upstream_states: Dict[Edge, State]) -> State:
        """
        Checks if any of the upstream tasks have skipped.

        Args:
            - state (State): the current state of this task
            - upstream_states (Dict[Edge, State]): the upstream states

        Returns:
            - State: the state of the task after running the check
        """

        all_states = set()  # type: Set[State]
        for edge, upstream_state in upstream_states.items():

            # if the upstream state is Mapped, and this task is also mapped,
            # we want each individual child to determine if it should
            # skip or not based on its upstream parent in the mapping
            if isinstance(upstream_state, Mapped) and not edge.mapped:
                all_states.update(upstream_state.map_states)
            else:
                all_states.add(upstream_state)

        if self.task.skip_on_upstream_skip and any(s.is_skipped()
                                                   for s in all_states):
            self.logger.debug(
                "Task '{name}': Upstream states were skipped; ending run.".
                format(name=prefect.context.get("task_full_name",
                                                self.task.name)))
            raise ENDRUN(state=Skipped(message=(
                "Upstream task was skipped; if this was not the intended "
                "behavior, consider changing `skip_on_upstream_skip=False` "
                "for this task.")))
        return state

    @call_state_handlers
    def check_task_ready_to_map(self, state: State,
                                upstream_states: Dict[Edge, State]) -> State:
        """
        Checks if the parent task is ready to proceed with mapping.

        Args:
            - state (State): the current state of this task
            - upstream_states (Dict[Edge, Union[State, List[State]]]): the upstream states

        Raises:
            - ENDRUN: either way, we dont continue past this point
        """
        if state.is_mapped():
            raise ENDRUN(state)

        # we can't map if there are no success states with iterables upstream
        if upstream_states and not any([
                edge.mapped and state.is_successful()
                for edge, state in upstream_states.items()
        ]):
            new_state = Failed(
                "No upstream states can be mapped over.")  # type: State
            raise ENDRUN(new_state)
        elif not all([
                hasattr(state.result, "__getitem__")
                for edge, state in upstream_states.items() if
                state.is_successful() and not state.is_mapped() and edge.mapped
        ]):
            new_state = Failed(
                "At least one upstream state has an unmappable result.")
            raise ENDRUN(new_state)
        else:
            new_state = Mapped("Ready to proceed with mapping.")
            raise ENDRUN(new_state)

    @call_state_handlers
    def check_task_trigger(self, state: State,
                           upstream_states: Dict[Edge, State]) -> State:
        """
        Checks if the task's trigger function passes.

        Args:
            - state (State): the current state of this task
            - upstream_states (Dict[Edge, Union[State, List[State]]]): the upstream states

        Returns:
            - State: the state of the task after running the check

        Raises:
            - ENDRUN: if the trigger raises an error
        """
        try:
            if not self.task.trigger(upstream_states):
                raise signals.TRIGGERFAIL(message="Trigger failed")

        except signals.PrefectStateSignal as exc:

            self.logger.debug(
                "Task '{name}': {signal} signal raised during execution.".
                format(
                    name=prefect.context.get("task_full_name", self.task.name),
                    signal=type(exc).__name__,
                ))
            if prefect.context.get("raise_on_exception"):
                raise exc
            raise ENDRUN(exc.state)

        # Exceptions are trapped and turned into TriggerFailed states
        except Exception as exc:
            self.logger.exception(
                "Task '{name}': unexpected error while evaluating task trigger: {exc}"
                .format(
                    exc=repr(exc),
                    name=prefect.context.get("task_full_name", self.task.name),
                ))
            if prefect.context.get("raise_on_exception"):
                raise exc
            raise ENDRUN(
                TriggerFailed(
                    "Unexpected error while checking task trigger: {}".format(
                        repr(exc)),
                    result=exc,
                ))

        return state

    @call_state_handlers
    def check_task_is_ready(self, state: State) -> State:
        """
        Checks to make sure the task is ready to run (Pending or Mapped).

        Args:
            - state (State): the current state of this task

        Returns:
            - State: the state of the task after running the check

        Raises:
            - ENDRUN: if the task is not ready to run
        """

        # the task is ready
        if state.is_pending():
            return state

        # the task is mapped, in which case we still proceed so that the children tasks
        # are generated (note that if the children tasks)
        elif state.is_mapped():
            self.logger.debug(
                "Task '%s': task is mapped, but run will proceed so children are generated.",
                prefect.context.get("task_full_name", self.task.name),
            )
            return state

        # this task is already running
        elif state.is_running():
            self.logger.debug(
                "Task '%s': task is already running.",
                prefect.context.get("task_full_name", self.task.name),
            )
            raise ENDRUN(state)

        elif state.is_cached():
            return state

        # this task is already finished
        elif state.is_finished():
            self.logger.debug(
                "Task '{name}': task is already finished.".format(
                    name=prefect.context.get("task_full_name",
                                             self.task.name)))
            raise ENDRUN(state)

        # this task is not pending
        else:
            self.logger.debug(
                "Task '{name}' is not ready to run or state was unrecognized ({state})."
                .format(
                    name=prefect.context.get("task_full_name", self.task.name),
                    state=state,
                ))
            raise ENDRUN(state)

    @call_state_handlers
    def check_task_reached_start_time(self, state: State) -> State:
        """
        Checks if a task is in a Scheduled state and, if it is, ensures that the scheduled
        time has been reached. Note: Scheduled states include Retry states. Scheduled
        states with no start time (`start_time = None`) are never considered ready;
        they must be manually placed in another state.

        Args:
            - state (State): the current state of this task

        Returns:
            - State: the state of the task after performing the check

        Raises:
            - ENDRUN: if the task is Scheduled with a future scheduled time
        """
        if isinstance(state, Scheduled):
            # handle case where no start_time is set
            if state.start_time is None:
                self.logger.debug(
                    "Task '{name}' is scheduled without a known start_time; ending run."
                    .format(name=prefect.context.get("task_full_name",
                                                     self.task.name)))
                raise ENDRUN(state)

            # handle case where start time is in the future
            elif state.start_time and state.start_time > pendulum.now("utc"):
                self.logger.debug(
                    "Task '{name}': start_time has not been reached; ending run."
                    .format(name=prefect.context.get("task_full_name",
                                                     self.task.name)))
                raise ENDRUN(state)

        return state

    def get_task_inputs(
            self, state: State,
            upstream_states: Dict[Edge, State]) -> Dict[str, Result]:
        """
        Given the task's current state and upstream states, generates the inputs for this task.
        Upstream state result values are used.

        Args:
            - state (State): the task's current state.
            - upstream_states (Dict[Edge, State]): the upstream state_handlers

        Returns:
            - Dict[str, Result]: the task inputs

        """
        task_inputs = {}  # type: Dict[str, Result]

        for edge, upstream_state in upstream_states.items():
            # construct task inputs
            if edge.key is not None:
                task_inputs[edge.key] = upstream_state._result  # type: ignore

        return task_inputs

    def load_results(
            self, state: State,
            upstream_states: Dict[Edge,
                                  State]) -> Tuple[State, Dict[Edge, State]]:
        """
        Given the task's current state and upstream states, populates all relevant result
        objects for this task run.

        Args:
            - state (State): the task's current state.
            - upstream_states (Dict[Edge, State]): the upstream state_handlers

        Returns:
            - Tuple[State, dict]: a tuple of (state, upstream_states)

        """
        return state, upstream_states

    @call_state_handlers
    def check_target(self, state: State, inputs: Dict[str, Result]) -> State:
        """
        Checks if a Result exists at the task's target.

        Args:
            - state (State): the current state of this task
            - inputs (Dict[str, Result]): a dictionary of inputs whose keys correspond
                to the task's `run()` arguments.

        Returns:
            - State: the state of the task after running the check
        """
        result = self.result
        target = self.task.target

        if result and target:
            raw_inputs = {k: r.value for k, r in inputs.items()}
            formatting_kwargs = {
                **prefect.context.get("parameters", {}).copy(),
                **raw_inputs,
                **prefect.context,
            }

            if not isinstance(target, str):
                target = target(**formatting_kwargs)

            if result.exists(target, **formatting_kwargs):
                known_location = target.format(**formatting_kwargs)
                new_res = result.read(known_location)
                cached_state = Cached(
                    result=new_res,
                    hashed_inputs={
                        key: tokenize(val.value)
                        for key, val in inputs.items()
                    },
                    cached_result_expiration=None,
                    cached_parameters=formatting_kwargs.get("parameters"),
                    message=f"Result found at task target {known_location}",
                )
                return cached_state

        return state

    @call_state_handlers
    def check_task_is_cached(self, state: State,
                             inputs: Dict[str, Result]) -> State:
        """
        Checks if task is cached and whether the cache is still valid.

        Args:
            - state (State): the current state of this task
            - inputs (Dict[str, Result]): a dictionary of inputs whose keys correspond
                to the task's `run()` arguments.

        Returns:
            - State: the state of the task after running the check

        Raises:
            - ENDRUN: if the task is not ready to run
        """
        if state.is_cached():
            assert isinstance(state, Cached)  # mypy assert
            sanitized_inputs = {key: res.value for key, res in inputs.items()}
            if self.task.cache_validator(state, sanitized_inputs,
                                         prefect.context.get("parameters")):
                return state
            else:
                state = Pending("Cache was invalid; ready to run.")

        if self.task.cache_for is not None:
            candidate_states = []
            if prefect.context.get("caches"):
                candidate_states = prefect.context.caches.get(
                    self.task.cache_key or self.task.name, [])
            sanitized_inputs = {key: res.value for key, res in inputs.items()}
            for candidate in candidate_states:
                if self.task.cache_validator(
                        candidate, sanitized_inputs,
                        prefect.context.get("parameters")):
                    return candidate

        if self.task.cache_for is not None:
            self.logger.warning(
                "Task '{name}': can't use cache because it "
                "is now invalid".format(name=prefect.context.get(
                    "task_full_name", self.task.name)))
        return state or Pending("Cache was invalid; ready to run.")

    @call_state_handlers
    def set_task_to_running(self, state: State, inputs: Dict[str,
                                                             Result]) -> State:
        """
        Sets the task to running

        Args:
            - state (State): the current state of this task
            - inputs (Dict[str, Result]): a dictionary of inputs whose keys correspond
                to the task's `run()` arguments.

        Returns:
            - State: the state of the task after running the check

        Raises:
            - ENDRUN: if the task is not ready to run
        """
        if not state.is_pending():
            self.logger.debug(
                "Task '{name}': can't set state to Running because it "
                "isn't Pending; ending run.".format(name=prefect.context.get(
                    "task_full_name", self.task.name)))
            raise ENDRUN(state)

        new_state = Running(message="Starting task run.")
        return new_state

    @run_with_heartbeat
    @call_state_handlers
    def get_task_run_state(self, state: State, inputs: Dict[str,
                                                            Result]) -> State:
        """
        Runs the task and traps any signals or errors it raises.
        Also checkpoints the result of a successful task, if `task.checkpoint` is `True`.

        Args:
            - state (State): the current state of this task
            - inputs (Dict[str, Result], optional): a dictionary of inputs whose keys correspond
                to the task's `run()` arguments.

        Returns:
            - State: the state of the task after running the check

        Raises:
            - signals.PAUSE: if the task raises PAUSE
            - ENDRUN: if the task is not ready to run
        """
        if not state.is_running():
            self.logger.debug(
                "Task '{name}': can't run task because it's not in a "
                "Running state; ending run.".format(name=prefect.context.get(
                    "task_full_name", self.task.name)))

            raise ENDRUN(state)

        value = None
        raw_inputs = {k: r.value for k, r in inputs.items()}
        try:
            self.logger.debug(
                "Task '{name}': Calling task.run() method...".format(
                    name=prefect.context.get("task_full_name",
                                             self.task.name)))
            timeout_handler = prefect.utilities.executors.timeout_handler
            if getattr(self.task, "log_stdout", False):
                with redirect_stdout(
                        prefect.utilities.logging.RedirectToLog(
                            self.logger)  # type: ignore
                ):
                    value = timeout_handler(self.task.run,
                                            timeout=self.task.timeout,
                                            **raw_inputs)
            else:
                value = timeout_handler(self.task.run,
                                        timeout=self.task.timeout,
                                        **raw_inputs)

        # inform user of timeout
        except TimeoutError as exc:
            if prefect.context.get("raise_on_exception"):
                raise exc
            state = TimedOut("Task timed out during execution.", result=exc)
            return state

        except signals.LOOP as exc:
            new_state = exc.state
            assert isinstance(new_state, Looped)
            new_state.result = self.result.from_value(value=new_state.result)
            new_state.message = exc.state.message or "Task is looping ({})".format(
                new_state.loop_count)
            return new_state

        # checkpoint tasks if a result is present, except for when the user has opted out by
        # disabling checkpointing
        if (prefect.context.get("checkpointing") is True
                and self.task.checkpoint is not False and value is not None):
            try:
                formatting_kwargs = {
                    **prefect.context.get("parameters", {}).copy(),
                    **raw_inputs,
                    **prefect.context,
                }
                result = self.result.write(value, **formatting_kwargs)
            except NotImplementedError:
                result = self.result.from_value(value=value)
        else:
            result = self.result.from_value(value=value)

        state = Success(result=result, message="Task run succeeded.")
        return state

    @call_state_handlers
    def cache_result(self, state: State, inputs: Dict[str, Result]) -> State:
        """
        Caches the result of a successful task, if appropriate. Alternatively,
        if the task is failed, caches the inputs.

        Tasks are cached if:
            - task.cache_for is not None
            - the task state is Successful
            - the task state is not Skipped (which is a subclass of Successful)

        Args:
            - state (State): the current state of this task
            - inputs (Dict[str, Result], optional): a dictionary of inputs whose keys correspond
                to the task's `run()` arguments.

        Returns:
            - State: the state of the task after running the check

        """
        if (state.is_successful() and not state.is_skipped()
                and self.task.cache_for is not None):
            expiration = pendulum.now("utc") + self.task.cache_for
            cached_state = Cached(
                result=state._result,
                hashed_inputs={
                    key: tokenize(val.value)
                    for key, val in inputs.items()
                },
                cached_result_expiration=expiration,
                cached_parameters=prefect.context.get("parameters"),
                message=state.message,
            )
            return cached_state

        return state

    @call_state_handlers
    def check_for_retry(self, state: State, inputs: Dict[str,
                                                         Result]) -> State:
        """
        Checks to see if a FAILED task should be retried.

        Args:
            - state (State): the current state of this task
            - inputs (Dict[str, Result]): a dictionary of inputs whose keys correspond
                to the task's `run()` arguments.

        Returns:
            - State: the state of the task after running the check
        """
        if state.is_failed():
            run_count = prefect.context.get("task_run_count", 1)
            loop_result = None
            state_context = None
            if prefect.context.get("task_loop_count") is not None:

                loop_result = self.result.from_value(
                    value=prefect.context.get("task_loop_result"))

                # checkpoint tasks if a result is present, except for when the user has opted
                # out by disabling checkpointing
                if (prefect.context.get("checkpointing") is True
                        and self.task.checkpoint is not False
                        and loop_result.value is not None):
                    try:
                        raw_inputs = {k: r.value for k, r in inputs.items()}
                        formatting_kwargs = {
                            **prefect.context.get("parameters", {}).copy(),
                            **raw_inputs,
                            **prefect.context,
                        }
                        loop_result = self.result.write(
                            loop_result.value, **formatting_kwargs)
                    except NotImplementedError:
                        pass

                state_context = {
                    "_loop_count": prefect.context["task_loop_count"]
                }
            if run_count <= self.task.max_retries:
                start_time = pendulum.now("utc") + self.task.retry_delay
                msg = "Retrying Task (after attempt {n} of {m})".format(
                    n=run_count, m=self.task.max_retries + 1)
                retry_state = Retrying(
                    start_time=start_time,
                    context=state_context,
                    message=msg,
                    run_count=run_count,
                    result=loop_result,
                )
                return retry_state

        return state

    def check_task_is_looping(
        self,
        state: State,
        inputs: Dict[str, Result] = None,
        upstream_states: Dict[Edge, State] = None,
        context: Dict[str, Any] = None,
    ) -> State:
        """
        Checks to see if the task is in a `Looped` state and if so, rerun the pipeline with an
        incremeneted `loop_count`.

        Args:
            - state (State, optional): initial `State` to begin task run from;
                defaults to `Pending()`
            - inputs (Dict[str, Result], optional): a dictionary of inputs whose keys correspond
                to the task's `run()` arguments.
            - upstream_states (Dict[Edge, State]): a dictionary
                representing the states of any tasks upstream of this one. The keys of the
                dictionary should correspond to the edges leading to the task.
            - context (dict, optional): prefect Context to use for execution

        Returns:
            - `State` object representing the final post-run state of the Task
        """
        if state.is_looped():
            assert isinstance(state, Looped)  # mypy assert
            assert isinstance(context, dict)  # mypy assert
            msg = "Looping task (on loop index {})".format(state.loop_count)
            context.update({
                "task_loop_result": state.result,
                "task_loop_count": state.loop_count + 1,
            })
            context.update(
                task_run_version=prefect.context.get("task_run_version"))
            new_state = Pending(message=msg)
            raise RecursiveCall(
                self.run,
                self,
                new_state,
                upstream_states=upstream_states,
                context=context,
            )

        return state
Пример #2
0
class TaskRunner(Runner):
    """
    TaskRunners handle the execution of Tasks and determine the State of a Task
    before, during and after the Task is run.

    In particular, through the TaskRunner you can specify the states of any upstream dependencies
    and what state the Task should be initialized with.

    Args:
        - task (Task): the Task to be run / executed
        - state_handlers (Iterable[Callable], optional): A list of state change handlers
            that will be called whenever the task changes state, providing an
            opportunity to inspect or modify the new state. The handler
            will be passed the task runner instance, the old (prior) state, and the new
            (current) state, with the following signature: `state_handler(TaskRunner, old_state, new_state) -> Optional[State]`;
            If multiple functions are passed, then the `new_state` argument will be the
            result of the previous handler.
        - result (Result, optional): the result type to use for retrieving and storing state results
            during execution (if the Task doesn't already have one)
        - default_result (Result, optional): the fallback result type to use for retrieving and storing state results
            during execution (to be used on upstream inputs if they don't provide their own results)
    """

    def __init__(
        self,
        task: Task,
        state_handlers: Iterable[Callable] = None,
        result: Result = None,
        default_result: Result = None,
    ):
        self.context = prefect.context.to_dict()
        self.task = task

        # if the result was provided off the parent Flow object
        # we want to use the task's target as the target location
        if task.result:
            self.result = task.result
        else:
            self.result = Result().copy() if result is None else result.copy()
            if self.task.target:
                self.result.location = self.task.target
        self.default_result = default_result or Result()
        super().__init__(state_handlers=state_handlers)

    def __repr__(self) -> str:
        return "<{}: {}>".format(type(self).__name__, self.task.name)

    def call_runner_target_handlers(self, old_state: State, new_state: State) -> State:
        """
        A special state handler that the TaskRunner uses to call its task's state handlers.
        This method is called as part of the base Runner's `handle_state_change()` method.

        Args:
            - old_state (State): the old (previous) state
            - new_state (State): the new (current) state

        Returns:
            - State: the new state
        """
        self.logger.debug(
            "Task '{name}': Handling state change from {old} to {new}".format(
                name=prefect.context.get("task_full_name", self.task.name),
                old=type(old_state).__name__,
                new=type(new_state).__name__,
            )
        )
        for handler in self.task.state_handlers:
            new_state = handler(self.task, old_state, new_state) or new_state

        return new_state

    def initialize_run(  # type: ignore
        self, state: Optional[State], context: Dict[str, Any]
    ) -> TaskRunnerInitializeResult:
        """
        Initializes the Task run by initializing state and context appropriately.

        If the task is being retried, then we retrieve the run count from the initial Retry
        state. Otherwise, we assume the run count is 1. The run count is stored in context as
        task_run_count.

        Also, if the task is being resumed through a `Resume` state, updates context to have `resume=True`.

        Args:
            - state (Optional[State]): the initial state of the run
            - context (Dict[str, Any]): the context to be updated with relevant information

        Returns:
            - tuple: a tuple of the updated state, context, upstream_states, and inputs objects
        """
        state, context = super().initialize_run(state=state, context=context)

        if isinstance(state, Retrying):
            run_count = state.run_count + 1
        else:
            run_count = state.context.get("task_run_count", 1)

        if isinstance(state, Resume):
            context.update(resume=True)

        if "_loop_count" in state.cached_inputs:  # type: ignore
            loop_result = state.cached_inputs.pop("_loop_result")
            if loop_result.value is None and loop_result.location is not None:
                loop_result_value = self.result.read(loop_result.location).value
            else:
                loop_result_value = loop_result.value
            loop_context = {
                "task_loop_count": json.loads(
                    state.cached_inputs.pop(  # type: ignore
                        "_loop_count"
                    ).location
                ),  # type: ignore
                "task_loop_result": loop_result_value,
            }
            context.update(loop_context)

        context.update(
            task_run_count=run_count,
            task_name=self.task.name,
            task_tags=self.task.tags,
            task_slug=self.task.slug,
        )
        context.setdefault("checkpointing", config.flows.checkpointing)

        map_index = context.get("map_index", None)
        if isinstance(map_index, int) and context.get("task_full_name"):
            context.update(
                logger=prefect.utilities.logging.get_logger(
                    context.get("task_full_name")
                )
            )
        else:
            context.update(logger=self.task.logger)

        return TaskRunnerInitializeResult(state=state, context=context)

    @tail_recursive
    def run(
        self,
        state: State = None,
        upstream_states: Dict[Edge, State] = None,
        context: Dict[str, Any] = None,
        executor: "prefect.engine.executors.Executor" = None,
    ) -> State:
        """
        The main endpoint for TaskRunners.  Calling this method will conditionally execute
        `self.task.run` with any provided inputs, assuming the upstream dependencies are in a
        state which allow this Task to run.

        Args:
            - state (State, optional): initial `State` to begin task run from;
                defaults to `Pending()`
            - upstream_states (Dict[Edge, State]): a dictionary
                representing the states of any tasks upstream of this one. The keys of the
                dictionary should correspond to the edges leading to the task.
            - context (dict, optional): prefect Context to use for execution
            - executor (Executor, optional): executor to use when performing
                computation; defaults to the executor specified in your prefect configuration

        Returns:
            - `State` object representing the final post-run state of the Task
        """
        upstream_states = upstream_states or {}
        context = context or {}
        map_index = context.setdefault("map_index", None)
        context["task_full_name"] = "{name}{index}".format(
            name=self.task.name,
            index=("" if map_index is None else "[{}]".format(map_index)),
        )

        if executor is None:
            executor = prefect.engine.get_default_executor_class()()

        # if mapped is true, this task run is going to generate a Mapped state. It won't
        # actually run, but rather spawn children tasks to map over its inputs. We
        # detect this case by checking for:
        #   - upstream edges that are `mapped`
        #   - no `map_index` (which indicates that this is the child task, not the parent)
        mapped = any([e.mapped for e in upstream_states]) and map_index is None
        task_inputs = {}  # type: Dict[str, Any]

        try:
            # initialize the run
            state, context = self.initialize_run(state, context)

            # run state transformation pipeline
            with prefect.context(context):

                if prefect.context.get("task_loop_count") is None:
                    self.logger.info(
                        "Task '{name}': Starting task run...".format(
                            name=context["task_full_name"]
                        )
                    )

                # check to make sure the task is in a pending state
                state = self.check_task_is_ready(state)

                # check if the task has reached its scheduled time
                state = self.check_task_reached_start_time(state)

                # Tasks never run if the upstream tasks haven't finished
                state = self.check_upstream_finished(
                    state, upstream_states=upstream_states
                )

                # check if any upstream tasks skipped (and if we need to skip)
                state = self.check_upstream_skipped(
                    state, upstream_states=upstream_states
                )

                # populate / hydrate all result objects
                state, upstream_states = self.load_results(
                    state=state, upstream_states=upstream_states
                )

                # if the task is mapped, process the mapped children and exit
                if mapped:
                    state = self.run_mapped_task(
                        state=state,
                        upstream_states=upstream_states,
                        context=context,
                        executor=executor,
                    )

                    state = self.wait_for_mapped_task(state=state, executor=executor)

                    self.logger.debug(
                        "Task '{name}': task has been mapped; ending run.".format(
                            name=context["task_full_name"]
                        )
                    )
                    raise ENDRUN(state)

                # retrieve task inputs from upstream and also explicitly passed inputs
                task_inputs = self.get_task_inputs(
                    state=state, upstream_states=upstream_states
                )

                if self.task.target:
                    # check to see if there is a Result at the task's target
                    state = self.check_target(state, inputs=task_inputs)
                else:
                    # check to see if the task has a cached result
                    state = self.check_task_is_cached(state, inputs=task_inputs)

                # check if the task's trigger passes
                # triggers can raise Pauses, which require task_inputs to be available for caching
                # so we run this after the previous step
                state = self.check_task_trigger(state, upstream_states=upstream_states)

                # set the task state to running
                state = self.set_task_to_running(state, inputs=task_inputs)

                # run the task
                state = self.get_task_run_state(
                    state, inputs=task_inputs, timeout_handler=executor.timeout_handler
                )

                # cache the output, if appropriate
                state = self.cache_result(state, inputs=task_inputs)

                # check if the task needs to be retried
                state = self.check_for_retry(state, inputs=task_inputs)

                state = self.check_task_is_looping(
                    state,
                    inputs=task_inputs,
                    upstream_states=upstream_states,
                    context=context,
                    executor=executor,
                )

        # for pending signals, including retries and pauses we need to make sure the
        # task_inputs are set
        except (ENDRUN, signals.PrefectStateSignal) as exc:
            exc.state.cached_inputs = task_inputs or {}
            state = exc.state
        except RecursiveCall as exc:
            raise exc

        except Exception as exc:
            msg = "Task '{name}': unexpected error while running task: {exc}".format(
                name=context["task_full_name"], exc=repr(exc)
            )
            self.logger.exception(msg)
            state = Failed(message=msg, result=exc, cached_inputs=task_inputs)
            if prefect.context.get("raise_on_exception"):
                raise exc

        # to prevent excessive repetition of this log
        # since looping relies on recursively calling self.run
        # TODO: figure out a way to only log this one single time instead of twice
        if prefect.context.get("task_loop_count") is None:
            # wrapping this final log in prefect.context(context) ensures
            # that any run-context, including task-run-ids, are respected
            with prefect.context(context):
                self.logger.info(
                    "Task '{name}': finished task run for task with final state: '{state}'".format(
                        name=context["task_full_name"], state=type(state).__name__
                    )
                )

        return state

    @call_state_handlers
    def check_upstream_finished(
        self, state: State, upstream_states: Dict[Edge, State]
    ) -> State:
        """
        Checks if the upstream tasks have all finshed.

        Args:
            - state (State): the current state of this task
            - upstream_states (Dict[Edge, Union[State, List[State]]]): the upstream states

        Returns:
            - State: the state of the task after running the check

        Raises:
            - ENDRUN: if upstream tasks are not finished.
        """
        all_states = set()  # type: Set[State]
        for edge, upstream_state in upstream_states.items():
            # if the upstream state is Mapped, and this task is also mapped,
            # we want each individual child to determine if it should
            # proceed or not based on its upstream parent in the mapping
            if isinstance(upstream_state, Mapped) and not edge.mapped:
                all_states.update(upstream_state.map_states)
            else:
                all_states.add(upstream_state)

        if not all(s.is_finished() for s in all_states):
            self.logger.debug(
                "Task '{name}': not all upstream states are finished; ending run.".format(
                    name=prefect.context.get("task_full_name", self.task.name)
                )
            )
            raise ENDRUN(state)
        return state

    @call_state_handlers
    def check_upstream_skipped(
        self, state: State, upstream_states: Dict[Edge, State]
    ) -> State:
        """
        Checks if any of the upstream tasks have skipped.

        Args:
            - state (State): the current state of this task
            - upstream_states (Dict[Edge, State]): the upstream states

        Returns:
            - State: the state of the task after running the check
        """

        all_states = set()  # type: Set[State]
        for edge, upstream_state in upstream_states.items():

            # if the upstream state is Mapped, and this task is also mapped,
            # we want each individual child to determine if it should
            # skip or not based on its upstream parent in the mapping
            if isinstance(upstream_state, Mapped) and not edge.mapped:
                all_states.update(upstream_state.map_states)
            else:
                all_states.add(upstream_state)

        if self.task.skip_on_upstream_skip and any(s.is_skipped() for s in all_states):
            self.logger.debug(
                "Task '{name}': Upstream states were skipped; ending run.".format(
                    name=prefect.context.get("task_full_name", self.task.name)
                )
            )
            raise ENDRUN(
                state=Skipped(
                    message=(
                        "Upstream task was skipped; if this was not the intended "
                        "behavior, consider changing `skip_on_upstream_skip=False` "
                        "for this task."
                    )
                )
            )
        return state

    @call_state_handlers
    def check_task_trigger(
        self, state: State, upstream_states: Dict[Edge, State]
    ) -> State:
        """
        Checks if the task's trigger function passes.

        Args:
            - state (State): the current state of this task
            - upstream_states (Dict[Edge, Union[State, List[State]]]): the upstream states

        Returns:
            - State: the state of the task after running the check

        Raises:
            - ENDRUN: if the trigger raises an error
        """
        try:
            if not self.task.trigger(upstream_states):
                raise signals.TRIGGERFAIL(message="Trigger failed")

        except signals.PrefectStateSignal as exc:

            self.logger.debug(
                "Task '{name}': {signal} signal raised during execution.".format(
                    name=prefect.context.get("task_full_name", self.task.name),
                    signal=type(exc).__name__,
                )
            )
            if prefect.context.get("raise_on_exception"):
                raise exc
            raise ENDRUN(exc.state)

        # Exceptions are trapped and turned into TriggerFailed states
        except Exception as exc:
            self.logger.exception(
                "Task '{name}': unexpected error while evaluating task trigger: {exc}".format(
                    exc=repr(exc),
                    name=prefect.context.get("task_full_name", self.task.name),
                )
            )
            if prefect.context.get("raise_on_exception"):
                raise exc
            raise ENDRUN(
                TriggerFailed(
                    "Unexpected error while checking task trigger: {}".format(
                        repr(exc)
                    ),
                    result=exc,
                )
            )

        return state

    @call_state_handlers
    def check_task_is_ready(self, state: State) -> State:
        """
        Checks to make sure the task is ready to run (Pending or Mapped).

        Args:
            - state (State): the current state of this task

        Returns:
            - State: the state of the task after running the check

        Raises:
            - ENDRUN: if the task is not ready to run
        """

        # the task is ready
        if state.is_pending():
            return state

        # the task is mapped, in which case we still proceed so that the children tasks
        # are generated (note that if the children tasks)
        elif state.is_mapped():
            self.logger.debug(
                "Task '{name}': task is mapped, but run will proceed so children are generated.".format(
                    name=prefect.context.get("task_full_name", self.task.name)
                )
            )
            return state

        # this task is already running
        elif state.is_running():
            self.logger.debug(
                "Task '{name}': task is already running.".format(
                    name=prefect.context.get("task_full_name", self.task.name)
                )
            )
            raise ENDRUN(state)

        elif state.is_cached():
            return state

        # this task is already finished
        elif state.is_finished():
            self.logger.debug(
                "Task '{name}': task is already finished.".format(
                    name=prefect.context.get("task_full_name", self.task.name)
                )
            )
            raise ENDRUN(state)

        # this task is not pending
        else:
            self.logger.debug(
                "Task '{name}' is not ready to run or state was unrecognized ({state}).".format(
                    name=prefect.context.get("task_full_name", self.task.name),
                    state=state,
                )
            )
            raise ENDRUN(state)

    @call_state_handlers
    def check_task_reached_start_time(self, state: State) -> State:
        """
        Checks if a task is in a Scheduled state and, if it is, ensures that the scheduled
        time has been reached. Note: Scheduled states include Retry states. Scheduled
        states with no start time (`start_time = None`) are never considered ready;
        they must be manually placed in another state.

        Args:
            - state (State): the current state of this task

        Returns:
            - State: the state of the task after performing the check

        Raises:
            - ENDRUN: if the task is Scheduled with a future scheduled time
        """
        if isinstance(state, Scheduled):
            # handle case where no start_time is set
            if state.start_time is None:
                self.logger.debug(
                    "Task '{name}' is scheduled without a known start_time; ending run.".format(
                        name=prefect.context.get("task_full_name", self.task.name)
                    )
                )
                raise ENDRUN(state)

            # handle case where start time is in the future
            elif state.start_time and state.start_time > pendulum.now("utc"):
                self.logger.debug(
                    "Task '{name}': start_time has not been reached; ending run.".format(
                        name=prefect.context.get("task_full_name", self.task.name)
                    )
                )
                raise ENDRUN(state)

        return state

    def get_task_inputs(
        self, state: State, upstream_states: Dict[Edge, State]
    ) -> Dict[str, Result]:
        """
        Given the task's current state and upstream states, generates the inputs for this task.
        Upstream state result values are used. If the current state has `cached_inputs`, they
        will override any upstream values.

        Args:
            - state (State): the task's current state.
            - upstream_states (Dict[Edge, State]): the upstream state_handlers

        Returns:
            - Dict[str, Result]: the task inputs

        """
        task_inputs = {}  # type: Dict[str, Result]
        results = {}  # type: Dict[str, Result]

        for edge, upstream_state in upstream_states.items():
            # construct task inputs
            if edge.key is not None:
                task_inputs[edge.key] = upstream_state._result  # type: ignore

        if state.is_pending() and state.cached_inputs:
            task_inputs.update(
                {
                    k: r
                    for k, r in state.cached_inputs.items()
                    if task_inputs.get(k, NoResult) == NoResult
                }
            )

        return task_inputs

    def load_results(
        self, state: State, upstream_states: Dict[Edge, State]
    ) -> Tuple[State, Dict[Edge, State]]:
        """
        Given the task's current state and upstream states, populates all relevant result objects for this task run.

        Args:
            - state (State): the task's current state.
            - upstream_states (Dict[Edge, State]): the upstream state_handlers

        Returns:
            - Tuple[State, dict]: a tuple of (state, upstream_states)

        """
        return state, upstream_states

    @call_state_handlers
    def check_target(self, state: State, inputs: Dict[str, Result]) -> State:
        """
        Checks if a Result exists at the task's target.

        Args:
            - state (State): the current state of this task
            - inputs (Dict[str, Result]): a dictionary of inputs whose keys correspond
                to the task's `run()` arguments.

        Returns:
            - State: the state of the task after running the check
        """
        result = self.result
        target = self.task.target

        if result and target:
            if result.exists(target, **prefect.context):
                new_res = result.read(target.format(**prefect.context))
                cached_state = Cached(
                    result=new_res,
                    cached_inputs=inputs,
                    cached_result_expiration=None,
                    cached_parameters=prefect.context.get("parameters"),
                    message=f"Result found at task target {target}",
                )
                return cached_state

        return state

    @call_state_handlers
    def check_task_is_cached(self, state: State, inputs: Dict[str, Result]) -> State:
        """
        Checks if task is cached and whether the cache is still valid.

        Args:
            - state (State): the current state of this task
            - inputs (Dict[str, Result]): a dictionary of inputs whose keys correspond
                to the task's `run()` arguments.

        Returns:
            - State: the state of the task after running the check

        Raises:
            - ENDRUN: if the task is not ready to run
        """
        if state.is_cached():
            assert isinstance(state, Cached)  # mypy assert
            sanitized_inputs = {key: res.value for key, res in inputs.items()}
            if self.task.cache_validator(
                state, sanitized_inputs, prefect.context.get("parameters")
            ):
                return state
            else:
                state = Pending("Cache was invalid; ready to run.")

        if self.task.cache_for is not None:
            candidate_states = []
            if prefect.context.get("caches"):
                candidate_states = prefect.context.caches.get(
                    self.task.cache_key or self.task.name, []
                )
            sanitized_inputs = {key: res.value for key, res in inputs.items()}
            for candidate in candidate_states:
                if self.task.cache_validator(
                    candidate, sanitized_inputs, prefect.context.get("parameters")
                ):
                    return candidate

        if self.task.cache_for is not None:
            self.logger.warning(
                "Task '{name}': can't use cache because it "
                "is now invalid".format(
                    name=prefect.context.get("task_full_name", self.task.name)
                )
            )
        return state or Pending("Cache was invalid; ready to run.")

    def run_mapped_task(
        self,
        state: State,
        upstream_states: Dict[Edge, State],
        context: Dict[str, Any],
        executor: "prefect.engine.executors.Executor",
    ) -> State:
        """
        If the task is being mapped, submits children tasks for execution. Returns a `Mapped` state.

        Args:
            - state (State): the current task state
            - upstream_states (Dict[Edge, State]): the upstream states
            - context (dict, optional): prefect Context to use for execution
            - executor (Executor): executor to use when performing computation

        Returns:
            - State: the state of the task after running the check

        Raises:
            - ENDRUN: if the current state is not `Running`
        """

        map_upstream_states = []

        # we don't know how long the iterables are, but we want to iterate until we reach
        # the end of the shortest one
        counter = itertools.count()

        # infinite loop, if upstream_states has any entries
        while True and upstream_states:
            i = next(counter)
            states = {}

            try:

                for edge, upstream_state in upstream_states.items():

                    # if the edge is not mapped over, then we take its state
                    if not edge.mapped:
                        states[edge] = upstream_state

                    # if the edge is mapped and the upstream state is Mapped, then we are mapping
                    # over a mapped task. In this case, we take the appropriately-indexed upstream
                    # state from the upstream tasks's `Mapped.map_states` array.
                    # Note that these "states" might actually be futures at this time; we aren't
                    # blocking until they finish.
                    elif edge.mapped and upstream_state.is_mapped():
                        states[edge] = upstream_state.map_states[i]  # type: ignore

                    # Otherwise, we are mapping over the result of a "vanilla" task. In this
                    # case, we create a copy of the upstream state but set the result to the
                    # appropriately-indexed item from the upstream task's `State.result`
                    # array.
                    else:
                        states[edge] = copy.copy(upstream_state)

                        # if the current state is already Mapped, then we might be executing
                        # a re-run of the mapping pipeline. In that case, the upstream states
                        # might not have `result` attributes (as any required results could be
                        # in the `cached_inputs` attribute of one of the child states).
                        # Therefore, we only try to get a result if EITHER this task's
                        # state is not already mapped OR the upstream result is not None.
                        if not state.is_mapped() or upstream_state._result != NoResult:
                            if not hasattr(upstream_state.result, "__getitem__"):
                                raise TypeError(
                                    "Cannot map over unsubscriptable object of type {t}: {preview}...".format(
                                        t=type(upstream_state.result),
                                        preview=repr(upstream_state.result)[:10],
                                    )
                                )
                            upstream_result = upstream_state._result.from_value(  # type: ignore
                                upstream_state.result[i]
                            )
                            states[edge].result = upstream_result
                        elif state.is_mapped():
                            if i >= len(state.map_states):  # type: ignore
                                raise IndexError()

                # only add this iteration if we made it through all iterables
                map_upstream_states.append(states)

            # index error means we reached the end of the shortest iterable
            except IndexError:
                break

        def run_fn(
            state: State, map_index: int, upstream_states: Dict[Edge, State]
        ) -> State:
            map_context = context.copy()
            map_context.update(map_index=map_index)
            with prefect.context(self.context):
                return self.run(
                    upstream_states=upstream_states,
                    # if we set the state here, then it will not be processed by `initialize_run()`
                    state=state,
                    context=map_context,
                    executor=executor,
                )

        # generate initial states, if available
        if isinstance(state, Mapped):
            initial_states = list(state.map_states)  # type: List[Optional[State]]
        else:
            initial_states = []
        initial_states.extend([None] * (len(map_upstream_states) - len(initial_states)))

        current_state = Mapped(
            message="Preparing to submit {} mapped tasks.".format(len(initial_states)),
            map_states=initial_states,  # type: ignore
        )
        state = self.handle_state_change(old_state=state, new_state=current_state)
        if state is not current_state:
            return state

        # map over the initial states, a counter representing the map_index, and also the mapped upstream states
        map_states = executor.map(
            run_fn, initial_states, range(len(map_upstream_states)), map_upstream_states
        )

        self.logger.debug(
            "{} mapped tasks submitted for execution.".format(len(map_states))
        )
        new_state = Mapped(
            message="Mapped tasks submitted for execution.", map_states=map_states
        )
        return self.handle_state_change(old_state=state, new_state=new_state)

    @call_state_handlers
    def wait_for_mapped_task(
        self, state: State, executor: "prefect.engine.executors.Executor"
    ) -> State:
        """
        Blocks until a mapped state's children have finished running.

        Args:
            - state (State): the current `Mapped` state
            - executor (Executor): the run's executor

        Returns:
            - State: the new state
        """
        if state.is_mapped():
            assert isinstance(state, Mapped)  # mypy assert
            state.map_states = executor.wait(state.map_states)
        return state

    @call_state_handlers
    def set_task_to_running(self, state: State, inputs: Dict[str, Result]) -> State:
        """
        Sets the task to running

        Args:
            - state (State): the current state of this task
            - inputs (Dict[str, Result]): a dictionary of inputs whose keys correspond
                to the task's `run()` arguments.

        Returns:
            - State: the state of the task after running the check

        Raises:
            - ENDRUN: if the task is not ready to run
        """
        if not state.is_pending():
            self.logger.debug(
                "Task '{name}': can't set state to Running because it "
                "isn't Pending; ending run.".format(
                    name=prefect.context.get("task_full_name", self.task.name)
                )
            )
            raise ENDRUN(state)

        new_state = Running(message="Starting task run.", cached_inputs=inputs)
        return new_state

    @run_with_heartbeat
    @call_state_handlers
    def get_task_run_state(
        self,
        state: State,
        inputs: Dict[str, Result],
        timeout_handler: Optional[Callable] = None,
    ) -> State:
        """
        Runs the task and traps any signals or errors it raises.
        Also checkpoints the result of a successful task, if `task.checkpoint` is `True`.

        Args:
            - state (State): the current state of this task
            - inputs (Dict[str, Result], optional): a dictionary of inputs whose keys correspond
                to the task's `run()` arguments.
            - timeout_handler (Callable, optional): function for timing out
                task execution, with call signature `handler(fn, *args, **kwargs)`. Defaults to
                `prefect.utilities.executors.timeout_handler`

        Returns:
            - State: the state of the task after running the check

        Raises:
            - signals.PAUSE: if the task raises PAUSE
            - ENDRUN: if the task is not ready to run
        """
        if not state.is_running():
            self.logger.debug(
                "Task '{name}': can't run task because it's not in a "
                "Running state; ending run.".format(
                    name=prefect.context.get("task_full_name", self.task.name)
                )
            )

            raise ENDRUN(state)

        value = None
        try:
            self.logger.debug(
                "Task '{name}': Calling task.run() method...".format(
                    name=prefect.context.get("task_full_name", self.task.name)
                )
            )
            timeout_handler = (
                timeout_handler or prefect.utilities.executors.timeout_handler
            )
            raw_inputs = {k: r.value for k, r in inputs.items()}

            if getattr(self.task, "log_stdout", False):
                with redirect_stdout(prefect.utilities.logging.RedirectToLog(self.logger)):  # type: ignore
                    value = timeout_handler(
                        self.task.run, timeout=self.task.timeout, **raw_inputs
                    )
            else:
                value = timeout_handler(
                    self.task.run, timeout=self.task.timeout, **raw_inputs
                )

        except KeyboardInterrupt:
            self.logger.debug("Interrupt signal raised, cancelling task run.")
            state = Cancelled(message="Interrupt signal raised, cancelling task run.")
            return state

        # inform user of timeout
        except TimeoutError as exc:
            if prefect.context.get("raise_on_exception"):
                raise exc
            state = TimedOut(
                "Task timed out during execution.", result=exc, cached_inputs=inputs
            )
            return state

        except signals.LOOP as exc:
            new_state = exc.state
            assert isinstance(new_state, Looped)
            new_state.result = self.result.from_value(value=new_state.result)
            new_state.cached_inputs = inputs
            new_state.message = exc.state.message or "Task is looping ({})".format(
                new_state.loop_count
            )
            return new_state

        ## checkpoint tasks if a result is present, except for when the user has opted out by disabling checkpointing
        if (
            prefect.context.get("checkpointing") is True
            and self.task.checkpoint is not False
            and value is not None
        ):
            try:
                result = self.result.write(value, filename="output", **prefect.context)
            except NotImplementedError:
                result = self.result.from_value(value=value)
        else:
            result = self.result.from_value(value=value)

        state = Success(
            result=result, message="Task run succeeded.", cached_inputs=inputs
        )
        return state

    @call_state_handlers
    def cache_result(self, state: State, inputs: Dict[str, Result]) -> State:
        """
        Caches the result of a successful task, if appropriate. Alternatively,
        if the task is failed, caches the inputs.

        Tasks are cached if:
            - task.cache_for is not None
            - the task state is Successful
            - the task state is not Skipped (which is a subclass of Successful)

        Args:
            - state (State): the current state of this task
            - inputs (Dict[str, Result], optional): a dictionary of inputs whose keys correspond
                to the task's `run()` arguments.

        Returns:
            - State: the state of the task after running the check

        """
        state.cached_inputs = inputs

        if (
            state.is_successful()
            and not state.is_skipped()
            and self.task.cache_for is not None
        ):
            expiration = pendulum.now("utc") + self.task.cache_for
            cached_state = Cached(
                result=state._result,
                cached_inputs=inputs,
                cached_result_expiration=expiration,
                cached_parameters=prefect.context.get("parameters"),
                message=state.message,
            )
            return cached_state

        return state

    @call_state_handlers
    def check_for_retry(self, state: State, inputs: Dict[str, Result]) -> State:
        """
        Checks to see if a FAILED task should be retried.

        Args:
            - state (State): the current state of this task
            - inputs (Dict[str, Result], optional): a dictionary of inputs whose keys correspond
                to the task's `run()` arguments.

        Returns:
            - State: the state of the task after running the check
        """
        if state.is_failed():
            run_count = prefect.context.get("task_run_count", 1)
            if prefect.context.get("task_loop_count") is not None:

                loop_result = self.result.from_value(
                    value=prefect.context.get("task_loop_result")
                )

                ## checkpoint tasks if a result is present, except for when the user has opted out by disabling checkpointing
                if (
                    prefect.context.get("checkpointing") is True
                    and self.task.checkpoint is not False
                    and loop_result.value is not None
                ):
                    try:
                        value = prefect.context.get("task_loop_result")
                        loop_result = self.result.write(
                            value, filename="output", **prefect.context
                        )
                    except NotImplementedError:
                        pass

                loop_context = {
                    "_loop_count": PrefectResult(
                        location=json.dumps(prefect.context["task_loop_count"]),
                    ),
                    "_loop_result": loop_result,
                }
                inputs.update(loop_context)
            if run_count <= self.task.max_retries:
                start_time = pendulum.now("utc") + self.task.retry_delay
                msg = "Retrying Task (after attempt {n} of {m})".format(
                    n=run_count, m=self.task.max_retries + 1
                )
                retry_state = Retrying(
                    start_time=start_time,
                    cached_inputs=inputs,
                    message=msg,
                    run_count=run_count,
                )
                return retry_state

        return state

    def check_task_is_looping(
        self,
        state: State,
        inputs: Dict[str, Result] = None,
        upstream_states: Dict[Edge, State] = None,
        context: Dict[str, Any] = None,
        executor: "prefect.engine.executors.Executor" = None,
    ) -> State:
        """
        Checks to see if the task is in a `Looped` state and if so, rerun the pipeline with an incremeneted `loop_count`.

        Args:
            - state (State, optional): initial `State` to begin task run from;
                defaults to `Pending()`
            - inputs (Dict[str, Result], optional): a dictionary of inputs whose keys correspond
                to the task's `run()` arguments.
            - upstream_states (Dict[Edge, State]): a dictionary
                representing the states of any tasks upstream of this one. The keys of the
                dictionary should correspond to the edges leading to the task.
            - context (dict, optional): prefect Context to use for execution
            - executor (Executor, optional): executor to use when performing
                computation; defaults to the executor specified in your prefect configuration

        Returns:
            - `State` object representing the final post-run state of the Task
        """
        if state.is_looped():
            assert isinstance(state, Looped)  # mypy assert
            assert isinstance(context, dict)  # mypy assert
            msg = "Looping task (on loop index {})".format(state.loop_count)
            context.update(
                {
                    "task_loop_result": state.result,
                    "task_loop_count": state.loop_count + 1,
                }
            )
            context.update(task_run_version=prefect.context.get("task_run_version"))
            new_state = Pending(message=msg, cached_inputs=inputs)
            raise RecursiveCall(
                self.run,
                self,
                new_state,
                upstream_states=upstream_states,
                context=context,
                executor=executor,
            )

        return state