Beispiel #1
0
async def set_task_run_state(task_run_id: str,
                             state: State,
                             force: bool = False) -> Dict[str, str]:
    """
    Updates a task run state.

    Args:
        - task_run_id (str): the task run id to update
        - state (State): the new state
        - false (bool): if True, avoids pipeline checks
    Returns:
        - Dict[str, str]: Mapping indicating status of the state
            change operation.
    """

    if task_run_id is None:
        raise ValueError(f"Invalid task run ID.")

    task_run = await models.TaskRun.where({
        "id": {
            "_eq": task_run_id
        },
    }).first({
        "id": True,
        "version": True,
        "state": True,
        "serialized_state": True,
        "flow_run": {
            "id": True,
            "state": True
        },
    })

    if not task_run:
        raise ValueError(f"Invalid task run ID: {task_run_id}.")

    # ------------------------------------------------------
    # if the state is running, ensure the flow run is also running
    # ------------------------------------------------------
    if not force and state.is_running(
    ) and task_run.flow_run.state != "Running":
        raise ValueError(
            f"State update failed for task run ID {task_run_id}: provided "
            f"a running state but associated flow run {task_run.flow_run.id} is not "
            "in a running state.")

    # ------------------------------------------------------
    # if we have cached inputs on the old state, we need to carry them forward
    # ------------------------------------------------------
    if not state.cached_inputs and task_run.serialized_state.get(
            "cached_inputs", None):
        # load up the old state's cached inputs and apply them to the new state
        serialized_state = state_schema.load(task_run.serialized_state)
        state.cached_inputs = serialized_state.cached_inputs

    # --------------------------------------------------------
    # prepare the new state for the database
    # --------------------------------------------------------

    task_run_state = models.TaskRunState(
        task_run_id=task_run.id,
        version=(task_run.version or 0) + 1,
        timestamp=pendulum.now("UTC"),
        message=state.message,
        result=state.result,
        start_time=getattr(state, "start_time", None),
        state=type(state).__name__,
        serialized_state=state.serialize(),
    )

    await task_run_state.insert()
    return {"status": "SUCCESS"}
Beispiel #2
0
    def run_mapped_task(
        self,
        state: State,
        upstream_states: Dict[Edge, State],
        context: Dict[str, Any],
        executor: "prefect.engine.executors.Executor",
    ) -> State:
        """
        If the task is being mapped, submits children tasks for execution. Returns a `Mapped` state.

        Args:
            - state (State): the current task state
            - upstream_states (Dict[Edge, State]): the upstream states
            - context (dict, optional): prefect Context to use for execution
            - executor (Executor): executor to use when performing computation

        Returns:
            - State: the state of the task after running the check

        Raises:
            - ENDRUN: if the current state is not `Running`
        """

        map_upstream_states = []

        # we don't know how long the iterables are, but we want to iterate until we reach
        # the end of the shortest one
        counter = itertools.count()

        # infinite loop, if upstream_states has any entries
        while True and upstream_states:
            i = next(counter)
            states = {}

            try:

                for edge, upstream_state in upstream_states.items():

                    # if the edge is not mapped over, then we take its state
                    if not edge.mapped:
                        states[edge] = upstream_state

                    # if the edge is mapped and the upstream state is Mapped, then we are mapping
                    # over a mapped task. In this case, we take the appropriately-indexed upstream
                    # state from the upstream tasks's `Mapped.map_states` array.
                    # Note that these "states" might actually be futures at this time; we aren't
                    # blocking until they finish.
                    elif edge.mapped and upstream_state.is_mapped():
                        states[edge] = upstream_state.map_states[
                            i]  # type: ignore

                    # Otherwise, we are mapping over the result of a "vanilla" task. In this
                    # case, we create a copy of the upstream state but set the result to the
                    # appropriately-indexed item from the upstream task's `State.result`
                    # array.
                    else:
                        states[edge] = copy.copy(upstream_state)

                        # if the current state is already Mapped, then we might be executing
                        # a re-run of the mapping pipeline. In that case, the upstream states
                        # might not have `result` attributes (as any required results could be
                        # in the `cached_inputs` attribute of one of the child states).
                        # Therefore, we only try to get a result if EITHER this task's
                        # state is not already mapped OR the upstream result is not None.
                        if not state.is_mapped(
                        ) or upstream_state._result != NoResult:
                            if not hasattr(upstream_state.result,
                                           "__getitem__"):
                                raise TypeError(
                                    "Cannot map over unsubscriptable object of type {t}: {preview}..."
                                    .format(
                                        t=type(upstream_state.result),
                                        preview=repr(
                                            upstream_state.result)[:10],
                                    ))
                            upstream_result = upstream_state._result.from_value(  # type: ignore
                                upstream_state.result[i])
                            states[edge].result = upstream_result
                        elif state.is_mapped():
                            if i >= len(state.map_states):  # type: ignore
                                raise IndexError()

                # only add this iteration if we made it through all iterables
                map_upstream_states.append(states)

            # index error means we reached the end of the shortest iterable
            except IndexError:
                break

        def run_fn(state: State, map_index: int,
                   upstream_states: Dict[Edge, State]) -> State:
            map_context = context.copy()
            map_context.update(map_index=map_index)
            with prefect.context(self.context):
                return self.run(
                    upstream_states=upstream_states,
                    # if we set the state here, then it will not be processed by `initialize_run()`
                    state=state,
                    context=map_context,
                    executor=executor,
                )

        # generate initial states, if available
        if isinstance(state, Mapped):
            initial_states = list(
                state.map_states)  # type: List[Optional[State]]
        else:
            initial_states = []
        initial_states.extend([None] *
                              (len(map_upstream_states) - len(initial_states)))

        current_state = Mapped(
            message="Preparing to submit {} mapped tasks.".format(
                len(initial_states)),
            map_states=initial_states,  # type: ignore
        )
        state = self.handle_state_change(old_state=state,
                                         new_state=current_state)
        if state is not current_state:
            return state

        # map over the initial states, a counter representing the map_index, and also the mapped upstream states
        map_states = executor.map(run_fn, initial_states,
                                  range(len(map_upstream_states)),
                                  map_upstream_states)

        self.logger.debug("{} mapped tasks submitted for execution.".format(
            len(map_states)))
        new_state = Mapped(message="Mapped tasks submitted for execution.",
                           map_states=map_states)
        return self.handle_state_change(old_state=state, new_state=new_state)
Beispiel #3
0
def test_state_equality_with_nested_states():
    s1 = State(result=Success(result=1))
    s2 = State(result=Success(result=2))
    s3 = State(result=Success(result=1))
    assert s1 != s2
    assert s1 == s3
Beispiel #4
0
def test_states_with_mutable_attrs_are_hashable():
    assert {State(result=[1]), Pending(cached_inputs=dict(a=1))}
Beispiel #5
0
def test_parent_method_on_base_state():
    assert State.parents() == []
Beispiel #6
0
def test_state_equality_ignores_context():
    s, r = State(result=1), State(result=1)
    s.context["key"] = "value"
    assert s == r
Beispiel #7
0
def test_state_pickle_with_unpicklable_result_raises():
    state = State(result=RLock())  # An unpickable result type
    with pytest.raises(TypeError, match="pickle"):
        cloudpickle.dumps(state)
Beispiel #8
0
    def get_task_run_state(self, state: State, inputs: Dict[str, Result]) -> State:
        """
        Runs the task and traps any signals or errors it raises.
        Also checkpoints the result of a successful task, if `task.checkpoint` is `True`.

        Args:
            - state (State): the current state of this task
            - inputs (Dict[str, Result], optional): a dictionary of inputs whose keys correspond
                to the task's `run()` arguments.

        Returns:
            - State: the state of the task after running the check

        Raises:
            - signals.PAUSE: if the task raises PAUSE
            - ENDRUN: if the task is not ready to run
        """
        if not state.is_running():
            self.logger.debug(
                "Task '{name}': Can't run task because it's not in a "
                "Running state; ending run.".format(
                    name=prefect.context.get("task_full_name", self.task.name)
                )
            )

            raise ENDRUN(state)

        value = None
        raw_inputs = {k: r.value for k, r in inputs.items()}
        new_state = None
        try:
            self.logger.debug(
                "Task '{name}': Calling task.run() method...".format(
                    name=prefect.context.get("task_full_name", self.task.name)
                )
            )

            # Create a stdout redirect if the task has log_stdout enabled
            log_context = (
                redirect_stdout(prefect.utilities.logging.RedirectToLog(self.logger))
                if getattr(self.task, "log_stdout", False)
                else nullcontext()
            )  # type: AbstractContextManager

            with log_context:
                value = prefect.utilities.executors.run_task_with_timeout(
                    task=self.task,
                    args=(),
                    kwargs=raw_inputs,
                    logger=self.logger,
                )

        # inform user of timeout
        except TimeoutError as exc:
            if prefect.context.get("raise_on_exception"):
                raise exc
            state = TimedOut("Task timed out during execution.", result=exc)
            return state

        except signals.LOOP as exc:
            new_state = exc.state
            assert isinstance(new_state, Looped)
            value = new_state.result
            new_state.message = exc.state.message or "Task is looping ({})".format(
                new_state.loop_count
            )

        # checkpoint tasks if a result is present, except for when the user has opted out by
        # disabling checkpointing
        if (
            prefect.context.get("checkpointing") is True
            and self.task.checkpoint is not False
            and value is not None
        ):
            try:
                formatting_kwargs = {
                    **prefect.context.get("parameters", {}).copy(),
                    **prefect.context,
                    **raw_inputs,
                }
                result = self.result.write(value, **formatting_kwargs)
            except ResultNotImplementedError:
                result = self.result.from_value(value=value)
        else:
            result = self.result.from_value(value=value)

        if new_state is not None:
            new_state.result = result
            return new_state

        state = Success(result=result, message="Task run succeeded.")
        return state
Beispiel #9
0
def test_parent_method_on_base_state(include_self):
    assert State.parents(
        include_self=include_self) == ([State] if include_self else [])
Beispiel #10
0
def test_parent_method_on_base_state_names_only(include_self):
    assert State.parents(
        include_self=include_self,
        names_only=True) == (["State"] if include_self else [])
Beispiel #11
0
def test_children_method_on_base_state(include_self):
    all_states_set = set(all_states)
    if not include_self:
        all_states_set.remove(State)
    assert all_states_set == set(State.children(include_self=include_self))
Beispiel #12
0
class TestTaskRunStates:
    async def test_set_task_run_state(self, task_run_id):
        result = await api.states.set_task_run_state(task_run_id=task_run_id,
                                                     state=Failed())

        assert result.task_run_id == task_run_id

        query = await models.TaskRun.where(id=task_run_id).first(
            {"version", "state", "serialized_state"})

        assert query.version == 2
        assert query.state == "Failed"
        assert query.serialized_state["type"] == "Failed"

    @pytest.mark.parametrize("state", [Failed(), Success()])
    async def test_set_task_run_state_fails_with_wrong_task_run_id(
            self, state):
        with pytest.raises(ValueError, match="State update failed"):
            await api.states.set_task_run_state(task_run_id=str(uuid.uuid4()),
                                                state=state)

    @pytest.mark.parametrize(
        "state", [s() for s in State.children() if not s().is_running()])
    async def test_state_does_not_set_heartbeat_unless_running(
            self, state, task_run_id):
        task_run = await models.TaskRun.where(id=task_run_id
                                              ).first({"heartbeat"})
        assert task_run.heartbeat is None

        await api.states.set_task_run_state(task_run_id=task_run_id,
                                            state=state)

        task_run = await models.TaskRun.where(id=task_run_id
                                              ).first({"heartbeat"})
        assert task_run.heartbeat is None

    async def test_running_state_sets_heartbeat(self, task_run_id,
                                                running_flow_run_id):
        task_run = await models.TaskRun.where(id=task_run_id
                                              ).first({"heartbeat"})
        assert task_run.heartbeat is None

        dt = pendulum.now("UTC")
        await api.states.set_task_run_state(task_run_id=task_run_id,
                                            state=Running())

        task_run = await models.TaskRun.where(id=task_run_id
                                              ).first({"heartbeat"})
        assert task_run.heartbeat > dt

    async def test_trigger_failed_state_does_not_set_end_time(
            self, task_run_id):
        await api.states.set_task_run_state(task_run_id=task_run_id,
                                            state=TriggerFailed())
        task_run_info = await models.TaskRun.where(id=task_run_id).first(
            {"id", "start_time", "end_time"})
        assert not task_run_info.start_time
        assert not task_run_info.end_time

    @pytest.mark.parametrize(
        "state",
        [s() for s in State.children() if s not in _MetaState.children()],
        ids=[
            s.__name__ for s in State.children()
            if s not in _MetaState.children()
        ],
    )
    async def test_setting_a_task_run_state_pulls_cached_inputs_if_possible(
            self, task_run_id, state, running_flow_run_id):

        res1 = SafeResult(1, result_handler=JSONResultHandler())
        res2 = SafeResult({"z": 2}, result_handler=JSONResultHandler())
        complex_result = {"x": res1, "y": res2}
        cached_state = Failed(cached_inputs=complex_result)
        await models.TaskRun.where(id=task_run_id).update(set=dict(
            serialized_state=cached_state.serialize()))

        # try to schedule the task run to scheduled
        await api.states.set_task_run_state(task_run_id=task_run_id,
                                            state=state)

        task_run = await models.TaskRun.where(id=task_run_id
                                              ).first({"serialized_state"})

        # ensure the state change took place
        assert task_run.serialized_state["type"] == type(state).__name__
        assert task_run.serialized_state["cached_inputs"]["x"]["value"] == 1
        assert task_run.serialized_state["cached_inputs"]["y"]["value"] == {
            "z": 2
        }

    @pytest.mark.parametrize(
        "state",
        [
            s(cached_inputs=None)
            for s in State.children() if s not in _MetaState.children()
        ],
        ids=[
            s.__name__ for s in State.children()
            if s not in _MetaState.children()
        ],
    )
    async def test_task_runs_with_null_cached_inputs_do_not_overwrite_cache(
            self, state, task_run_id, running_flow_run_id):

        await api.states.set_task_run_state(task_run_id=task_run_id,
                                            state=state)
        # set up a Retrying state with non-null cached inputs
        res1 = SafeResult(1, result_handler=JSONResultHandler())
        res2 = SafeResult({"z": 2}, result_handler=JSONResultHandler())
        complex_result = {"x": res1, "y": res2}
        cached_state = Retrying(cached_inputs=complex_result)
        await api.states.set_task_run_state(task_run_id=task_run_id,
                                            state=cached_state)
        run = await models.TaskRun.where(id=task_run_id
                                         ).first({"serialized_state"})

        assert run.serialized_state["cached_inputs"]["x"]["value"] == 1
        assert run.serialized_state["cached_inputs"]["y"]["value"] == {"z": 2}

    @pytest.mark.parametrize(
        "state_cls",
        [s for s in State.children() if s not in _MetaState.children()])
    async def test_task_runs_cached_inputs_give_preference_to_new_cached_inputs(
            self, state_cls, task_run_id, running_flow_run_id):

        # set up a Failed state with null cached inputs
        res1 = SafeResult(1, result_handler=JSONResultHandler())
        res2 = SafeResult({"a": 2}, result_handler=JSONResultHandler())
        complex_result = {"b": res1, "c": res2}
        cached_state = state_cls(cached_inputs=complex_result)
        await api.states.set_task_run_state(task_run_id=task_run_id,
                                            state=cached_state)
        # set up a Retrying state with non-null cached inputs
        res1 = SafeResult(1, result_handler=JSONResultHandler())
        res2 = SafeResult({"z": 2}, result_handler=JSONResultHandler())
        complex_result = {"x": res1, "y": res2}
        cached_state = Retrying(cached_inputs=complex_result)
        await api.states.set_task_run_state(task_run_id=task_run_id,
                                            state=cached_state)
        run = Box(await models.TaskRun.where(id=task_run_id
                                             ).first({"serialized_state"}))

        # verify that we have cached inputs, and that preference has been given to the new
        # cached inputs
        assert run.serialized_state.cached_inputs
        assert run.serialized_state.cached_inputs.x.value == 1
        assert run.serialized_state.cached_inputs.y.value == {"z": 2}

    @pytest.mark.parametrize(
        "flow_run_state",
        [Pending(), Running(), Failed(),
         Success()])
    async def test_running_states_can_not_be_set_if_flow_run_is_not_running(
            self, flow_run_id, task_run_id, flow_run_state):

        await api.states.set_flow_run_state(flow_run_id=flow_run_id,
                                            state=flow_run_state)

        set_running_coroutine = api.states.set_task_run_state(
            task_run_id=task_run_id, state=Running())

        if flow_run_state.is_running():
            assert await set_running_coroutine
            assert (await
                    models.TaskRun.where(id=task_run_id
                                         ).first({"state"})).state == "Running"
        else:

            with pytest.raises(ValueError, match="is not in a running state"):
                await set_running_coroutine
            assert (await models.TaskRun.where(id=task_run_id).first(
                {"state"})).state != "Running"
Beispiel #13
0
class TestFlowRunStates:
    async def test_set_flow_run_state(self, flow_run_id):
        result = await api.states.set_flow_run_state(flow_run_id=flow_run_id,
                                                     state=Running())

        assert result.flow_run_id == flow_run_id

        query = await models.FlowRun.where(id=flow_run_id).first(
            {"version", "state", "serialized_state"})

        assert query.version == 3
        assert query.state == "Running"
        assert query.serialized_state["type"] == "Running"

    @pytest.mark.parametrize("state", [Running(), Success()])
    async def test_set_flow_run_state_fails_with_wrong_flow_run_id(
            self, state):
        with pytest.raises(ValueError, match="State update failed"):
            await api.states.set_flow_run_state(flow_run_id=str(uuid.uuid4()),
                                                state=state)

    async def test_trigger_failed_state_does_not_set_end_time(
            self, flow_run_id):
        # there is no logic in Prefect that would create this sequence of
        # events, but a user could manually do this
        await api.states.set_flow_run_state(flow_run_id=flow_run_id,
                                            state=TriggerFailed())
        flow_run_info = await models.FlowRun.where(id=flow_run_id).first(
            {"id", "start_time", "end_time"})
        assert not flow_run_info.start_time
        assert not flow_run_info.end_time

    @pytest.mark.parametrize(
        "state",
        [
            s() for s in State.children()
            if not s().is_running() and not s().is_submitted()
        ],
    )
    async def test_state_does_not_set_heartbeat_unless_running_or_submitted(
            self, state, flow_run_id):
        flow_run = await models.FlowRun.where(id=flow_run_id
                                              ).first({"heartbeat"})
        assert flow_run.heartbeat is None

        dt = pendulum.now("UTC")
        await api.states.set_flow_run_state(flow_run_id=flow_run_id,
                                            state=state)

        flow_run = await models.FlowRun.where(id=flow_run_id
                                              ).first({"heartbeat"})
        assert flow_run.heartbeat is None

    @pytest.mark.parametrize("state", [Running(), Submitted()])
    async def test_running_and_submitted_state_sets_heartbeat(
            self, state, flow_run_id):
        """
        Both Running and Submitted states need to set heartbeats for services like Lazarus to
        function properly.
        """
        flow_run = await models.FlowRun.where(id=flow_run_id
                                              ).first({"heartbeat"})
        assert flow_run.heartbeat is None

        dt = pendulum.now("UTC")
        await api.states.set_flow_run_state(flow_run_id=flow_run_id,
                                            state=state)

        flow_run = await models.FlowRun.where(id=flow_run_id
                                              ).first({"heartbeat"})
        assert flow_run.heartbeat > dt

    async def test_setting_flow_run_to_cancelled_state_sets_unfinished_task_runs_to_cancelled(
            self, flow_run_id):
        task_runs = await models.TaskRun.where({
            "flow_run_id": {
                "_eq": flow_run_id
            }
        }).get({"id"})
        task_run_ids = [run.id for run in task_runs]
        # update the state to Running
        await api.states.set_flow_run_state(flow_run_id=flow_run_id,
                                            state=Running())
        # Currently this flow_run_id fixture has at least 3 tasks, if this
        # changes the test will need to be updated
        assert len(task_run_ids) >= 3, "flow_run_id fixture has changed"
        # Set one task run to pending, one to running, and the rest to success
        pending_task_run = task_run_ids[0]
        running_task_run = task_run_ids[1]
        rest = task_run_ids[2:]
        await api.states.set_task_run_state(task_run_id=pending_task_run,
                                            state=Pending())
        await api.states.set_task_run_state(task_run_id=running_task_run,
                                            state=Running())
        for task_run_id in rest:
            await api.states.set_task_run_state(task_run_id=task_run_id,
                                                state=Success())
        # set the flow run to a cancelled state
        await api.states.set_flow_run_state(flow_run_id=flow_run_id,
                                            state=Cancelled())
        # Confirm the unfinished task runs have been marked as cancelled
        task_runs = await models.TaskRun.where({
            "flow_run_id": {
                "_eq": flow_run_id
            }
        }).get({"id", "state"})
        new_states = {run.id: run.state for run in task_runs}
        assert new_states[pending_task_run] == "Cancelled"
        assert new_states[running_task_run] == "Cancelled"
        assert all(new_states[id] == "Success" for id in rest)
Beispiel #14
0
    def get_task_run_state(
        self,
        state: State,
        inputs: Dict[str, Result],
        timeout_handler: Optional[Callable] = None,
    ) -> State:
        """
        Runs the task and traps any signals or errors it raises.
        Also checkpoints the result of a successful task, if `task.checkpoint` is `True`.

        Args:
            - state (State): the current state of this task
            - inputs (Dict[str, Result], optional): a dictionary of inputs whose keys correspond
                to the task's `run()` arguments.
            - timeout_handler (Callable, optional): function for timing out
                task execution, with call signature `handler(fn, *args, **kwargs)`. Defaults to
                `prefect.utilities.executors.timeout_handler`

        Returns:
            - State: the state of the task after running the check

        Raises:
            - signals.PAUSE: if the task raises PAUSE
            - ENDRUN: if the task is not ready to run
        """
        if not state.is_running():
            self.logger.debug(
                "Task '{name}': can't run task because it's not in a "
                "Running state; ending run.".format(name=prefect.context.get(
                    "task_full_name", self.task.name)))

            raise ENDRUN(state)

        value = None
        try:
            self.logger.debug(
                "Task '{name}': Calling task.run() method...".format(
                    name=prefect.context.get("task_full_name",
                                             self.task.name)))
            timeout_handler = (timeout_handler
                               or prefect.utilities.executors.timeout_handler)
            raw_inputs = {k: r.value for k, r in inputs.items()}

            if getattr(self.task, "log_stdout", False):
                with redirect_stdout(
                        prefect.utilities.logging.RedirectToLog(
                            self.logger)):  # type: ignore
                    result = timeout_handler(self.task.run,
                                             timeout=self.task.timeout,
                                             **raw_inputs)
            else:
                value = timeout_handler(self.task.run,
                                        timeout=self.task.timeout,
                                        **raw_inputs)

        except KeyboardInterrupt:
            self.logger.debug("Interrupt signal raised, cancelling task run.")
            state = Cancelled(
                message="Interrupt signal raised, cancelling task run.")
            return state

        # inform user of timeout
        except TimeoutError as exc:
            if prefect.context.get("raise_on_exception"):
                raise exc
            state = TimedOut("Task timed out during execution.",
                             result=exc,
                             cached_inputs=inputs)
            return state

        except signals.LOOP as exc:
            new_state = exc.state
            assert isinstance(new_state, Looped)
            new_state.result = self.result.from_value(value=new_state.result)
            new_state.cached_inputs = inputs
            new_state.message = exc.state.message or "Task is looping ({})".format(
                new_state.loop_count)
            return new_state

        ## checkpoint tasks if a result is present, except for when the user has opted out by disabling checkpointing
        if (prefect.context.get("checkpointing") is True
                and self.task.checkpoint is not False and value is not None):
            try:
                result = self.result.write(value, **prefect.context)
            except NotImplementedError:
                result = self.result.from_value(value=value)
        else:
            result = self.result.from_value(value=value)

        state = Success(result=result,
                        message="Task run succeeded.",
                        cached_inputs=inputs)
        return state
Beispiel #15
0
def test_state_pickle_with_exception():
    state = State(result=Exception("foo"))
    new_state = cloudpickle.loads(cloudpickle.dumps(state))

    assert isinstance(new_state.result, Exception)
    assert new_state.result.args == ("foo", )
Beispiel #16
0
    def check_task_ready_to_map(
        self, state: State, upstream_states: Dict[Edge, State]
    ) -> State:
        """
        Checks if the parent task is ready to proceed with mapping.

        Args:
            - state (State): the current state of this task
            - upstream_states (Dict[Edge, Union[State, List[State]]]): the upstream states

        Raises:
            - ENDRUN: either way, we dont continue past this point
        """
        if state.is_mapped():
            # this indicates we are executing a re-run of a mapped pipeline;
            # in this case, we populate both `map_states` and `cached_inputs`
            # to ensure the flow runner can properly regenerate the child tasks,
            # regardless of whether we mapped over an exchanged piece of data
            # or a non-data-exchanging upstream dependency
            if len(state.map_states) == 0 and state.n_map_states > 0:  # type: ignore
                state.map_states = [None] * state.n_map_states  # type: ignore
            state.cached_inputs = {
                edge.key: state._result  # type: ignore
                for edge, state in upstream_states.items()
                if edge.key
            }
            raise ENDRUN(state)

        # we can't map if there are no success states with iterables upstream
        if upstream_states and not any(
            [
                edge.mapped and state.is_successful()
                for edge, state in upstream_states.items()
            ]
        ):
            new_state = Failed("No upstream states can be mapped over.")  # type: State
            raise ENDRUN(new_state)
        elif not all(
            [
                hasattr(state.result, "__getitem__")
                for edge, state in upstream_states.items()
                if state.is_successful() and not state.is_mapped() and edge.mapped
            ]
        ):
            new_state = Failed("At least one upstream state has an unmappable result.")
            raise ENDRUN(new_state)
        else:
            # compute and set n_map_states
            n_map_states = min(
                [
                    len(s.result)
                    for e, s in upstream_states.items()
                    if e.mapped and s.is_successful() and not s.is_mapped()
                ]
                + [
                    s.n_map_states  # type: ignore
                    for e, s in upstream_states.items()
                    if e.mapped and s.is_mapped()
                ],
                default=0,
            )
            new_state = Mapped(
                "Ready to proceed with mapping.", n_map_states=n_map_states
            )
            raise ENDRUN(new_state)
async def set_task_run_state(
    task_run_id: str, state: State, version: int = None, flow_run_version: int = None
) -> models.TaskRunState:
    """
    Updates a task run state.

    Args:
        - task_run_id (str): the task run id to update
        - state (State): the new state
        - version (int): a version (only for Cloud API compatibility)
        - flow_run_version (int): a flow run version (only for Cloud API compatibility)

    Returns:
        - models.TaskRunState
    """

    if task_run_id is None:
        raise ValueError(f"Invalid task run ID.")

    task_run = await models.TaskRun.where(id=task_run_id).first(
        {
            "id": True,
            "tenant_id": True,
            "version": True,
            "state": True,
            "serialized_state": True,
            "flow_run": {"id": True, "state": True},
        }
    )

    if not task_run:
        raise ValueError(f"State update failed for task run ID {task_run_id}")

    # ------------------------------------------------------
    # if the state is running, ensure the flow run is also running
    # ------------------------------------------------------
    if state.is_running() and task_run.flow_run.state != "Running":
        raise ValueError(
            f"State update failed for task run ID {task_run_id}: provided "
            f"a running state but associated flow run {task_run.flow_run.id} is not "
            "in a running state."
        )

    # ------------------------------------------------------
    # if we have cached inputs on the old state, we need to carry them forward
    # ------------------------------------------------------
    if not state.cached_inputs and task_run.serialized_state.get("cached_inputs", None):
        # load up the old state's cached inputs and apply them to the new state
        serialized_state = state_schema.load(task_run.serialized_state)
        state.cached_inputs = serialized_state.cached_inputs

    # --------------------------------------------------------
    # prepare the new state for the database
    # --------------------------------------------------------

    task_run_state = models.TaskRunState(
        id=str(uuid.uuid4()),
        tenant_id=task_run.tenant_id,
        task_run_id=task_run.id,
        version=(task_run.version or 0) + 1,
        timestamp=pendulum.now("UTC"),
        message=state.message,
        result=state.result,
        start_time=getattr(state, "start_time", None),
        state=type(state).__name__,
        serialized_state=state.serialize(),
    )

    await task_run_state.insert()

    # --------------------------------------------------------
    # apply downstream updates
    # --------------------------------------------------------

    # FOR RUNNING STATES:
    #   - update the task run heartbeat
    if state.is_running():
        await api.runs.update_task_run_heartbeat(task_run_id=task_run_id)

    return task_run_state
Beispiel #18
0
    def get_flow_run_state(
        self,
        state: State,
        task_states: Dict[Task, State],
        task_contexts: Dict[Task, Dict[str, Any]],
        return_tasks: Set[Task],
        task_runner_state_handlers: Iterable[Callable],
        executor: "prefect.engine.executors.base.Executor",
    ) -> State:
        """
        Runs the flow.

        Args:
            - state (State): starting state for the Flow. Defaults to
                `Pending`
            - task_states (dict): dictionary of task states to begin
                computation with, with keys being Tasks and values their corresponding state
            - task_contexts (Dict[Task, Dict[str, Any]]): contexts that will be provided to each task
            - return_tasks ([Task], optional): list of Tasks to include in the
                final returned Flow state. Defaults to `None`
            - task_runner_state_handlers (Iterable[Callable]): A list of state change
                handlers that will be provided to the task_runner, and called whenever a task changes
                state.
            - executor (Executor): executor to use when performing
                computation; defaults to the executor provided in your prefect configuration

        Returns:
            - State: `State` representing the final post-run state of the `Flow`.

        """

        if not state.is_running():
            self.logger.info("Flow is not in a Running state.")
            raise ENDRUN(state)

        if return_tasks is None:
            return_tasks = set()
        if set(return_tasks).difference(self.flow.tasks):
            raise ValueError("Some tasks in return_tasks were not found in the flow.")

        # -- process each task in order

        with executor.start():

            for task in self.flow.sorted_tasks():

                task_state = task_states.get(task)
                if task_state is None and isinstance(
                    task, prefect.tasks.core.constants.Constant
                ):
                    task_states[task] = task_state = Success(result=task.value)

                # if the state is finished, don't run the task, just use the provided state
                if (
                    isinstance(task_state, State)
                    and task_state.is_finished()
                    and not task_state.is_cached()
                    and not task_state.is_mapped()
                ):
                    continue

                upstream_states = {}  # type: Dict[Edge, Union[State, Iterable]]

                # -- process each edge to the task
                for edge in self.flow.edges_to(task):
                    upstream_states[edge] = task_states.get(
                        edge.upstream_task, Pending(message="Task state not available.")
                    )

                # augment edges with upstream constants
                for key, val in self.flow.constants[task].items():
                    edge = Edge(
                        upstream_task=prefect.tasks.core.constants.Constant(val),
                        downstream_task=task,
                        key=key,
                    )
                    upstream_states[edge] = Success(
                        "Auto-generated constant value",
                        result=Result(val, result_handler=ConstantResultHandler(val)),
                    )

                # -- run the task

                with prefect.context(task_full_name=task.name, task_tags=task.tags):
                    task_states[task] = executor.submit(
                        self.run_task,
                        task=task,
                        state=task_state,
                        upstream_states=upstream_states,
                        context=dict(prefect.context, **task_contexts.get(task, {})),
                        task_runner_state_handlers=task_runner_state_handlers,
                        executor=executor,
                    )

            # ---------------------------------------------
            # Collect results
            # ---------------------------------------------

            # terminal tasks determine if the flow is finished
            terminal_tasks = self.flow.terminal_tasks()

            # reference tasks determine flow state
            reference_tasks = self.flow.reference_tasks()

            # wait until all terminal tasks are finished
            final_tasks = terminal_tasks.union(reference_tasks).union(return_tasks)
            final_states = executor.wait(
                {
                    t: task_states.get(t, Pending("Task not evaluated by FlowRunner."))
                    for t in final_tasks
                }
            )

            # also wait for any children of Mapped tasks to finish, and add them
            # to the dictionary to determine flow state
            all_final_states = final_states.copy()
            for t, s in list(final_states.items()):
                if s.is_mapped():
                    s.map_states = executor.wait(s.map_states)
                    s.result = [ms.result for ms in s.map_states]
                    all_final_states[t] = s.map_states

            assert isinstance(final_states, dict)

        key_states = set(flatten_seq([all_final_states[t] for t in reference_tasks]))
        terminal_states = set(
            flatten_seq([all_final_states[t] for t in terminal_tasks])
        )
        return_states = {t: final_states[t] for t in return_tasks}

        state = self.determine_final_state(
            state=state,
            key_states=key_states,
            return_states=return_states,
            terminal_states=terminal_states,
        )

        return state
async def set_flow_run_state(
    flow_run_id: str, state: State, version: int = None, agent_id: str = None
) -> models.FlowRunState:
    """
    Updates a flow run state.

    Args:
        - flow_run_id (str): the flow run id to update
        - state (State): the new state
        - version (int): a version (only for Cloud API compatibility)
        - agent_id (str): the ID of an agent instance setting the state

    Returns:
        - models.FlowRunState
    """

    if flow_run_id is None:
        raise ValueError(f"Invalid flow run ID.")

    flow_run = await models.FlowRun.where(id=flow_run_id).first(
        {
            "id": True,
            "state": True,
            "name": True,
            "version": True,
            "flow": {"id", "name", "flow_group_id", "version_group_id"},
            "tenant": {"id", "slug"},
        }
    )

    if not flow_run:
        raise ValueError(f"State update failed for flow run ID {flow_run_id}")

    # --------------------------------------------------------
    # apply downstream updates
    # --------------------------------------------------------

    # FOR CANCELLED STATES:
    #   - set all non-finished task run states to Cancelled
    if isinstance(state, Cancelled):
        task_runs = await models.TaskRun.where(
            {"flow_run_id": {"_eq": flow_run_id}}
        ).get({"id", "serialized_state"})
        to_cancel = [
            t
            for t in task_runs
            if not state_schema.load(t.serialized_state).is_finished()
        ]
        # For a run with many tasks this may be a lot of tasks - at some point
        # we might want to batch this rather than kicking off lots of asyncio
        # tasks at once.
        await asyncio.gather(
            *(api.states.set_task_run_state(t.id, state) for t in to_cancel),
            return_exceptions=True,
        )

    # --------------------------------------------------------
    # insert the new state in the database
    # --------------------------------------------------------

    flow_run_state = models.FlowRunState(
        id=str(uuid.uuid4()),
        tenant_id=flow_run.tenant_id,
        flow_run_id=flow_run_id,
        version=(flow_run.version or 0) + 1,
        state=type(state).__name__,
        timestamp=pendulum.now("UTC"),
        message=state.message,
        result=state.result,
        start_time=getattr(state, "start_time", None),
        serialized_state=state.serialize(),
    )

    await flow_run_state.insert()

    # --------------------------------------------------------
    # apply downstream updates
    # --------------------------------------------------------

    # FOR RUNNING STATES:
    #   - update the flow run heartbeat
    if state.is_running() or state.is_submitted():
        await api.runs.update_flow_run_heartbeat(flow_run_id=flow_run_id)

    # Set agent ID on flow run when submitted by agent
    if state.is_submitted() and agent_id:
        await api.runs.update_flow_run_agent(flow_run_id=flow_run_id, agent_id=agent_id)

    # --------------------------------------------------------
    # call cloud hooks
    # --------------------------------------------------------

    event = events.FlowRunStateChange(
        flow_run=flow_run,
        state=flow_run_state,
        flow=flow_run.flow,
        tenant=flow_run.tenant,
    )

    asyncio.create_task(api.cloud_hooks.call_hooks(event))

    return flow_run_state
Beispiel #20
0
    def get_flow_run_state(
        self,
        state: State,
        task_states: Dict[Task, State],
        task_contexts: Dict[Task, Dict[str, Any]],
        return_tasks: Set[Task],
        task_runner_state_handlers: Iterable[Callable],
        executor: "prefect.engine.executors.base.Executor",
    ) -> State:
        """
        Runs the flow.

        Args:
            - state (State): starting state for the Flow. Defaults to
                `Pending`
            - task_states (dict): dictionary of task states to begin
                computation with, with keys being Tasks and values their corresponding state
            - task_contexts (Dict[Task, Dict[str, Any]]): contexts that will be provided to
                each task
            - return_tasks ([Task], optional): list of Tasks to include in the
                final returned Flow state. Defaults to `None`
            - task_runner_state_handlers (Iterable[Callable]): A list of state change handlers
                that will be provided to the task_runner, and called whenever a task changes
                state.
            - executor (Executor): executor to use when performing computation; defaults to the
                executor provided in your prefect configuration

        Returns:
            - State: `State` representing the final post-run state of the `Flow`.

        """
        # this dictionary is used for tracking the states of "children" mapped tasks;
        # when running on Dask, we want to avoid serializing futures, so instead
        # of storing child task states in the `map_states` attribute we instead store
        # in this dictionary and only after they are resolved do we attach them to the Mapped state
        mapped_children = dict()  # type: Dict[Task, list]

        if not state.is_running():
            self.logger.info("Flow is not in a Running state.")
            raise ENDRUN(state)

        if return_tasks is None:
            return_tasks = set()
        if set(return_tasks).difference(self.flow.tasks):
            raise ValueError(
                "Some tasks in return_tasks were not found in the flow.")

        def extra_context(task: Task, task_index: int = None) -> dict:
            return {
                "task_name": task.name,
                "task_tags": task.tags,
                "task_index": task_index,
            }

        # -- process each task in order

        with self.check_for_cancellation(), executor.start():

            for task in self.flow.sorted_tasks():
                task_state = task_states.get(task)

                # if a task is a constant task, we already know its return value
                # no need to use up resources by running it through a task runner
                if task_state is None and isinstance(
                        task, prefect.tasks.core.constants.Constant):
                    task_states[task] = task_state = Success(result=task.value)

                # Always restart completed resource setup/cleanup tasks unless
                # they were explicitly cached.
                # TODO: we only need to rerun these tasks if any pending
                # downstream tasks depend on them.
                if (isinstance(
                        task,
                    (
                        prefect.tasks.core.resource_manager.ResourceSetupTask,
                        prefect.tasks.core.resource_manager.
                        ResourceCleanupTask,
                    ),
                ) and task_state is not None and task_state.is_finished()
                        and not task_state.is_cached()):
                    task_states[task] = task_state = Pending()

                # if the state is finished, don't run the task, just use the provided state if
                # the state is cached / mapped, we still want to run the task runner pipeline
                # steps to either ensure the cache is still valid / or to recreate the mapped
                # pipeline for possible retries
                if (isinstance(task_state, State) and task_state.is_finished()
                        and not task_state.is_cached()
                        and not task_state.is_mapped()):
                    continue

                upstream_states = {}  # type: Dict[Edge, State]

                # this dictionary is used exclusively for "reduce" tasks in particular we store
                # the states / futures corresponding to the upstream children, and if running
                # on Dask, let Dask resolve them at the appropriate time.
                # Note: this is an optimization that allows Dask to resolve the mapped
                # dependencies by "elevating" them to a function argument.
                upstream_mapped_states = {}  # type: Dict[Edge, list]

                # -- process each edge to the task
                for edge in self.flow.edges_to(task):

                    # load the upstream task states (supplying Pending as a default)
                    upstream_states[edge] = task_states.get(
                        edge.upstream_task,
                        Pending(message="Task state not available."))

                    # if the edge is flattened and not the result of a map, then we
                    # preprocess the upstream states. If it IS the result of a
                    # map, it will be handled in `prepare_upstream_states_for_mapping`
                    if edge.flattened:
                        if not isinstance(upstream_states[edge], Mapped):
                            upstream_states[edge] = executor.submit(
                                executors.flatten_upstream_state,
                                upstream_states[edge])

                    # this checks whether the task is a "reduce" task for a mapped pipeline
                    # and if so, collects the appropriate upstream children
                    if not edge.mapped and isinstance(upstream_states[edge],
                                                      Mapped):
                        children = mapped_children.get(edge.upstream_task, [])

                        # if the edge is flattened, then we need to wait for the mapped children
                        # to complete and then flatten them
                        if edge.flattened:
                            children = executors.flatten_mapped_children(
                                mapped_children=children, executor=executor)

                        upstream_mapped_states[edge] = children

                # augment edges with upstream constants
                for key, val in self.flow.constants[task].items():
                    edge = Edge(
                        upstream_task=prefect.tasks.core.constants.Constant(
                            val),
                        downstream_task=task,
                        key=key,
                    )
                    upstream_states[edge] = Success(
                        "Auto-generated constant value",
                        result=ConstantResult(value=val),
                    )

                # handle mapped tasks
                if any(edge.mapped for edge in upstream_states.keys()):

                    # wait on upstream states to determine the width of the pipeline
                    # this is the key to depth-first execution
                    upstream_states = executor.wait(
                        {e: state
                         for e, state in upstream_states.items()})
                    # we submit the task to the task runner to determine if
                    # we can proceed with mapping - if the new task state is not a Mapped
                    # state then we don't proceed
                    task_states[task] = executor.wait(
                        executor.submit(
                            run_task,
                            task=task,
                            state=task_state,  # original state
                            upstream_states=upstream_states,
                            context=dict(prefect.context,
                                         **task_contexts.get(task, {})),
                            flow_result=self.flow.result,
                            task_runner_cls=self.task_runner_cls,
                            task_runner_state_handlers=
                            task_runner_state_handlers,
                            upstream_mapped_states=upstream_mapped_states,
                            is_mapped_parent=True,
                            extra_context=extra_context(task),
                        ))

                    # either way, we should now have enough resolved states to restructure
                    # the upstream states into a list of upstream state dictionaries to iterate over
                    list_of_upstream_states = (
                        executors.prepare_upstream_states_for_mapping(
                            task_states[task],
                            upstream_states,
                            mapped_children,
                            executor=executor,
                        ))

                    submitted_states = []

                    for idx, states in enumerate(list_of_upstream_states):
                        # if we are on a future rerun of a partially complete flow run,
                        # there might be mapped children in a retrying state; this check
                        # looks into the current task state's map_states for such info
                        if (isinstance(task_state, Mapped)
                                and len(task_state.map_states) >= idx + 1):
                            current_state = task_state.map_states[
                                idx]  # type: Optional[State]
                        elif isinstance(task_state, Mapped):
                            current_state = None
                        else:
                            current_state = task_state

                        # this is where each child is submitted for actual work
                        submitted_states.append(
                            executor.submit(
                                run_task,
                                task=task,
                                state=current_state,
                                upstream_states=states,
                                context=dict(
                                    prefect.context,
                                    **task_contexts.get(task, {}),
                                    map_index=idx,
                                ),
                                flow_result=self.flow.result,
                                task_runner_cls=self.task_runner_cls,
                                task_runner_state_handlers=
                                task_runner_state_handlers,
                                upstream_mapped_states=upstream_mapped_states,
                                extra_context=extra_context(task,
                                                            task_index=idx),
                            ))
                    if isinstance(task_states.get(task), Mapped):
                        mapped_children[
                            task] = submitted_states  # type: ignore

                else:
                    task_states[task] = executor.submit(
                        run_task,
                        task=task,
                        state=task_state,
                        upstream_states=upstream_states,
                        context=dict(prefect.context,
                                     **task_contexts.get(task, {})),
                        flow_result=self.flow.result,
                        task_runner_cls=self.task_runner_cls,
                        task_runner_state_handlers=task_runner_state_handlers,
                        upstream_mapped_states=upstream_mapped_states,
                        extra_context=extra_context(task),
                    )

            # ---------------------------------------------
            # Collect results
            # ---------------------------------------------

            # terminal tasks determine if the flow is finished
            terminal_tasks = self.flow.terminal_tasks()

            # reference tasks determine flow state
            reference_tasks = self.flow.reference_tasks()

            # wait until all terminal tasks are finished
            final_tasks = terminal_tasks.union(reference_tasks).union(
                return_tasks)
            final_states = executor.wait({
                t:
                task_states.get(t,
                                Pending("Task not evaluated by FlowRunner."))
                for t in final_tasks
            })

            # also wait for any children of Mapped tasks to finish, and add them
            # to the dictionary to determine flow state
            all_final_states = final_states.copy()
            for t, s in list(final_states.items()):
                if s.is_mapped():
                    # ensure we wait for any mapped children to complete
                    if t in mapped_children:
                        s.map_states = executor.wait(mapped_children[t])
                    s.result = [ms.result for ms in s.map_states]
                    all_final_states[t] = s.map_states

            assert isinstance(final_states, dict)

        key_states = set(
            flatten_seq([all_final_states[t] for t in reference_tasks]))
        terminal_states = set(
            flatten_seq([all_final_states[t] for t in terminal_tasks]))
        return_states = {t: final_states[t] for t in return_tasks}

        state = self.determine_final_state(
            state=state,
            key_states=key_states,
            return_states=return_states,
            terminal_states=terminal_states,
        )

        return state
Beispiel #21
0
    def check_task_is_cached(self, state: State, inputs: Dict[str, Result]) -> State:
        """
        Checks if task is cached in the DB and whether any of the caches are still valid.

        Args:
            - state (State): the current state of this task
            - inputs (Dict[str, Result]): a dictionary of inputs whose keys correspond
                to the task's `run()` arguments.

        Returns:
            - State: the state of the task after running the check

        Raises:
            - ENDRUN: if the task is not ready to run
        """
        if state.is_cached() is True:
            assert isinstance(state, Cached)  # mypy assert
            sanitized_inputs = {key: res.value for key, res in inputs.items()}
            if self.task.cache_validator(
                state, sanitized_inputs, prefect.context.get("parameters")
            ):
                state = state.load_result(self.result)
                return state

        if self.task.cache_for is not None:
            oldest_valid_cache = datetime.datetime.utcnow() - self.task.cache_for
            cached_states = self.client.get_latest_cached_states(
                task_id=prefect.context.get("task_id", ""),
                cache_key=self.task.cache_key,
                created_after=oldest_valid_cache,
            )

            if not cached_states:
                self.logger.debug(
                    "Task '{name}': can't use cache because no Cached states were found".format(
                        name=prefect.context.get("task_full_name", self.task.name)
                    )
                )
            else:
                self.logger.debug(
                    "Task '{name}': {num} candidate cached states were found".format(
                        name=prefect.context.get("task_full_name", self.task.name),
                        num=len(cached_states),
                    )
                )

            for candidate_state in cached_states:
                assert isinstance(candidate_state, Cached)  # mypy assert
                candidate_state.load_cached_results(inputs)
                sanitized_inputs = {key: res.value for key, res in inputs.items()}
                if self.task.cache_validator(
                    candidate_state, sanitized_inputs, prefect.context.get("parameters")
                ):
                    return candidate_state.load_result(self.result)

                self.logger.debug(
                    "Task '{name}': can't use cache because no candidate Cached states "
                    "were valid".format(
                        name=prefect.context.get("task_full_name", self.task.name)
                    )
                )

        return state
Beispiel #22
0
def test_state_equality_ignores_message():
    assert State(result=1, message="x") == State(result=1, message="y")
    assert State(result=1, message="x") != State(result=2, message="x")
Beispiel #23
0
    def check_task_is_ready(self, state: State) -> State:
        """
        Checks to make sure the task is ready to run (Pending or Mapped).

        If the state is Paused, an ENDRUN is raised.

        Args:
            - state (State): the current state of this task

        Returns:
            - State: the state of the task after running the check

        Raises:
            - ENDRUN: if the task is not ready to run
        """

        # the task is paused
        if isinstance(state, Paused):
            self.logger.debug(
                "Task '{name}': task is paused; ending run.".format(
                    name=prefect.context.get("task_full_name", self.task.name)
                )
            )
            raise ENDRUN(state)

        # the task is ready
        elif state.is_pending():
            return state

        # the task is mapped, in which case we still proceed so that the children tasks
        # are generated (note that if the children tasks)
        elif state.is_mapped():
            self.logger.debug(
                "Task '{name}': task is mapped, but run will proceed so children are generated.".format(
                    name=prefect.context.get("task_full_name", self.task.name)
                )
            )
            return state

        # this task is already running
        elif state.is_running():
            self.logger.debug(
                "Task '{name}': task is already running.".format(
                    name=prefect.context.get("task_full_name", self.task.name)
                )
            )
            raise ENDRUN(state)

        elif state.is_cached():
            return state

        # this task is already finished
        elif state.is_finished():
            self.logger.debug(
                "Task '{name}': task is already finished.".format(
                    name=prefect.context.get("task_full_name", self.task.name)
                )
            )
            raise ENDRUN(state)

        # this task is not pending
        else:
            self.logger.debug(
                "Task '{name}' is not ready to run or state was unrecognized ({state}).".format(
                    name=prefect.context.get("task_full_name", self.task.name),
                    state=state,
                )
            )
            raise ENDRUN(state)
Beispiel #24
0
def test_states_are_hashable():
    assert {State(), Pending(), Success()}
Beispiel #25
0
    def get_task_run_state(
        self,
        state: State,
        inputs: Dict[str, Result],
        timeout_handler: Optional[Callable],
    ) -> State:
        """
        Runs the task and traps any signals or errors it raises.
        Also checkpoints the result of a successful task, if `task.checkpoint` is `True`.

        Args:
            - state (State): the current state of this task
            - inputs (Dict[str, Result], optional): a dictionary of inputs whose keys correspond
                to the task's `run()` arguments.
            - timeout_handler (Callable, optional): function for timing out
                task execution, with call signature `handler(fn, *args, **kwargs)`. Defaults to
                `prefect.utilities.executors.main_thread_timeout`

        Returns:
            - State: the state of the task after running the check

        Raises:
            - signals.PAUSE: if the task raises PAUSE
            - ENDRUN: if the task is not ready to run
        """
        if not state.is_running():
            self.logger.debug(
                "Task '{name}': can't run task because it's not in a "
                "Running state; ending run.".format(
                    name=prefect.context.get("task_full_name", self.task.name)
                )
            )

            raise ENDRUN(state)

        try:
            self.logger.debug(
                "Task '{name}': Calling task.run() method...".format(
                    name=prefect.context.get("task_full_name", self.task.name)
                )
            )
            timeout_handler = timeout_handler or main_thread_timeout
            raw_inputs = {k: r.value for k, r in inputs.items()}
            result = timeout_handler(
                self.task.run, timeout=self.task.timeout, **raw_inputs
            )

        # inform user of timeout
        except TimeoutError as exc:
            if prefect.context.get("raise_on_exception"):
                raise exc
            state = TimedOut(
                "Task timed out during execution.", result=exc, cached_inputs=inputs
            )
            return state

        result = Result(value=result, result_handler=self.result_handler)
        state = Success(result=result, message="Task run succeeded.")

        ## only checkpoint tasks if running in cloud
        if (
            state.is_successful()
            and prefect.context.get("cloud") is True
            and self.task.checkpoint is True
        ):
            state._result.store_safe_value()

        return state
Beispiel #26
0
def test_children_method_on_base_state():
    all_states_set = set(all_states)
    all_states_set.remove(State)
    assert all_states_set == set(State.children())
Beispiel #27
0
class TestTaskRunStates:
    async def test_set_task_run_state(self, task_run_id):
        result = await api.states.set_task_run_state(task_run_id=task_run_id,
                                                     state=Failed())

        assert result.task_run_id == task_run_id

        query = await models.TaskRun.where(id=task_run_id).first(
            {"version", "state", "serialized_state"})

        assert query.version == 2
        assert query.state == "Failed"
        assert query.serialized_state["type"] == "Failed"

    @pytest.mark.parametrize("state", [Failed(), Success()])
    async def test_set_task_run_state_fails_with_wrong_task_run_id(
            self, state):
        with pytest.raises(ValueError, match="State update failed"):
            await api.states.set_task_run_state(task_run_id=str(uuid.uuid4()),
                                                state=state)

    @pytest.mark.parametrize(
        "state", [s() for s in State.children() if not s().is_running()])
    async def test_state_does_not_set_heartbeat_unless_running(
            self, state, task_run_id):
        task_run = await models.TaskRun.where(id=task_run_id
                                              ).first({"heartbeat"})
        assert task_run.heartbeat is None

        await api.states.set_task_run_state(task_run_id=task_run_id,
                                            state=state)

        task_run = await models.TaskRun.where(id=task_run_id
                                              ).first({"heartbeat"})
        assert task_run.heartbeat is None

    async def test_running_state_sets_heartbeat(self, task_run_id,
                                                running_flow_run_id):
        task_run = await models.TaskRun.where(id=task_run_id
                                              ).first({"heartbeat"})
        assert task_run.heartbeat is None

        dt = pendulum.now("UTC")
        await api.states.set_task_run_state(task_run_id=task_run_id,
                                            state=Running())

        task_run = await models.TaskRun.where(id=task_run_id
                                              ).first({"heartbeat"})
        assert task_run.heartbeat > dt

    async def test_trigger_failed_state_does_not_set_end_time(
            self, task_run_id):
        await api.states.set_task_run_state(task_run_id=task_run_id,
                                            state=TriggerFailed())
        task_run_info = await models.TaskRun.where(id=task_run_id).first(
            {"id", "start_time", "end_time"})
        assert not task_run_info.start_time
        assert not task_run_info.end_time

    @pytest.mark.parametrize(
        "flow_run_state",
        [Pending(), Running(), Failed(),
         Success()])
    async def test_running_states_can_not_be_set_if_flow_run_is_not_running(
            self, flow_run_id, task_run_id, flow_run_state):

        await api.states.set_flow_run_state(flow_run_id=flow_run_id,
                                            state=flow_run_state)

        set_running_coroutine = api.states.set_task_run_state(
            task_run_id=task_run_id, state=Running())

        if flow_run_state.is_running():
            assert await set_running_coroutine
            assert (await
                    models.TaskRun.where(id=task_run_id
                                         ).first({"state"})).state == "Running"
        else:

            with pytest.raises(ValueError, match="is not in a running state"):
                await set_running_coroutine
            assert (await models.TaskRun.where(id=task_run_id).first(
                {"state"})).state != "Running"