def checkpoint_handler(task_runner: DSTaskRunner, old_state: State, new_state: State) -> State: """ A handler designed to implement result caching by filename. If the result handler's ``read`` method can be successfully run, this handler loads the result of that method as the task result and sets the task state to ``Success``. Similarly, on successful completion of the task, if the task was actually run and not loaded from cache, this handler will apply the result handler's ``write`` method to the task. Parameters ---------- task_runner : instance of DSTaskRunner The task runner associated with the flow the handler is used in. old_state : instance of prefect.engine.state.State The current state of the task. new_state : instance of prefect.engine.state.State The expected new state of the task. Returns ------- new_state : instance of prefect.engine.state.State The actual new state of the task. """ if "PREFECT__FLOWS__CHECKPOINTING" in os.environ and os.environ["PREFECT__FLOWS__CHECKPOINTING"] == "true": raise AttributeError("Cannot use standard prefect checkpointing with this handler") if task_runner.result_handler is not None and old_state.is_pending() and new_state.is_running(): if not hasattr(task_runner, "upstream_states"): raise TypeError( "upstream_states not found in task runner. Make sure to use " "prefect_ds.task_runner.DSTaskRunner." ) input_mapping = _create_input_mapping(task_runner.upstream_states) try: data = task_runner.task.result_handler.read(input_mapping=input_mapping) except FileNotFoundError: return new_state except TypeError: # unexpected argument input_mapping raise TypeError( "Result handler could not accept input_mapping argument. " "Please ensure that you are using a handler from prefect_ds." ) result = Result(value=data, result_handler=task_runner.task.result_handler) state = Success(result=result, message="Task loaded from disk.") return state if task_runner.result_handler is not None and old_state.is_running() and new_state.is_successful(): input_mapping = _create_input_mapping(task_runner.upstream_states) task_runner.task.result_handler.write(new_state.result, input_mapping=input_mapping) return new_state
def check_flow_is_pending_or_running(self, state: State) -> State: """ Checks if the flow is in either a Pending state or Running state. Either are valid starting points (because we allow simultaneous runs of the same flow run). Args: - state (State): the current state of this flow Returns: - State: the state of the flow after running the check Raises: - ENDRUN: if the flow is not pending or running """ # the flow run is already finished if state.is_finished() is True: self.logger.info("Flow run has already finished.") raise ENDRUN(state) # the flow run must be either pending or running (possibly redundant with above) elif not (state.is_pending() or state.is_running()): self.logger.info("Flow is not ready to run.") raise ENDRUN(state) return state
def get_task_run_state( self, state: State, inputs: Dict[str, Result], timeout_handler: Optional[Callable], ) -> State: """ Runs the task and traps any signals or errors it raises. Also checkpoints the result of a successful task, if `task.checkpoint` is `True`. Args: - state (State): the current state of this task - inputs (Dict[str, Result], optional): a dictionary of inputs whose keys correspond to the task's `run()` arguments. - timeout_handler (Callable, optional): function for timing out task execution, with call signature `handler(fn, *args, **kwargs)`. Defaults to `prefect.utilities.executors.main_thread_timeout` Returns: - State: the state of the task after running the check Raises: - signals.PAUSE: if the task raises PAUSE - ENDRUN: if the task is not ready to run """ if not state.is_running(): self.logger.debug( "Task '{name}': can't run task because it's not in a " "Running state; ending run.".format(name=prefect.context.get( "task_full_name", self.task.name))) raise ENDRUN(state) try: self.logger.debug( "Task '{name}': Calling task.run() method...".format( name=prefect.context.get("task_full_name", self.task.name))) timeout_handler = timeout_handler or main_thread_timeout raw_inputs = {k: r.value for k, r in inputs.items()} result = timeout_handler(self.task.run, timeout=self.task.timeout, **raw_inputs) # inform user of timeout except TimeoutError as exc: if prefect.context.get("raise_on_exception"): raise exc state = TimedOut("Task timed out during execution.", result=exc, cached_inputs=inputs) return state result = Result(value=result, result_handler=self.result_handler) state = Success(result=result, message="Task run succeeded.") if state.is_successful() and self.task.checkpoint is True: state._result.store_safe_value() return state
async def set_task_run_state(task_run_id: str, state: State, force=False) -> None: """ Updates a task run state. Args: - task_run_id (str): the task run id to update - state (State): the new state - false (bool): if True, avoids pipeline checks """ if task_run_id is None: raise ValueError(f"Invalid task run ID.") task_run = await models.TaskRun.where({"id": {"_eq": task_run_id},}).first( { "id": True, "version": True, "state": True, "serialized_state": True, "flow_run": {"id": True, "state": True}, } ) if not task_run: raise ValueError(f"Invalid task run ID: {task_run_id}.") # ------------------------------------------------------ # if the state is running, ensure the flow run is also running # ------------------------------------------------------ if not force and state.is_running() and task_run.flow_run.state != "Running": raise ValueError( f"State update failed for task run ID {task_run_id}: provided " f"a running state but associated flow run {task_run.flow_run.id} is not " "in a running state." ) # ------------------------------------------------------ # if we have cached inputs on the old state, we need to carry them forward # ------------------------------------------------------ if not state.cached_inputs and task_run.serialized_state.get("cached_inputs", None): # load up the old state's cached inputs and apply them to the new state serialized_state = state_schema.load(task_run.serialized_state) state.cached_inputs = serialized_state.cached_inputs # -------------------------------------------------------- # prepare the new state for the database # -------------------------------------------------------- task_run_state = models.TaskRunState( task_run_id=task_run.id, version=(task_run.version or 0) + 1, timestamp=pendulum.now("UTC"), message=state.message, result=state.result, start_time=getattr(state, "start_time", None), state=type(state).__name__, serialized_state=state.serialize(), ) await task_run_state.insert()
def check_task_is_ready(self, state: State) -> State: """ Checks to make sure the task is ready to run (Pending or Mapped). Args: - state (State): the current state of this task Returns: - State: the state of the task after running the check Raises: - ENDRUN: if the task is not ready to run """ # the task is ready if state.is_pending(): return state # the task is mapped, in which case we still proceed so that the children tasks # are generated (note that if the children tasks) elif state.is_mapped(): self.logger.debug( "Task '{name}': task is mapped, but run will proceed so children are generated.".format( name=prefect.context.get("task_full_name", self.task.name) ) ) return state # this task is already running elif state.is_running(): self.logger.debug( "Task '{name}': task is already running.".format( name=prefect.context.get("task_full_name", self.task.name) ) ) raise ENDRUN(state) elif state.is_cached(): return state # this task is already finished elif state.is_finished(): self.logger.debug( "Task '{name}': task is already finished.".format( name=prefect.context.get("task_full_name", self.task.name) ) ) raise ENDRUN(state) # this task is not pending else: self.logger.debug( "Task '{name}' is not ready to run or state was unrecognized ({state}).".format( name=prefect.context.get("task_full_name", self.task.name), state=state, ) ) raise ENDRUN(state)
def set_flow_to_running(self, state: State) -> State: """ Puts Pending flows in a Running state; leaves Running flows Running. Args: - state (State): the current state of this flow Returns: - State: the state of the flow after running the check Raises: - ENDRUN: if the flow is not pending or running """ if state.is_pending(): return Running(message="Running flow.") elif state.is_running(): return state else: raise ENDRUN(state)
def get_flow_run_state( self, state: State, task_states: Dict[Task, State], task_contexts: Dict[Task, Dict[str, Any]], return_tasks: Set[Task], task_runner_state_handlers: Iterable[Callable], executor: "prefect.executors.base.Executor", ) -> State: """ Runs the flow. Args: - state (State): starting state for the Flow. Defaults to `Pending` - task_states (dict): dictionary of task states to begin computation with, with keys being Tasks and values their corresponding state - task_contexts (Dict[Task, Dict[str, Any]]): contexts that will be provided to each task - return_tasks ([Task], optional): list of Tasks to include in the final returned Flow state. Defaults to `None` - task_runner_state_handlers (Iterable[Callable]): A list of state change handlers that will be provided to the task_runner, and called whenever a task changes state. - executor (Executor): executor to use when performing computation; defaults to the executor provided in your prefect configuration Returns: - State: `State` representing the final post-run state of the `Flow`. """ # this dictionary is used for tracking the states of "children" mapped tasks; # when running on Dask, we want to avoid serializing futures, so instead # of storing child task states in the `map_states` attribute we instead store # in this dictionary and only after they are resolved do we attach them to the Mapped state mapped_children = dict() # type: Dict[Task, list] if not state.is_running(): self.logger.info("Flow is not in a Running state.") raise ENDRUN(state) if return_tasks is None: return_tasks = set() if set(return_tasks).difference(self.flow.tasks): raise ValueError("Some tasks in return_tasks were not found in the flow.") def extra_context(task: Task, task_index: int = None) -> dict: return { "task_name": task.name, "task_tags": task.tags, "task_index": task_index, } # -- process each task in order with self.check_for_cancellation(), executor.start(): for task in self.flow.sorted_tasks(): task_state = task_states.get(task) # if a task is a constant task, we already know its return value # no need to use up resources by running it through a task runner if task_state is None and isinstance( task, prefect.tasks.core.constants.Constant ): task_states[task] = task_state = Success(result=task.value) # Always restart completed resource setup/cleanup tasks and # secret tasks unless they were explicitly cached. # TODO: we only need to rerun these tasks if any pending # downstream tasks depend on them. if ( isinstance( task, ( prefect.tasks.core.resource_manager.ResourceSetupTask, prefect.tasks.core.resource_manager.ResourceCleanupTask, prefect.tasks.secrets.SecretBase, ), ) and task_state is not None and task_state.is_finished() and not task_state.is_cached() ): task_states[task] = task_state = Pending() # if the state is finished, don't run the task, just use the provided state if # the state is cached / mapped, we still want to run the task runner pipeline # steps to either ensure the cache is still valid / or to recreate the mapped # pipeline for possible retries if ( isinstance(task_state, State) and task_state.is_finished() and not task_state.is_cached() and not task_state.is_mapped() ): continue upstream_states = {} # type: Dict[Edge, State] # this dictionary is used exclusively for "reduce" tasks in particular we store # the states / futures corresponding to the upstream children, and if running # on Dask, let Dask resolve them at the appropriate time. # Note: this is an optimization that allows Dask to resolve the mapped # dependencies by "elevating" them to a function argument. upstream_mapped_states = {} # type: Dict[Edge, list] # -- process each edge to the task for edge in self.flow.edges_to(task): # load the upstream task states (supplying Pending as a default) upstream_states[edge] = task_states.get( edge.upstream_task, Pending(message="Task state not available.") ) # if the edge is flattened and not the result of a map, then we # preprocess the upstream states. If it IS the result of a # map, it will be handled in `prepare_upstream_states_for_mapping` if edge.flattened: if not isinstance(upstream_states[edge], Mapped): upstream_states[edge] = executor.submit( executors.flatten_upstream_state, upstream_states[edge] ) # this checks whether the task is a "reduce" task for a mapped pipeline # and if so, collects the appropriate upstream children if not edge.mapped and isinstance(upstream_states[edge], Mapped): children = mapped_children.get(edge.upstream_task, []) # if the edge is flattened, then we need to wait for the mapped children # to complete and then flatten them if edge.flattened: children = executors.flatten_mapped_children( mapped_children=children, executor=executor ) upstream_mapped_states[edge] = children # augment edges with upstream constants for key, val in self.flow.constants[task].items(): edge = Edge( upstream_task=prefect.tasks.core.constants.Constant(val), downstream_task=task, key=key, ) upstream_states[edge] = Success( "Auto-generated constant value", result=ConstantResult(value=val), ) # handle mapped tasks if any(edge.mapped for edge in upstream_states.keys()): # wait on upstream states to determine the width of the pipeline # this is the key to depth-first execution upstream_states = executor.wait( {e: state for e, state in upstream_states.items()} ) # we submit the task to the task runner to determine if # we can proceed with mapping - if the new task state is not a Mapped # state then we don't proceed task_states[task] = executor.wait( executor.submit( run_task, task=task, state=task_state, # original state upstream_states=upstream_states, context=dict( prefect.context, **task_contexts.get(task, {}) ), flow_result=self.flow.result, task_runner_cls=self.task_runner_cls, task_runner_state_handlers=task_runner_state_handlers, upstream_mapped_states=upstream_mapped_states, is_mapped_parent=True, extra_context=extra_context(task), ) ) # either way, we should now have enough resolved states to restructure # the upstream states into a list of upstream state dictionaries to iterate over list_of_upstream_states = ( executors.prepare_upstream_states_for_mapping( task_states[task], upstream_states, mapped_children, executor=executor, ) ) submitted_states = [] for idx, states in enumerate(list_of_upstream_states): # if we are on a future rerun of a partially complete flow run, # there might be mapped children in a retrying state; this check # looks into the current task state's map_states for such info if ( isinstance(task_state, Mapped) and len(task_state.map_states) >= idx + 1 ): current_state = task_state.map_states[ idx ] # type: Optional[State] elif isinstance(task_state, Mapped): current_state = None else: current_state = task_state # this is where each child is submitted for actual work submitted_states.append( executor.submit( run_task, task=task, state=current_state, upstream_states=states, context=dict( prefect.context, **task_contexts.get(task, {}), map_index=idx, ), flow_result=self.flow.result, task_runner_cls=self.task_runner_cls, task_runner_state_handlers=task_runner_state_handlers, upstream_mapped_states=upstream_mapped_states, extra_context=extra_context(task, task_index=idx), ) ) if isinstance(task_states.get(task), Mapped): mapped_children[task] = submitted_states # type: ignore else: task_states[task] = executor.submit( run_task, task=task, state=task_state, upstream_states=upstream_states, context=dict(prefect.context, **task_contexts.get(task, {})), flow_result=self.flow.result, task_runner_cls=self.task_runner_cls, task_runner_state_handlers=task_runner_state_handlers, upstream_mapped_states=upstream_mapped_states, extra_context=extra_context(task), ) # --------------------------------------------- # Collect results # --------------------------------------------- # terminal tasks determine if the flow is finished terminal_tasks = self.flow.terminal_tasks() # reference tasks determine flow state reference_tasks = self.flow.reference_tasks() # wait until all terminal tasks are finished final_tasks = terminal_tasks.union(reference_tasks).union(return_tasks) final_states = executor.wait( { t: task_states.get(t, Pending("Task not evaluated by FlowRunner.")) for t in final_tasks } ) # also wait for any children of Mapped tasks to finish, and add them # to the dictionary to determine flow state all_final_states = final_states.copy() for t, s in list(final_states.items()): if s.is_mapped(): # ensure we wait for any mapped children to complete if t in mapped_children: s.map_states = executor.wait(mapped_children[t]) s.result = [ms.result for ms in s.map_states] all_final_states[t] = s.map_states assert isinstance(final_states, dict) key_states = set(flatten_seq([all_final_states[t] for t in reference_tasks])) terminal_states = set( flatten_seq([all_final_states[t] for t in terminal_tasks]) ) return_states = {t: final_states[t] for t in return_tasks} state = self.determine_final_state( state=state, key_states=key_states, return_states=return_states, terminal_states=terminal_states, ) return state
def get_task_run_state(self, state: State, inputs: Dict[str, Result]) -> State: """ Runs the task and traps any signals or errors it raises. Also checkpoints the result of a successful task, if `task.checkpoint` is `True`. Args: - state (State): the current state of this task - inputs (Dict[str, Result], optional): a dictionary of inputs whose keys correspond to the task's `run()` arguments. Returns: - State: the state of the task after running the check Raises: - signals.PAUSE: if the task raises PAUSE - ENDRUN: if the task is not ready to run """ if not state.is_running(): self.logger.debug( "Task '{name}': can't run task because it's not in a " "Running state; ending run.".format(name=prefect.context.get( "task_full_name", self.task.name))) raise ENDRUN(state) value = None raw_inputs = {k: r.value for k, r in inputs.items()} try: self.logger.debug( "Task '{name}': Calling task.run() method...".format( name=prefect.context.get("task_full_name", self.task.name))) timeout_handler = prefect.utilities.executors.timeout_handler if getattr(self.task, "log_stdout", False): with redirect_stdout( prefect.utilities.logging.RedirectToLog( self.logger) # type: ignore ): value = timeout_handler(self.task.run, timeout=self.task.timeout, **raw_inputs) else: value = timeout_handler(self.task.run, timeout=self.task.timeout, **raw_inputs) # inform user of timeout except TimeoutError as exc: if prefect.context.get("raise_on_exception"): raise exc state = TimedOut("Task timed out during execution.", result=exc) return state except signals.LOOP as exc: new_state = exc.state assert isinstance(new_state, Looped) new_state.result = self.result.from_value(value=new_state.result) new_state.message = exc.state.message or "Task is looping ({})".format( new_state.loop_count) return new_state # checkpoint tasks if a result is present, except for when the user has opted out by # disabling checkpointing if (prefect.context.get("checkpointing") is True and self.task.checkpoint is not False and value is not None): try: formatting_kwargs = { **prefect.context.get("parameters", {}).copy(), **raw_inputs, **prefect.context, } result = self.result.write(value, **formatting_kwargs) except NotImplementedError: result = self.result.from_value(value=value) else: result = self.result.from_value(value=value) state = Success(result=result, message="Task run succeeded.") return state
def get_flow_run_state( self, state: State, task_states: Dict[Task, State], task_contexts: Dict[Task, Dict[str, Any]], return_tasks: Set[Task], task_runner_state_handlers: Iterable[Callable], executor: "prefect.engine.executors.base.Executor", ) -> State: """ Runs the flow. Args: - state (State): starting state for the Flow. Defaults to `Pending` - task_states (dict): dictionary of task states to begin computation with, with keys being Tasks and values their corresponding state - task_contexts (Dict[Task, Dict[str, Any]]): contexts that will be provided to each task - return_tasks ([Task], optional): list of Tasks to include in the final returned Flow state. Defaults to `None` - task_runner_state_handlers (Iterable[Callable]): A list of state change handlers that will be provided to the task_runner, and called whenever a task changes state. - executor (Executor): executor to use when performing computation; defaults to the executor provided in your prefect configuration Returns: - State: `State` representing the final post-run state of the `Flow`. """ if not state.is_running(): self.logger.info("Flow is not in a Running state.") raise ENDRUN(state) if return_tasks is None: return_tasks = set() if set(return_tasks).difference(self.flow.tasks): raise ValueError( "Some tasks in return_tasks were not found in the flow.") # -- process each task in order with executor.start(): for task in self.flow.sorted_tasks(): task_state = task_states.get(task) if task_state is None and isinstance( task, prefect.tasks.core.constants.Constant): task_states[task] = task_state = Success(result=task.value) # if the state is finished, don't run the task, just use the provided state if (isinstance(task_state, State) and task_state.is_finished() and not task_state.is_cached() and not task_state.is_mapped()): continue upstream_states = { } # type: Dict[Edge, Union[State, Iterable]] # -- process each edge to the task for edge in self.flow.edges_to(task): upstream_states[edge] = task_states.get( edge.upstream_task, Pending(message="Task state not available.")) # -- run the task with prefect.context(task_full_name=task.name, task_tags=task.tags): task_states[task] = executor.submit( self.run_task, task=task, state=task_state, upstream_states=upstream_states, context=dict(prefect.context, **task_contexts.get(task, {})), task_runner_state_handlers=task_runner_state_handlers, executor=executor, ) # --------------------------------------------- # Collect results # --------------------------------------------- # terminal tasks determine if the flow is finished terminal_tasks = self.flow.terminal_tasks() # reference tasks determine flow state reference_tasks = self.flow.reference_tasks() # wait until all terminal tasks are finished final_tasks = terminal_tasks.union(reference_tasks).union( return_tasks) final_states = executor.wait({ t: task_states.get(t, Pending("Task not evaluated by FlowRunner.")) for t in final_tasks }) # also wait for any children of Mapped tasks to finish, and add them # to the dictionary to determine flow state all_final_states = final_states.copy() for t, s in list(final_states.items()): if s.is_mapped(): s.map_states = executor.wait(s.map_states) s.result = [ms.result for ms in s.map_states] all_final_states[t] = s.map_states assert isinstance(final_states, dict) key_states = set( flatten_seq([all_final_states[t] for t in reference_tasks])) terminal_states = set( flatten_seq([all_final_states[t] for t in terminal_tasks])) return_states = {t: final_states[t] for t in return_tasks} state = self.determine_final_state( state=state, key_states=key_states, return_states=return_states, terminal_states=terminal_states, ) return state
def get_task_run_state(self, state: State, inputs: Dict[str, Result]) -> State: """ Runs the task and traps any signals or errors it raises. Also checkpoints the result of a successful task, if `task.checkpoint` is `True`. Args: - state (State): the current state of this task - inputs (Dict[str, Result], optional): a dictionary of inputs whose keys correspond to the task's `run()` arguments. Returns: - State: the state of the task after running the check Raises: - signals.PAUSE: if the task raises PAUSE - ENDRUN: if the task is not ready to run """ task_name = prefect.context.get("task_full_name", self.task.name) if not state.is_running(): self.logger.debug( f"Task {task_name!r}: Can't run task because it's not in a Running " "state; ending run.") raise ENDRUN(state) value = None raw_inputs = {k: r.value for k, r in inputs.items()} new_state = None try: self.logger.debug( f"Task {task_name!r}: Calling task.run() method...") # Create a stdout redirect if the task has log_stdout enabled log_context = ( redirect_stdout( prefect.utilities.logging.RedirectToLog(self.logger)) if getattr(self.task, "log_stdout", False) else nullcontext() ) # type: AbstractContextManager with log_context: value = prefect.utilities.executors.run_task_with_timeout( task=self.task, args=(), kwargs=raw_inputs, logger=self.logger, ) except TaskTimeoutSignal as exc: # Convert timeouts to a `TimedOut` state if prefect.context.get("raise_on_exception"): raise exc state = TimedOut("Task timed out during execution.", result=exc) return state except signals.LOOP as exc: # Convert loop signals to a `Looped` state new_state = exc.state assert isinstance(new_state, Looped) value = new_state.result new_state.message = exc.state.message or "Task is looping ({})".format( new_state.loop_count) except signals.SUCCESS as exc: # Success signals can be treated like a normal result new_state = exc.state assert isinstance(new_state, Success) value = new_state.result except Exception as exc: # Handle exceptions in the task if prefect.context.get("raise_on_exception"): raise self.logger.error( f"Task {task_name!r}: Exception encountered during task execution!", exc_info=True, ) state = Failed(f"Error during execution of task: {exc!r}", result=exc) return state # checkpoint tasks if a result is present, except for when the user has opted out by # disabling checkpointing if (prefect.context.get("checkpointing") is True and self.task.checkpoint is not False and value is not None): try: formatting_kwargs = { **prefect.context.get("parameters", {}).copy(), **prefect.context, **raw_inputs, } result = self.result.write(value, **formatting_kwargs) except ResultNotImplementedError: result = self.result.from_value(value=value) else: result = self.result.from_value(value=value) if new_state is not None: new_state.result = result return new_state state = Success(result=result, message="Task run succeeded.") return state
def get_task_run_state( self, state: State, inputs: Dict[str, Result], timeout_handler: Optional[Callable] = None, ) -> State: """ Runs the task and traps any signals or errors it raises. Also checkpoints the result of a successful task, if `task.checkpoint` is `True`. Args: - state (State): the current state of this task - inputs (Dict[str, Result], optional): a dictionary of inputs whose keys correspond to the task's `run()` arguments. - timeout_handler (Callable, optional): function for timing out task execution, with call signature `handler(fn, *args, **kwargs)`. Defaults to `prefect.utilities.executors.timeout_handler` Returns: - State: the state of the task after running the check Raises: - signals.PAUSE: if the task raises PAUSE - ENDRUN: if the task is not ready to run """ if not state.is_running(): self.logger.debug( "Task '{name}': can't run task because it's not in a " "Running state; ending run.".format(name=prefect.context.get( "task_full_name", self.task.name))) raise ENDRUN(state) try: self.logger.debug( "Task '{name}': Calling task.run() method...".format( name=prefect.context.get("task_full_name", self.task.name))) timeout_handler = (timeout_handler or prefect.utilities.executors.timeout_handler) raw_inputs = {k: r.value for k, r in inputs.items()} if getattr(self.task, "log_stdout", False): with redirect_stdout( prefect.utilities.logging.RedirectToLog( self.logger)): # type: ignore result = timeout_handler(self.task.run, timeout=self.task.timeout, **raw_inputs) else: result = timeout_handler(self.task.run, timeout=self.task.timeout, **raw_inputs) except KeyboardInterrupt: self.logger.debug("Interrupt signal raised, cancelling task run.") state = Cancelled( message="Interrupt signal raised, cancelling task run.") return state # inform user of timeout except TimeoutError as exc: if prefect.context.get("raise_on_exception"): raise exc state = TimedOut("Task timed out during execution.", result=exc, cached_inputs=inputs) return state except signals.LOOP as exc: new_state = exc.state assert isinstance(new_state, Looped) new_state.result = Result(value=new_state.result, result_handler=self.result_handler) new_state.cached_inputs = inputs new_state.message = exc.state.message or "Task is looping ({})".format( new_state.loop_count) return new_state result = Result(value=result, result_handler=self.result_handler) state = Success(result=result, message="Task run succeeded.", cached_inputs=inputs) ## checkpoint tasks if a result_handler is present, except for when the user has opted out by disabling checkpointing if (state.is_successful() and prefect.context.get("checkpointing") is True and self.task.checkpoint is not False and self.result_handler is not None): state._result.store_safe_value() return state
async def set_flow_run_state(flow_run_id: str, state: State, version: int = None, agent_id: str = None) -> models.FlowRunState: """ Updates a flow run state. Args: - flow_run_id (str): the flow run id to update - state (State): the new state - version (int): a version to enforce version-locking - agent_id (str): the ID of an agent instance setting the state Returns: - models.FlowRunState """ if flow_run_id is None: raise ValueError(f"Invalid flow run ID.") where = { "id": { "_eq": flow_run_id }, "_or": [ # EITHER version locking is enabled and versions match { "version": { "_eq": version }, "flow": { "flow_group": { "settings": { "_contains": { "version_locking_enabled": True } } } }, }, # OR version locking is not enabled { "flow": { "flow_group": { "_not": { "settings": { "_contains": { "version_locking_enabled": True } } } } } }, ], } flow_run = await models.FlowRun.where(where).first({ "id": True, "state": True, "name": True, "version": True, "flow": {"id", "name", "flow_group_id", "version_group_id"}, "tenant": {"id", "slug"}, }) if not flow_run: raise ValueError(f"State update failed for flow run ID {flow_run_id}") # -------------------------------------------------------- # apply downstream updates # -------------------------------------------------------- # FOR CANCELLED STATES: # - set all non-finished task run states to Cancelled if isinstance(state, Cancelled): task_runs = await models.TaskRun.where({ "flow_run_id": { "_eq": flow_run_id } }).get({"id", "serialized_state"}) to_cancel = [ t for t in task_runs if not state_schema.load(t.serialized_state).is_finished() ] # For a run with many tasks this may be a lot of tasks - at some point # we might want to batch this rather than kicking off lots of asyncio # tasks at once. await asyncio.gather( *(api.states.set_task_run_state(t.id, state) for t in to_cancel), return_exceptions=True, ) # -------------------------------------------------------- # insert the new state in the database # -------------------------------------------------------- flow_run_state = models.FlowRunState( id=str(uuid.uuid4()), tenant_id=flow_run.tenant_id, flow_run_id=flow_run_id, version=(flow_run.version or 0) + 1, state=type(state).__name__, timestamp=pendulum.now("UTC"), message=state.message, result=state.result, start_time=getattr(state, "start_time", None), serialized_state=state.serialize(), ) await flow_run_state.insert() # -------------------------------------------------------- # apply downstream updates # -------------------------------------------------------- # FOR RUNNING STATES: # - update the flow run heartbeat if state.is_running() or state.is_submitted(): await api.runs.update_flow_run_heartbeat(flow_run_id=flow_run_id) # Set agent ID on flow run when submitted by agent if state.is_submitted() and agent_id: await api.runs.update_flow_run_agent(flow_run_id=flow_run_id, agent_id=agent_id) # -------------------------------------------------------- # call cloud hooks # -------------------------------------------------------- event = events.FlowRunStateChange( flow_run=flow_run, state=flow_run_state, flow=flow_run.flow, tenant=flow_run.tenant, ) asyncio.create_task(api.cloud_hooks.call_hooks(event)) return flow_run_state
async def set_task_run_state( task_run_id: str, state: State, version: int = None, flow_run_version: int = None) -> models.TaskRunState: """ Updates a task run state. Args: - task_run_id (str): the task run id to update - state (State): the new state - version (int): a version to enforce version-locking - flow_run_version (int): a flow run version to enforce version-lockgin Returns: - models.TaskRunState """ if task_run_id is None: raise ValueError(f"Invalid task run ID.") where = { "id": { "_eq": task_run_id }, "_or": [ { # EITHER version locking is enabled and the versions match "version": { "_eq": version }, "flow_run": { "version": { "_eq": flow_run_version }, "flow": { "flow_group": { "settings": { "_contains": { "version_locking_enabled": True } } } }, }, }, # OR version locking is not enabled { "flow_run": { "flow": { "flow_group": { "_not": { "settings": { "_contains": { "version_locking_enabled": True } } } } } } }, ], } task_run = await models.TaskRun.where(where).first({ "id": True, "tenant_id": True, "version": True, "state": True, "serialized_state": True, "flow_run": { "id": True, "state": True }, }) if not task_run: raise ValueError(f"State update failed for task run ID {task_run_id}") # ------------------------------------------------------ # if the state is running, ensure the flow run is also running # ------------------------------------------------------ if state.is_running() and task_run.flow_run.state != "Running": raise ValueError( f"State update failed for task run ID {task_run_id}: provided " f"a running state but associated flow run {task_run.flow_run.id} is not " "in a running state.") # ------------------------------------------------------ # if we have cached inputs on the old state, we need to carry them forward # ------------------------------------------------------ if not state.cached_inputs and task_run.serialized_state.get( "cached_inputs", None): # load up the old state's cached inputs and apply them to the new state serialized_state = state_schema.load(task_run.serialized_state) state.cached_inputs = serialized_state.cached_inputs # -------------------------------------------------------- # prepare the new state for the database # -------------------------------------------------------- task_run_state = models.TaskRunState( id=str(uuid.uuid4()), tenant_id=task_run.tenant_id, task_run_id=task_run.id, version=(task_run.version or 0) + 1, timestamp=pendulum.now("UTC"), message=state.message, result=state.result, start_time=getattr(state, "start_time", None), state=type(state).__name__, serialized_state=state.serialize(), ) await task_run_state.insert() # -------------------------------------------------------- # apply downstream updates # -------------------------------------------------------- # FOR RUNNING STATES: # - update the task run heartbeat if state.is_running(): await api.runs.update_task_run_heartbeat(task_run_id=task_run_id) return task_run_state