async def set_task_run_state(task_run_id: str, state: State, force: bool = False) -> Dict[str, str]: """ Updates a task run state. Args: - task_run_id (str): the task run id to update - state (State): the new state - false (bool): if True, avoids pipeline checks Returns: - Dict[str, str]: Mapping indicating status of the state change operation. """ if task_run_id is None: raise ValueError(f"Invalid task run ID.") task_run = await models.TaskRun.where({ "id": { "_eq": task_run_id }, }).first({ "id": True, "version": True, "state": True, "serialized_state": True, "flow_run": { "id": True, "state": True }, }) if not task_run: raise ValueError(f"Invalid task run ID: {task_run_id}.") # ------------------------------------------------------ # if the state is running, ensure the flow run is also running # ------------------------------------------------------ if not force and state.is_running( ) and task_run.flow_run.state != "Running": raise ValueError( f"State update failed for task run ID {task_run_id}: provided " f"a running state but associated flow run {task_run.flow_run.id} is not " "in a running state.") # ------------------------------------------------------ # if we have cached inputs on the old state, we need to carry them forward # ------------------------------------------------------ if not state.cached_inputs and task_run.serialized_state.get( "cached_inputs", None): # load up the old state's cached inputs and apply them to the new state serialized_state = state_schema.load(task_run.serialized_state) state.cached_inputs = serialized_state.cached_inputs # -------------------------------------------------------- # prepare the new state for the database # -------------------------------------------------------- task_run_state = models.TaskRunState( task_run_id=task_run.id, version=(task_run.version or 0) + 1, timestamp=pendulum.now("UTC"), message=state.message, result=state.result, start_time=getattr(state, "start_time", None), state=type(state).__name__, serialized_state=state.serialize(), ) await task_run_state.insert() return {"status": "SUCCESS"}
def run_mapped_task( self, state: State, upstream_states: Dict[Edge, State], context: Dict[str, Any], executor: "prefect.engine.executors.Executor", ) -> State: """ If the task is being mapped, submits children tasks for execution. Returns a `Mapped` state. Args: - state (State): the current task state - upstream_states (Dict[Edge, State]): the upstream states - context (dict, optional): prefect Context to use for execution - executor (Executor): executor to use when performing computation Returns: - State: the state of the task after running the check Raises: - ENDRUN: if the current state is not `Running` """ map_upstream_states = [] # we don't know how long the iterables are, but we want to iterate until we reach # the end of the shortest one counter = itertools.count() # infinite loop, if upstream_states has any entries while True and upstream_states: i = next(counter) states = {} try: for edge, upstream_state in upstream_states.items(): # if the edge is not mapped over, then we take its state if not edge.mapped: states[edge] = upstream_state # if the edge is mapped and the upstream state is Mapped, then we are mapping # over a mapped task. In this case, we take the appropriately-indexed upstream # state from the upstream tasks's `Mapped.map_states` array. # Note that these "states" might actually be futures at this time; we aren't # blocking until they finish. elif edge.mapped and upstream_state.is_mapped(): states[edge] = upstream_state.map_states[ i] # type: ignore # Otherwise, we are mapping over the result of a "vanilla" task. In this # case, we create a copy of the upstream state but set the result to the # appropriately-indexed item from the upstream task's `State.result` # array. else: states[edge] = copy.copy(upstream_state) # if the current state is already Mapped, then we might be executing # a re-run of the mapping pipeline. In that case, the upstream states # might not have `result` attributes (as any required results could be # in the `cached_inputs` attribute of one of the child states). # Therefore, we only try to get a result if EITHER this task's # state is not already mapped OR the upstream result is not None. if not state.is_mapped( ) or upstream_state._result != NoResult: if not hasattr(upstream_state.result, "__getitem__"): raise TypeError( "Cannot map over unsubscriptable object of type {t}: {preview}..." .format( t=type(upstream_state.result), preview=repr( upstream_state.result)[:10], )) upstream_result = upstream_state._result.from_value( # type: ignore upstream_state.result[i]) states[edge].result = upstream_result elif state.is_mapped(): if i >= len(state.map_states): # type: ignore raise IndexError() # only add this iteration if we made it through all iterables map_upstream_states.append(states) # index error means we reached the end of the shortest iterable except IndexError: break def run_fn(state: State, map_index: int, upstream_states: Dict[Edge, State]) -> State: map_context = context.copy() map_context.update(map_index=map_index) with prefect.context(self.context): return self.run( upstream_states=upstream_states, # if we set the state here, then it will not be processed by `initialize_run()` state=state, context=map_context, executor=executor, ) # generate initial states, if available if isinstance(state, Mapped): initial_states = list( state.map_states) # type: List[Optional[State]] else: initial_states = [] initial_states.extend([None] * (len(map_upstream_states) - len(initial_states))) current_state = Mapped( message="Preparing to submit {} mapped tasks.".format( len(initial_states)), map_states=initial_states, # type: ignore ) state = self.handle_state_change(old_state=state, new_state=current_state) if state is not current_state: return state # map over the initial states, a counter representing the map_index, and also the mapped upstream states map_states = executor.map(run_fn, initial_states, range(len(map_upstream_states)), map_upstream_states) self.logger.debug("{} mapped tasks submitted for execution.".format( len(map_states))) new_state = Mapped(message="Mapped tasks submitted for execution.", map_states=map_states) return self.handle_state_change(old_state=state, new_state=new_state)
def test_state_equality_with_nested_states(): s1 = State(result=Success(result=1)) s2 = State(result=Success(result=2)) s3 = State(result=Success(result=1)) assert s1 != s2 assert s1 == s3
def test_states_with_mutable_attrs_are_hashable(): assert {State(result=[1]), Pending(cached_inputs=dict(a=1))}
def test_parent_method_on_base_state(): assert State.parents() == []
def test_state_equality_ignores_context(): s, r = State(result=1), State(result=1) s.context["key"] = "value" assert s == r
def test_state_pickle_with_unpicklable_result_raises(): state = State(result=RLock()) # An unpickable result type with pytest.raises(TypeError, match="pickle"): cloudpickle.dumps(state)
def get_task_run_state(self, state: State, inputs: Dict[str, Result]) -> State: """ Runs the task and traps any signals or errors it raises. Also checkpoints the result of a successful task, if `task.checkpoint` is `True`. Args: - state (State): the current state of this task - inputs (Dict[str, Result], optional): a dictionary of inputs whose keys correspond to the task's `run()` arguments. Returns: - State: the state of the task after running the check Raises: - signals.PAUSE: if the task raises PAUSE - ENDRUN: if the task is not ready to run """ if not state.is_running(): self.logger.debug( "Task '{name}': Can't run task because it's not in a " "Running state; ending run.".format( name=prefect.context.get("task_full_name", self.task.name) ) ) raise ENDRUN(state) value = None raw_inputs = {k: r.value for k, r in inputs.items()} new_state = None try: self.logger.debug( "Task '{name}': Calling task.run() method...".format( name=prefect.context.get("task_full_name", self.task.name) ) ) # Create a stdout redirect if the task has log_stdout enabled log_context = ( redirect_stdout(prefect.utilities.logging.RedirectToLog(self.logger)) if getattr(self.task, "log_stdout", False) else nullcontext() ) # type: AbstractContextManager with log_context: value = prefect.utilities.executors.run_task_with_timeout( task=self.task, args=(), kwargs=raw_inputs, logger=self.logger, ) # inform user of timeout except TimeoutError as exc: if prefect.context.get("raise_on_exception"): raise exc state = TimedOut("Task timed out during execution.", result=exc) return state except signals.LOOP as exc: new_state = exc.state assert isinstance(new_state, Looped) value = new_state.result new_state.message = exc.state.message or "Task is looping ({})".format( new_state.loop_count ) # checkpoint tasks if a result is present, except for when the user has opted out by # disabling checkpointing if ( prefect.context.get("checkpointing") is True and self.task.checkpoint is not False and value is not None ): try: formatting_kwargs = { **prefect.context.get("parameters", {}).copy(), **prefect.context, **raw_inputs, } result = self.result.write(value, **formatting_kwargs) except ResultNotImplementedError: result = self.result.from_value(value=value) else: result = self.result.from_value(value=value) if new_state is not None: new_state.result = result return new_state state = Success(result=result, message="Task run succeeded.") return state
def test_parent_method_on_base_state(include_self): assert State.parents( include_self=include_self) == ([State] if include_self else [])
def test_parent_method_on_base_state_names_only(include_self): assert State.parents( include_self=include_self, names_only=True) == (["State"] if include_self else [])
def test_children_method_on_base_state(include_self): all_states_set = set(all_states) if not include_self: all_states_set.remove(State) assert all_states_set == set(State.children(include_self=include_self))
class TestTaskRunStates: async def test_set_task_run_state(self, task_run_id): result = await api.states.set_task_run_state(task_run_id=task_run_id, state=Failed()) assert result.task_run_id == task_run_id query = await models.TaskRun.where(id=task_run_id).first( {"version", "state", "serialized_state"}) assert query.version == 2 assert query.state == "Failed" assert query.serialized_state["type"] == "Failed" @pytest.mark.parametrize("state", [Failed(), Success()]) async def test_set_task_run_state_fails_with_wrong_task_run_id( self, state): with pytest.raises(ValueError, match="State update failed"): await api.states.set_task_run_state(task_run_id=str(uuid.uuid4()), state=state) @pytest.mark.parametrize( "state", [s() for s in State.children() if not s().is_running()]) async def test_state_does_not_set_heartbeat_unless_running( self, state, task_run_id): task_run = await models.TaskRun.where(id=task_run_id ).first({"heartbeat"}) assert task_run.heartbeat is None await api.states.set_task_run_state(task_run_id=task_run_id, state=state) task_run = await models.TaskRun.where(id=task_run_id ).first({"heartbeat"}) assert task_run.heartbeat is None async def test_running_state_sets_heartbeat(self, task_run_id, running_flow_run_id): task_run = await models.TaskRun.where(id=task_run_id ).first({"heartbeat"}) assert task_run.heartbeat is None dt = pendulum.now("UTC") await api.states.set_task_run_state(task_run_id=task_run_id, state=Running()) task_run = await models.TaskRun.where(id=task_run_id ).first({"heartbeat"}) assert task_run.heartbeat > dt async def test_trigger_failed_state_does_not_set_end_time( self, task_run_id): await api.states.set_task_run_state(task_run_id=task_run_id, state=TriggerFailed()) task_run_info = await models.TaskRun.where(id=task_run_id).first( {"id", "start_time", "end_time"}) assert not task_run_info.start_time assert not task_run_info.end_time @pytest.mark.parametrize( "state", [s() for s in State.children() if s not in _MetaState.children()], ids=[ s.__name__ for s in State.children() if s not in _MetaState.children() ], ) async def test_setting_a_task_run_state_pulls_cached_inputs_if_possible( self, task_run_id, state, running_flow_run_id): res1 = SafeResult(1, result_handler=JSONResultHandler()) res2 = SafeResult({"z": 2}, result_handler=JSONResultHandler()) complex_result = {"x": res1, "y": res2} cached_state = Failed(cached_inputs=complex_result) await models.TaskRun.where(id=task_run_id).update(set=dict( serialized_state=cached_state.serialize())) # try to schedule the task run to scheduled await api.states.set_task_run_state(task_run_id=task_run_id, state=state) task_run = await models.TaskRun.where(id=task_run_id ).first({"serialized_state"}) # ensure the state change took place assert task_run.serialized_state["type"] == type(state).__name__ assert task_run.serialized_state["cached_inputs"]["x"]["value"] == 1 assert task_run.serialized_state["cached_inputs"]["y"]["value"] == { "z": 2 } @pytest.mark.parametrize( "state", [ s(cached_inputs=None) for s in State.children() if s not in _MetaState.children() ], ids=[ s.__name__ for s in State.children() if s not in _MetaState.children() ], ) async def test_task_runs_with_null_cached_inputs_do_not_overwrite_cache( self, state, task_run_id, running_flow_run_id): await api.states.set_task_run_state(task_run_id=task_run_id, state=state) # set up a Retrying state with non-null cached inputs res1 = SafeResult(1, result_handler=JSONResultHandler()) res2 = SafeResult({"z": 2}, result_handler=JSONResultHandler()) complex_result = {"x": res1, "y": res2} cached_state = Retrying(cached_inputs=complex_result) await api.states.set_task_run_state(task_run_id=task_run_id, state=cached_state) run = await models.TaskRun.where(id=task_run_id ).first({"serialized_state"}) assert run.serialized_state["cached_inputs"]["x"]["value"] == 1 assert run.serialized_state["cached_inputs"]["y"]["value"] == {"z": 2} @pytest.mark.parametrize( "state_cls", [s for s in State.children() if s not in _MetaState.children()]) async def test_task_runs_cached_inputs_give_preference_to_new_cached_inputs( self, state_cls, task_run_id, running_flow_run_id): # set up a Failed state with null cached inputs res1 = SafeResult(1, result_handler=JSONResultHandler()) res2 = SafeResult({"a": 2}, result_handler=JSONResultHandler()) complex_result = {"b": res1, "c": res2} cached_state = state_cls(cached_inputs=complex_result) await api.states.set_task_run_state(task_run_id=task_run_id, state=cached_state) # set up a Retrying state with non-null cached inputs res1 = SafeResult(1, result_handler=JSONResultHandler()) res2 = SafeResult({"z": 2}, result_handler=JSONResultHandler()) complex_result = {"x": res1, "y": res2} cached_state = Retrying(cached_inputs=complex_result) await api.states.set_task_run_state(task_run_id=task_run_id, state=cached_state) run = Box(await models.TaskRun.where(id=task_run_id ).first({"serialized_state"})) # verify that we have cached inputs, and that preference has been given to the new # cached inputs assert run.serialized_state.cached_inputs assert run.serialized_state.cached_inputs.x.value == 1 assert run.serialized_state.cached_inputs.y.value == {"z": 2} @pytest.mark.parametrize( "flow_run_state", [Pending(), Running(), Failed(), Success()]) async def test_running_states_can_not_be_set_if_flow_run_is_not_running( self, flow_run_id, task_run_id, flow_run_state): await api.states.set_flow_run_state(flow_run_id=flow_run_id, state=flow_run_state) set_running_coroutine = api.states.set_task_run_state( task_run_id=task_run_id, state=Running()) if flow_run_state.is_running(): assert await set_running_coroutine assert (await models.TaskRun.where(id=task_run_id ).first({"state"})).state == "Running" else: with pytest.raises(ValueError, match="is not in a running state"): await set_running_coroutine assert (await models.TaskRun.where(id=task_run_id).first( {"state"})).state != "Running"
class TestFlowRunStates: async def test_set_flow_run_state(self, flow_run_id): result = await api.states.set_flow_run_state(flow_run_id=flow_run_id, state=Running()) assert result.flow_run_id == flow_run_id query = await models.FlowRun.where(id=flow_run_id).first( {"version", "state", "serialized_state"}) assert query.version == 3 assert query.state == "Running" assert query.serialized_state["type"] == "Running" @pytest.mark.parametrize("state", [Running(), Success()]) async def test_set_flow_run_state_fails_with_wrong_flow_run_id( self, state): with pytest.raises(ValueError, match="State update failed"): await api.states.set_flow_run_state(flow_run_id=str(uuid.uuid4()), state=state) async def test_trigger_failed_state_does_not_set_end_time( self, flow_run_id): # there is no logic in Prefect that would create this sequence of # events, but a user could manually do this await api.states.set_flow_run_state(flow_run_id=flow_run_id, state=TriggerFailed()) flow_run_info = await models.FlowRun.where(id=flow_run_id).first( {"id", "start_time", "end_time"}) assert not flow_run_info.start_time assert not flow_run_info.end_time @pytest.mark.parametrize( "state", [ s() for s in State.children() if not s().is_running() and not s().is_submitted() ], ) async def test_state_does_not_set_heartbeat_unless_running_or_submitted( self, state, flow_run_id): flow_run = await models.FlowRun.where(id=flow_run_id ).first({"heartbeat"}) assert flow_run.heartbeat is None dt = pendulum.now("UTC") await api.states.set_flow_run_state(flow_run_id=flow_run_id, state=state) flow_run = await models.FlowRun.where(id=flow_run_id ).first({"heartbeat"}) assert flow_run.heartbeat is None @pytest.mark.parametrize("state", [Running(), Submitted()]) async def test_running_and_submitted_state_sets_heartbeat( self, state, flow_run_id): """ Both Running and Submitted states need to set heartbeats for services like Lazarus to function properly. """ flow_run = await models.FlowRun.where(id=flow_run_id ).first({"heartbeat"}) assert flow_run.heartbeat is None dt = pendulum.now("UTC") await api.states.set_flow_run_state(flow_run_id=flow_run_id, state=state) flow_run = await models.FlowRun.where(id=flow_run_id ).first({"heartbeat"}) assert flow_run.heartbeat > dt async def test_setting_flow_run_to_cancelled_state_sets_unfinished_task_runs_to_cancelled( self, flow_run_id): task_runs = await models.TaskRun.where({ "flow_run_id": { "_eq": flow_run_id } }).get({"id"}) task_run_ids = [run.id for run in task_runs] # update the state to Running await api.states.set_flow_run_state(flow_run_id=flow_run_id, state=Running()) # Currently this flow_run_id fixture has at least 3 tasks, if this # changes the test will need to be updated assert len(task_run_ids) >= 3, "flow_run_id fixture has changed" # Set one task run to pending, one to running, and the rest to success pending_task_run = task_run_ids[0] running_task_run = task_run_ids[1] rest = task_run_ids[2:] await api.states.set_task_run_state(task_run_id=pending_task_run, state=Pending()) await api.states.set_task_run_state(task_run_id=running_task_run, state=Running()) for task_run_id in rest: await api.states.set_task_run_state(task_run_id=task_run_id, state=Success()) # set the flow run to a cancelled state await api.states.set_flow_run_state(flow_run_id=flow_run_id, state=Cancelled()) # Confirm the unfinished task runs have been marked as cancelled task_runs = await models.TaskRun.where({ "flow_run_id": { "_eq": flow_run_id } }).get({"id", "state"}) new_states = {run.id: run.state for run in task_runs} assert new_states[pending_task_run] == "Cancelled" assert new_states[running_task_run] == "Cancelled" assert all(new_states[id] == "Success" for id in rest)
def get_task_run_state( self, state: State, inputs: Dict[str, Result], timeout_handler: Optional[Callable] = None, ) -> State: """ Runs the task and traps any signals or errors it raises. Also checkpoints the result of a successful task, if `task.checkpoint` is `True`. Args: - state (State): the current state of this task - inputs (Dict[str, Result], optional): a dictionary of inputs whose keys correspond to the task's `run()` arguments. - timeout_handler (Callable, optional): function for timing out task execution, with call signature `handler(fn, *args, **kwargs)`. Defaults to `prefect.utilities.executors.timeout_handler` Returns: - State: the state of the task after running the check Raises: - signals.PAUSE: if the task raises PAUSE - ENDRUN: if the task is not ready to run """ if not state.is_running(): self.logger.debug( "Task '{name}': can't run task because it's not in a " "Running state; ending run.".format(name=prefect.context.get( "task_full_name", self.task.name))) raise ENDRUN(state) value = None try: self.logger.debug( "Task '{name}': Calling task.run() method...".format( name=prefect.context.get("task_full_name", self.task.name))) timeout_handler = (timeout_handler or prefect.utilities.executors.timeout_handler) raw_inputs = {k: r.value for k, r in inputs.items()} if getattr(self.task, "log_stdout", False): with redirect_stdout( prefect.utilities.logging.RedirectToLog( self.logger)): # type: ignore result = timeout_handler(self.task.run, timeout=self.task.timeout, **raw_inputs) else: value = timeout_handler(self.task.run, timeout=self.task.timeout, **raw_inputs) except KeyboardInterrupt: self.logger.debug("Interrupt signal raised, cancelling task run.") state = Cancelled( message="Interrupt signal raised, cancelling task run.") return state # inform user of timeout except TimeoutError as exc: if prefect.context.get("raise_on_exception"): raise exc state = TimedOut("Task timed out during execution.", result=exc, cached_inputs=inputs) return state except signals.LOOP as exc: new_state = exc.state assert isinstance(new_state, Looped) new_state.result = self.result.from_value(value=new_state.result) new_state.cached_inputs = inputs new_state.message = exc.state.message or "Task is looping ({})".format( new_state.loop_count) return new_state ## checkpoint tasks if a result is present, except for when the user has opted out by disabling checkpointing if (prefect.context.get("checkpointing") is True and self.task.checkpoint is not False and value is not None): try: result = self.result.write(value, **prefect.context) except NotImplementedError: result = self.result.from_value(value=value) else: result = self.result.from_value(value=value) state = Success(result=result, message="Task run succeeded.", cached_inputs=inputs) return state
def test_state_pickle_with_exception(): state = State(result=Exception("foo")) new_state = cloudpickle.loads(cloudpickle.dumps(state)) assert isinstance(new_state.result, Exception) assert new_state.result.args == ("foo", )
def check_task_ready_to_map( self, state: State, upstream_states: Dict[Edge, State] ) -> State: """ Checks if the parent task is ready to proceed with mapping. Args: - state (State): the current state of this task - upstream_states (Dict[Edge, Union[State, List[State]]]): the upstream states Raises: - ENDRUN: either way, we dont continue past this point """ if state.is_mapped(): # this indicates we are executing a re-run of a mapped pipeline; # in this case, we populate both `map_states` and `cached_inputs` # to ensure the flow runner can properly regenerate the child tasks, # regardless of whether we mapped over an exchanged piece of data # or a non-data-exchanging upstream dependency if len(state.map_states) == 0 and state.n_map_states > 0: # type: ignore state.map_states = [None] * state.n_map_states # type: ignore state.cached_inputs = { edge.key: state._result # type: ignore for edge, state in upstream_states.items() if edge.key } raise ENDRUN(state) # we can't map if there are no success states with iterables upstream if upstream_states and not any( [ edge.mapped and state.is_successful() for edge, state in upstream_states.items() ] ): new_state = Failed("No upstream states can be mapped over.") # type: State raise ENDRUN(new_state) elif not all( [ hasattr(state.result, "__getitem__") for edge, state in upstream_states.items() if state.is_successful() and not state.is_mapped() and edge.mapped ] ): new_state = Failed("At least one upstream state has an unmappable result.") raise ENDRUN(new_state) else: # compute and set n_map_states n_map_states = min( [ len(s.result) for e, s in upstream_states.items() if e.mapped and s.is_successful() and not s.is_mapped() ] + [ s.n_map_states # type: ignore for e, s in upstream_states.items() if e.mapped and s.is_mapped() ], default=0, ) new_state = Mapped( "Ready to proceed with mapping.", n_map_states=n_map_states ) raise ENDRUN(new_state)
async def set_task_run_state( task_run_id: str, state: State, version: int = None, flow_run_version: int = None ) -> models.TaskRunState: """ Updates a task run state. Args: - task_run_id (str): the task run id to update - state (State): the new state - version (int): a version (only for Cloud API compatibility) - flow_run_version (int): a flow run version (only for Cloud API compatibility) Returns: - models.TaskRunState """ if task_run_id is None: raise ValueError(f"Invalid task run ID.") task_run = await models.TaskRun.where(id=task_run_id).first( { "id": True, "tenant_id": True, "version": True, "state": True, "serialized_state": True, "flow_run": {"id": True, "state": True}, } ) if not task_run: raise ValueError(f"State update failed for task run ID {task_run_id}") # ------------------------------------------------------ # if the state is running, ensure the flow run is also running # ------------------------------------------------------ if state.is_running() and task_run.flow_run.state != "Running": raise ValueError( f"State update failed for task run ID {task_run_id}: provided " f"a running state but associated flow run {task_run.flow_run.id} is not " "in a running state." ) # ------------------------------------------------------ # if we have cached inputs on the old state, we need to carry them forward # ------------------------------------------------------ if not state.cached_inputs and task_run.serialized_state.get("cached_inputs", None): # load up the old state's cached inputs and apply them to the new state serialized_state = state_schema.load(task_run.serialized_state) state.cached_inputs = serialized_state.cached_inputs # -------------------------------------------------------- # prepare the new state for the database # -------------------------------------------------------- task_run_state = models.TaskRunState( id=str(uuid.uuid4()), tenant_id=task_run.tenant_id, task_run_id=task_run.id, version=(task_run.version or 0) + 1, timestamp=pendulum.now("UTC"), message=state.message, result=state.result, start_time=getattr(state, "start_time", None), state=type(state).__name__, serialized_state=state.serialize(), ) await task_run_state.insert() # -------------------------------------------------------- # apply downstream updates # -------------------------------------------------------- # FOR RUNNING STATES: # - update the task run heartbeat if state.is_running(): await api.runs.update_task_run_heartbeat(task_run_id=task_run_id) return task_run_state
def get_flow_run_state( self, state: State, task_states: Dict[Task, State], task_contexts: Dict[Task, Dict[str, Any]], return_tasks: Set[Task], task_runner_state_handlers: Iterable[Callable], executor: "prefect.engine.executors.base.Executor", ) -> State: """ Runs the flow. Args: - state (State): starting state for the Flow. Defaults to `Pending` - task_states (dict): dictionary of task states to begin computation with, with keys being Tasks and values their corresponding state - task_contexts (Dict[Task, Dict[str, Any]]): contexts that will be provided to each task - return_tasks ([Task], optional): list of Tasks to include in the final returned Flow state. Defaults to `None` - task_runner_state_handlers (Iterable[Callable]): A list of state change handlers that will be provided to the task_runner, and called whenever a task changes state. - executor (Executor): executor to use when performing computation; defaults to the executor provided in your prefect configuration Returns: - State: `State` representing the final post-run state of the `Flow`. """ if not state.is_running(): self.logger.info("Flow is not in a Running state.") raise ENDRUN(state) if return_tasks is None: return_tasks = set() if set(return_tasks).difference(self.flow.tasks): raise ValueError("Some tasks in return_tasks were not found in the flow.") # -- process each task in order with executor.start(): for task in self.flow.sorted_tasks(): task_state = task_states.get(task) if task_state is None and isinstance( task, prefect.tasks.core.constants.Constant ): task_states[task] = task_state = Success(result=task.value) # if the state is finished, don't run the task, just use the provided state if ( isinstance(task_state, State) and task_state.is_finished() and not task_state.is_cached() and not task_state.is_mapped() ): continue upstream_states = {} # type: Dict[Edge, Union[State, Iterable]] # -- process each edge to the task for edge in self.flow.edges_to(task): upstream_states[edge] = task_states.get( edge.upstream_task, Pending(message="Task state not available.") ) # augment edges with upstream constants for key, val in self.flow.constants[task].items(): edge = Edge( upstream_task=prefect.tasks.core.constants.Constant(val), downstream_task=task, key=key, ) upstream_states[edge] = Success( "Auto-generated constant value", result=Result(val, result_handler=ConstantResultHandler(val)), ) # -- run the task with prefect.context(task_full_name=task.name, task_tags=task.tags): task_states[task] = executor.submit( self.run_task, task=task, state=task_state, upstream_states=upstream_states, context=dict(prefect.context, **task_contexts.get(task, {})), task_runner_state_handlers=task_runner_state_handlers, executor=executor, ) # --------------------------------------------- # Collect results # --------------------------------------------- # terminal tasks determine if the flow is finished terminal_tasks = self.flow.terminal_tasks() # reference tasks determine flow state reference_tasks = self.flow.reference_tasks() # wait until all terminal tasks are finished final_tasks = terminal_tasks.union(reference_tasks).union(return_tasks) final_states = executor.wait( { t: task_states.get(t, Pending("Task not evaluated by FlowRunner.")) for t in final_tasks } ) # also wait for any children of Mapped tasks to finish, and add them # to the dictionary to determine flow state all_final_states = final_states.copy() for t, s in list(final_states.items()): if s.is_mapped(): s.map_states = executor.wait(s.map_states) s.result = [ms.result for ms in s.map_states] all_final_states[t] = s.map_states assert isinstance(final_states, dict) key_states = set(flatten_seq([all_final_states[t] for t in reference_tasks])) terminal_states = set( flatten_seq([all_final_states[t] for t in terminal_tasks]) ) return_states = {t: final_states[t] for t in return_tasks} state = self.determine_final_state( state=state, key_states=key_states, return_states=return_states, terminal_states=terminal_states, ) return state
async def set_flow_run_state( flow_run_id: str, state: State, version: int = None, agent_id: str = None ) -> models.FlowRunState: """ Updates a flow run state. Args: - flow_run_id (str): the flow run id to update - state (State): the new state - version (int): a version (only for Cloud API compatibility) - agent_id (str): the ID of an agent instance setting the state Returns: - models.FlowRunState """ if flow_run_id is None: raise ValueError(f"Invalid flow run ID.") flow_run = await models.FlowRun.where(id=flow_run_id).first( { "id": True, "state": True, "name": True, "version": True, "flow": {"id", "name", "flow_group_id", "version_group_id"}, "tenant": {"id", "slug"}, } ) if not flow_run: raise ValueError(f"State update failed for flow run ID {flow_run_id}") # -------------------------------------------------------- # apply downstream updates # -------------------------------------------------------- # FOR CANCELLED STATES: # - set all non-finished task run states to Cancelled if isinstance(state, Cancelled): task_runs = await models.TaskRun.where( {"flow_run_id": {"_eq": flow_run_id}} ).get({"id", "serialized_state"}) to_cancel = [ t for t in task_runs if not state_schema.load(t.serialized_state).is_finished() ] # For a run with many tasks this may be a lot of tasks - at some point # we might want to batch this rather than kicking off lots of asyncio # tasks at once. await asyncio.gather( *(api.states.set_task_run_state(t.id, state) for t in to_cancel), return_exceptions=True, ) # -------------------------------------------------------- # insert the new state in the database # -------------------------------------------------------- flow_run_state = models.FlowRunState( id=str(uuid.uuid4()), tenant_id=flow_run.tenant_id, flow_run_id=flow_run_id, version=(flow_run.version or 0) + 1, state=type(state).__name__, timestamp=pendulum.now("UTC"), message=state.message, result=state.result, start_time=getattr(state, "start_time", None), serialized_state=state.serialize(), ) await flow_run_state.insert() # -------------------------------------------------------- # apply downstream updates # -------------------------------------------------------- # FOR RUNNING STATES: # - update the flow run heartbeat if state.is_running() or state.is_submitted(): await api.runs.update_flow_run_heartbeat(flow_run_id=flow_run_id) # Set agent ID on flow run when submitted by agent if state.is_submitted() and agent_id: await api.runs.update_flow_run_agent(flow_run_id=flow_run_id, agent_id=agent_id) # -------------------------------------------------------- # call cloud hooks # -------------------------------------------------------- event = events.FlowRunStateChange( flow_run=flow_run, state=flow_run_state, flow=flow_run.flow, tenant=flow_run.tenant, ) asyncio.create_task(api.cloud_hooks.call_hooks(event)) return flow_run_state
def get_flow_run_state( self, state: State, task_states: Dict[Task, State], task_contexts: Dict[Task, Dict[str, Any]], return_tasks: Set[Task], task_runner_state_handlers: Iterable[Callable], executor: "prefect.engine.executors.base.Executor", ) -> State: """ Runs the flow. Args: - state (State): starting state for the Flow. Defaults to `Pending` - task_states (dict): dictionary of task states to begin computation with, with keys being Tasks and values their corresponding state - task_contexts (Dict[Task, Dict[str, Any]]): contexts that will be provided to each task - return_tasks ([Task], optional): list of Tasks to include in the final returned Flow state. Defaults to `None` - task_runner_state_handlers (Iterable[Callable]): A list of state change handlers that will be provided to the task_runner, and called whenever a task changes state. - executor (Executor): executor to use when performing computation; defaults to the executor provided in your prefect configuration Returns: - State: `State` representing the final post-run state of the `Flow`. """ # this dictionary is used for tracking the states of "children" mapped tasks; # when running on Dask, we want to avoid serializing futures, so instead # of storing child task states in the `map_states` attribute we instead store # in this dictionary and only after they are resolved do we attach them to the Mapped state mapped_children = dict() # type: Dict[Task, list] if not state.is_running(): self.logger.info("Flow is not in a Running state.") raise ENDRUN(state) if return_tasks is None: return_tasks = set() if set(return_tasks).difference(self.flow.tasks): raise ValueError( "Some tasks in return_tasks were not found in the flow.") def extra_context(task: Task, task_index: int = None) -> dict: return { "task_name": task.name, "task_tags": task.tags, "task_index": task_index, } # -- process each task in order with self.check_for_cancellation(), executor.start(): for task in self.flow.sorted_tasks(): task_state = task_states.get(task) # if a task is a constant task, we already know its return value # no need to use up resources by running it through a task runner if task_state is None and isinstance( task, prefect.tasks.core.constants.Constant): task_states[task] = task_state = Success(result=task.value) # Always restart completed resource setup/cleanup tasks unless # they were explicitly cached. # TODO: we only need to rerun these tasks if any pending # downstream tasks depend on them. if (isinstance( task, ( prefect.tasks.core.resource_manager.ResourceSetupTask, prefect.tasks.core.resource_manager. ResourceCleanupTask, ), ) and task_state is not None and task_state.is_finished() and not task_state.is_cached()): task_states[task] = task_state = Pending() # if the state is finished, don't run the task, just use the provided state if # the state is cached / mapped, we still want to run the task runner pipeline # steps to either ensure the cache is still valid / or to recreate the mapped # pipeline for possible retries if (isinstance(task_state, State) and task_state.is_finished() and not task_state.is_cached() and not task_state.is_mapped()): continue upstream_states = {} # type: Dict[Edge, State] # this dictionary is used exclusively for "reduce" tasks in particular we store # the states / futures corresponding to the upstream children, and if running # on Dask, let Dask resolve them at the appropriate time. # Note: this is an optimization that allows Dask to resolve the mapped # dependencies by "elevating" them to a function argument. upstream_mapped_states = {} # type: Dict[Edge, list] # -- process each edge to the task for edge in self.flow.edges_to(task): # load the upstream task states (supplying Pending as a default) upstream_states[edge] = task_states.get( edge.upstream_task, Pending(message="Task state not available.")) # if the edge is flattened and not the result of a map, then we # preprocess the upstream states. If it IS the result of a # map, it will be handled in `prepare_upstream_states_for_mapping` if edge.flattened: if not isinstance(upstream_states[edge], Mapped): upstream_states[edge] = executor.submit( executors.flatten_upstream_state, upstream_states[edge]) # this checks whether the task is a "reduce" task for a mapped pipeline # and if so, collects the appropriate upstream children if not edge.mapped and isinstance(upstream_states[edge], Mapped): children = mapped_children.get(edge.upstream_task, []) # if the edge is flattened, then we need to wait for the mapped children # to complete and then flatten them if edge.flattened: children = executors.flatten_mapped_children( mapped_children=children, executor=executor) upstream_mapped_states[edge] = children # augment edges with upstream constants for key, val in self.flow.constants[task].items(): edge = Edge( upstream_task=prefect.tasks.core.constants.Constant( val), downstream_task=task, key=key, ) upstream_states[edge] = Success( "Auto-generated constant value", result=ConstantResult(value=val), ) # handle mapped tasks if any(edge.mapped for edge in upstream_states.keys()): # wait on upstream states to determine the width of the pipeline # this is the key to depth-first execution upstream_states = executor.wait( {e: state for e, state in upstream_states.items()}) # we submit the task to the task runner to determine if # we can proceed with mapping - if the new task state is not a Mapped # state then we don't proceed task_states[task] = executor.wait( executor.submit( run_task, task=task, state=task_state, # original state upstream_states=upstream_states, context=dict(prefect.context, **task_contexts.get(task, {})), flow_result=self.flow.result, task_runner_cls=self.task_runner_cls, task_runner_state_handlers= task_runner_state_handlers, upstream_mapped_states=upstream_mapped_states, is_mapped_parent=True, extra_context=extra_context(task), )) # either way, we should now have enough resolved states to restructure # the upstream states into a list of upstream state dictionaries to iterate over list_of_upstream_states = ( executors.prepare_upstream_states_for_mapping( task_states[task], upstream_states, mapped_children, executor=executor, )) submitted_states = [] for idx, states in enumerate(list_of_upstream_states): # if we are on a future rerun of a partially complete flow run, # there might be mapped children in a retrying state; this check # looks into the current task state's map_states for such info if (isinstance(task_state, Mapped) and len(task_state.map_states) >= idx + 1): current_state = task_state.map_states[ idx] # type: Optional[State] elif isinstance(task_state, Mapped): current_state = None else: current_state = task_state # this is where each child is submitted for actual work submitted_states.append( executor.submit( run_task, task=task, state=current_state, upstream_states=states, context=dict( prefect.context, **task_contexts.get(task, {}), map_index=idx, ), flow_result=self.flow.result, task_runner_cls=self.task_runner_cls, task_runner_state_handlers= task_runner_state_handlers, upstream_mapped_states=upstream_mapped_states, extra_context=extra_context(task, task_index=idx), )) if isinstance(task_states.get(task), Mapped): mapped_children[ task] = submitted_states # type: ignore else: task_states[task] = executor.submit( run_task, task=task, state=task_state, upstream_states=upstream_states, context=dict(prefect.context, **task_contexts.get(task, {})), flow_result=self.flow.result, task_runner_cls=self.task_runner_cls, task_runner_state_handlers=task_runner_state_handlers, upstream_mapped_states=upstream_mapped_states, extra_context=extra_context(task), ) # --------------------------------------------- # Collect results # --------------------------------------------- # terminal tasks determine if the flow is finished terminal_tasks = self.flow.terminal_tasks() # reference tasks determine flow state reference_tasks = self.flow.reference_tasks() # wait until all terminal tasks are finished final_tasks = terminal_tasks.union(reference_tasks).union( return_tasks) final_states = executor.wait({ t: task_states.get(t, Pending("Task not evaluated by FlowRunner.")) for t in final_tasks }) # also wait for any children of Mapped tasks to finish, and add them # to the dictionary to determine flow state all_final_states = final_states.copy() for t, s in list(final_states.items()): if s.is_mapped(): # ensure we wait for any mapped children to complete if t in mapped_children: s.map_states = executor.wait(mapped_children[t]) s.result = [ms.result for ms in s.map_states] all_final_states[t] = s.map_states assert isinstance(final_states, dict) key_states = set( flatten_seq([all_final_states[t] for t in reference_tasks])) terminal_states = set( flatten_seq([all_final_states[t] for t in terminal_tasks])) return_states = {t: final_states[t] for t in return_tasks} state = self.determine_final_state( state=state, key_states=key_states, return_states=return_states, terminal_states=terminal_states, ) return state
def check_task_is_cached(self, state: State, inputs: Dict[str, Result]) -> State: """ Checks if task is cached in the DB and whether any of the caches are still valid. Args: - state (State): the current state of this task - inputs (Dict[str, Result]): a dictionary of inputs whose keys correspond to the task's `run()` arguments. Returns: - State: the state of the task after running the check Raises: - ENDRUN: if the task is not ready to run """ if state.is_cached() is True: assert isinstance(state, Cached) # mypy assert sanitized_inputs = {key: res.value for key, res in inputs.items()} if self.task.cache_validator( state, sanitized_inputs, prefect.context.get("parameters") ): state = state.load_result(self.result) return state if self.task.cache_for is not None: oldest_valid_cache = datetime.datetime.utcnow() - self.task.cache_for cached_states = self.client.get_latest_cached_states( task_id=prefect.context.get("task_id", ""), cache_key=self.task.cache_key, created_after=oldest_valid_cache, ) if not cached_states: self.logger.debug( "Task '{name}': can't use cache because no Cached states were found".format( name=prefect.context.get("task_full_name", self.task.name) ) ) else: self.logger.debug( "Task '{name}': {num} candidate cached states were found".format( name=prefect.context.get("task_full_name", self.task.name), num=len(cached_states), ) ) for candidate_state in cached_states: assert isinstance(candidate_state, Cached) # mypy assert candidate_state.load_cached_results(inputs) sanitized_inputs = {key: res.value for key, res in inputs.items()} if self.task.cache_validator( candidate_state, sanitized_inputs, prefect.context.get("parameters") ): return candidate_state.load_result(self.result) self.logger.debug( "Task '{name}': can't use cache because no candidate Cached states " "were valid".format( name=prefect.context.get("task_full_name", self.task.name) ) ) return state
def test_state_equality_ignores_message(): assert State(result=1, message="x") == State(result=1, message="y") assert State(result=1, message="x") != State(result=2, message="x")
def check_task_is_ready(self, state: State) -> State: """ Checks to make sure the task is ready to run (Pending or Mapped). If the state is Paused, an ENDRUN is raised. Args: - state (State): the current state of this task Returns: - State: the state of the task after running the check Raises: - ENDRUN: if the task is not ready to run """ # the task is paused if isinstance(state, Paused): self.logger.debug( "Task '{name}': task is paused; ending run.".format( name=prefect.context.get("task_full_name", self.task.name) ) ) raise ENDRUN(state) # the task is ready elif state.is_pending(): return state # the task is mapped, in which case we still proceed so that the children tasks # are generated (note that if the children tasks) elif state.is_mapped(): self.logger.debug( "Task '{name}': task is mapped, but run will proceed so children are generated.".format( name=prefect.context.get("task_full_name", self.task.name) ) ) return state # this task is already running elif state.is_running(): self.logger.debug( "Task '{name}': task is already running.".format( name=prefect.context.get("task_full_name", self.task.name) ) ) raise ENDRUN(state) elif state.is_cached(): return state # this task is already finished elif state.is_finished(): self.logger.debug( "Task '{name}': task is already finished.".format( name=prefect.context.get("task_full_name", self.task.name) ) ) raise ENDRUN(state) # this task is not pending else: self.logger.debug( "Task '{name}' is not ready to run or state was unrecognized ({state}).".format( name=prefect.context.get("task_full_name", self.task.name), state=state, ) ) raise ENDRUN(state)
def test_states_are_hashable(): assert {State(), Pending(), Success()}
def get_task_run_state( self, state: State, inputs: Dict[str, Result], timeout_handler: Optional[Callable], ) -> State: """ Runs the task and traps any signals or errors it raises. Also checkpoints the result of a successful task, if `task.checkpoint` is `True`. Args: - state (State): the current state of this task - inputs (Dict[str, Result], optional): a dictionary of inputs whose keys correspond to the task's `run()` arguments. - timeout_handler (Callable, optional): function for timing out task execution, with call signature `handler(fn, *args, **kwargs)`. Defaults to `prefect.utilities.executors.main_thread_timeout` Returns: - State: the state of the task after running the check Raises: - signals.PAUSE: if the task raises PAUSE - ENDRUN: if the task is not ready to run """ if not state.is_running(): self.logger.debug( "Task '{name}': can't run task because it's not in a " "Running state; ending run.".format( name=prefect.context.get("task_full_name", self.task.name) ) ) raise ENDRUN(state) try: self.logger.debug( "Task '{name}': Calling task.run() method...".format( name=prefect.context.get("task_full_name", self.task.name) ) ) timeout_handler = timeout_handler or main_thread_timeout raw_inputs = {k: r.value for k, r in inputs.items()} result = timeout_handler( self.task.run, timeout=self.task.timeout, **raw_inputs ) # inform user of timeout except TimeoutError as exc: if prefect.context.get("raise_on_exception"): raise exc state = TimedOut( "Task timed out during execution.", result=exc, cached_inputs=inputs ) return state result = Result(value=result, result_handler=self.result_handler) state = Success(result=result, message="Task run succeeded.") ## only checkpoint tasks if running in cloud if ( state.is_successful() and prefect.context.get("cloud") is True and self.task.checkpoint is True ): state._result.store_safe_value() return state
def test_children_method_on_base_state(): all_states_set = set(all_states) all_states_set.remove(State) assert all_states_set == set(State.children())
class TestTaskRunStates: async def test_set_task_run_state(self, task_run_id): result = await api.states.set_task_run_state(task_run_id=task_run_id, state=Failed()) assert result.task_run_id == task_run_id query = await models.TaskRun.where(id=task_run_id).first( {"version", "state", "serialized_state"}) assert query.version == 2 assert query.state == "Failed" assert query.serialized_state["type"] == "Failed" @pytest.mark.parametrize("state", [Failed(), Success()]) async def test_set_task_run_state_fails_with_wrong_task_run_id( self, state): with pytest.raises(ValueError, match="State update failed"): await api.states.set_task_run_state(task_run_id=str(uuid.uuid4()), state=state) @pytest.mark.parametrize( "state", [s() for s in State.children() if not s().is_running()]) async def test_state_does_not_set_heartbeat_unless_running( self, state, task_run_id): task_run = await models.TaskRun.where(id=task_run_id ).first({"heartbeat"}) assert task_run.heartbeat is None await api.states.set_task_run_state(task_run_id=task_run_id, state=state) task_run = await models.TaskRun.where(id=task_run_id ).first({"heartbeat"}) assert task_run.heartbeat is None async def test_running_state_sets_heartbeat(self, task_run_id, running_flow_run_id): task_run = await models.TaskRun.where(id=task_run_id ).first({"heartbeat"}) assert task_run.heartbeat is None dt = pendulum.now("UTC") await api.states.set_task_run_state(task_run_id=task_run_id, state=Running()) task_run = await models.TaskRun.where(id=task_run_id ).first({"heartbeat"}) assert task_run.heartbeat > dt async def test_trigger_failed_state_does_not_set_end_time( self, task_run_id): await api.states.set_task_run_state(task_run_id=task_run_id, state=TriggerFailed()) task_run_info = await models.TaskRun.where(id=task_run_id).first( {"id", "start_time", "end_time"}) assert not task_run_info.start_time assert not task_run_info.end_time @pytest.mark.parametrize( "flow_run_state", [Pending(), Running(), Failed(), Success()]) async def test_running_states_can_not_be_set_if_flow_run_is_not_running( self, flow_run_id, task_run_id, flow_run_state): await api.states.set_flow_run_state(flow_run_id=flow_run_id, state=flow_run_state) set_running_coroutine = api.states.set_task_run_state( task_run_id=task_run_id, state=Running()) if flow_run_state.is_running(): assert await set_running_coroutine assert (await models.TaskRun.where(id=task_run_id ).first({"state"})).state == "Running" else: with pytest.raises(ValueError, match="is not in a running state"): await set_running_coroutine assert (await models.TaskRun.where(id=task_run_id).first( {"state"})).state != "Running"