def check_task_trigger(self, state: State, upstream_states: Dict[Edge, State]) -> State: """ Checks if the task's trigger function passes. Args: - state (State): the current state of this task - upstream_states (Dict[Edge, Union[State, List[State]]]): the upstream states Returns: - State: the state of the task after running the check Raises: - ENDRUN: if the trigger raises an error """ try: if not self.task.trigger(upstream_states): raise signals.TRIGGERFAIL(message="Trigger failed") except signals.PrefectStateSignal as exc: self.logger.debug( "Task '{name}': {signal} signal raised during execution.". format( name=prefect.context.get("task_full_name", self.task.name), signal=type(exc).__name__, )) if prefect.context.get("raise_on_exception"): raise exc raise ENDRUN(exc.state) from exc # Exceptions are trapped and turned into TriggerFailed states except Exception as exc: self.logger.exception( "Task '{name}': unexpected error while evaluating task trigger: {exc}" .format( exc=repr(exc), name=prefect.context.get("task_full_name", self.task.name), )) if prefect.context.get("raise_on_exception"): raise exc raise ENDRUN( TriggerFailed( "Unexpected error while checking task trigger: {}".format( repr(exc)), result=exc, )) from exc return state
def check_upstream_skipped(self, state: State, upstream_states: Dict[Edge, State]) -> State: """ Checks if any of the upstream tasks have skipped. Args: - state (State): the current state of this task - upstream_states (Dict[Edge, State]): the upstream states Returns: - State: the state of the task after running the check """ all_states = set() # type: Set[State] for edge, upstream_state in upstream_states.items(): # if the upstream state is Mapped, and this task is also mapped, # we want each individual child to determine if it should # skip or not based on its upstream parent in the mapping if isinstance(upstream_state, Mapped) and not edge.mapped: all_states.update(upstream_state.map_states) else: all_states.add(upstream_state) if self.task.skip_on_upstream_skip and any(s.is_skipped() for s in all_states): self.logger.debug( "Task '{name}': Upstream states were skipped; ending run.". format(name=prefect.context.get("task_full_name", self.task.name))) raise ENDRUN(state=Skipped(message=( "Upstream task was skipped; if this was not the intended " "behavior, consider changing `skip_on_upstream_skip=False` " "for this task."))) return state
def set_task_to_running(self, state: State, inputs: Dict[str, Result]) -> State: """ Sets the task to running Args: - state (State): the current state of this task - inputs (Dict[str, Result]): a dictionary of inputs whose keys correspond to the task's `run()` arguments. Returns: - State: the state of the task after running the check Raises: - ENDRUN: if the task is not ready to run """ if not state.is_pending(): self.logger.debug( "Task '{name}': can't set state to Running because it " "isn't Pending; ending run.".format(name=prefect.context.get( "task_full_name", self.task.name))) raise ENDRUN(state) new_state = Running(message="Starting task run.") return new_state
def check_upstream_skipped(self, state: State, upstream_states: Dict[Edge, State]) -> State: """ Checks if any of the upstream tasks have skipped. Args: - state (State): the current state of this task - upstream_states (Dict[Edge, State]): the upstream states Returns: - State: the state of the task after running the check """ all_states = set() # type: Set[State] for upstream_state in upstream_states.values(): if isinstance(upstream_state, Mapped): all_states.update(upstream_state.map_states) else: all_states.add(upstream_state) if self.task.skip_on_upstream_skip and any(s.is_skipped() for s in all_states): self.logger.debug( "Task '{name}': Upstream states were skipped; ending run.". format(name=prefect.context.get("task_full_name", self.task.name))) raise ENDRUN(state=Skipped(message=( "Upstream task was skipped; if this was not the intended " "behavior, consider changing `skip_on_upstream_skip=False` " "for this task."))) return state
def get_task_run_state( self, state: State, inputs: Dict[str, Result], timeout_handler: Optional[Callable], ) -> State: """ Runs the task and traps any signals or errors it raises. Also checkpoints the result of a successful task, if `task.checkpoint` is `True`. Args: - state (State): the current state of this task - inputs (Dict[str, Result], optional): a dictionary of inputs whose keys correspond to the task's `run()` arguments. - timeout_handler (Callable, optional): function for timing out task execution, with call signature `handler(fn, *args, **kwargs)`. Defaults to `prefect.utilities.executors.main_thread_timeout` Returns: - State: the state of the task after running the check Raises: - signals.PAUSE: if the task raises PAUSE - ENDRUN: if the task is not ready to run """ if not state.is_running(): self.logger.debug( "Task '{name}': can't run task because it's not in a " "Running state; ending run.".format(name=prefect.context.get( "task_full_name", self.task.name))) raise ENDRUN(state) try: self.logger.debug( "Task '{name}': Calling task.run() method...".format( name=prefect.context.get("task_full_name", self.task.name))) timeout_handler = timeout_handler or main_thread_timeout raw_inputs = {k: r.value for k, r in inputs.items()} result = timeout_handler(self.task.run, timeout=self.task.timeout, **raw_inputs) # inform user of timeout except TimeoutError as exc: if prefect.context.get("raise_on_exception"): raise exc state = TimedOut("Task timed out during execution.", result=exc, cached_inputs=inputs) return state result = Result(value=result, result_handler=self.result_handler) state = Success(result=result, message="Task run succeeded.") if state.is_successful() and self.task.checkpoint is True: state._result.store_safe_value() return state
def load_results( self, state: State, upstream_states: Dict[Edge, State]) -> Tuple[State, Dict[Edge, State]]: """ Given the task's current state and upstream states, populates all relevant result objects for this task run. Args: - state (State): the task's current state. - upstream_states (Dict[Edge, State]): the upstream state_handlers Returns: - Tuple[State, dict]: a tuple of (state, upstream_states) """ upstream_results = {} try: for edge, upstream_state in upstream_states.items(): upstream_states[edge] = upstream_state.load_result( edge.upstream_task.result or self.default_result) if edge.key is not None: upstream_results[edge.key] = (edge.upstream_task.result or self.default_result) state.load_cached_results(upstream_results) return state, upstream_states except Exception as exc: new_state = Failed( message=f"Failed to retrieve task results: {exc}", result=exc) final_state = self.handle_state_change(old_state=state, new_state=new_state) raise ENDRUN(final_state)
def check_upstream_finished( self, state: State, upstream_states: Dict[Edge, State] ) -> State: """ Checks if the upstream tasks have all finshed. Args: - state (State): the current state of this task - upstream_states (Dict[Edge, Union[State, List[State]]]): the upstream states Returns: - State: the state of the task after running the check Raises: - ENDRUN: if upstream tasks are not finished. """ all_states = set() # type: Set[State] for edge, upstream_state in upstream_states.items(): # if the upstream state is Mapped, and this task is also mapped, # we want each individual child to determine if it should # proceed or not based on its upstream parent in the mapping if isinstance(upstream_state, Mapped) and not edge.mapped: all_states.update(upstream_state.map_states) else: all_states.add(upstream_state) if not all(s.is_finished() for s in all_states): self.logger.debug( "Task '{name}': not all upstream states are finished; ending run.".format( name=prefect.context.get("task_full_name", self.task.name) ) ) raise ENDRUN(state) return state
def call_runner_target_handlers(self, old_state: State, new_state: State) -> State: """ A special state handler that the FlowRunner uses to call its flow's state handlers. This method is called as part of the base Runner's `handle_state_change()` method. Args: - old_state (State): the old (previous) state - new_state (State): the new (current) state Returns: - State: the new state """ raise_on_exception = prefect.context.get("raise_on_exception", False) try: new_state = super().call_runner_target_handlers( old_state=old_state, new_state=new_state) except Exception as exc: msg = "Exception raised while calling state handlers: {}".format( repr(exc)) self.logger.debug(msg) if raise_on_exception: raise exc new_state = Failed(msg, result=exc) flow_run_id = prefect.context.get("flow_run_id", None) version = prefect.context.get("flow_run_version") try: cloud_state = new_state state = self.client.set_flow_run_state(flow_run_id=flow_run_id, version=version, state=cloud_state) except Exception as exc: self.logger.debug("Failed to set flow state with error: {}".format( repr(exc))) raise ENDRUN(state=new_state) if state.is_queued(): state.state = old_state # type: ignore raise ENDRUN(state=state) prefect.context.update(flow_run_version=version + 1) return new_state
def check_task_reached_start_time(self, state: State) -> State: """ Checks if a task is in a Scheduled state and, if it is, ensures that the scheduled time has been reached. Note: Scheduled states include Retry states. Scheduled states with no start time (`start_time = None`) are never considered ready; they must be manually placed in another state. Args: - state (State): the current state of this task Returns: - State: the state of the task after performing the check Raises: - ENDRUN: if the task is Scheduled with a future scheduled time """ if isinstance(state, Scheduled): # handle case where no start_time is set if state.start_time is None: self.logger.debug( "Task '{name}' is scheduled without a known start_time; " "ending run.".format( name=prefect.context.get("task_full_name", self.task.name) ) ) raise ENDRUN(state) # handle case where start time is in the future elif state.start_time and state.start_time > pendulum.now("utc"): self.logger.debug( "Task '{name}': start_time has not been reached; " "ending run.".format( name=prefect.context.get("task_full_name", self.task.name) ) ) raise ENDRUN(state) return state
def check_task_ready_to_map( self, state: State, upstream_states: Dict[Edge, State] ) -> State: """ Checks if the parent task is ready to proceed with mapping. Args: - state (State): the current state of this task - upstream_states (Dict[Edge, Union[State, List[State]]]): the upstream states Raises: - ENDRUN: either way, we dont continue past this point """ if state.is_mapped(): raise ENDRUN(state) # we can't map if there are no success states with iterables upstream if upstream_states and not any( [ edge.mapped and state.is_successful() for edge, state in upstream_states.items() ] ): new_state = Failed("No upstream states can be mapped over.") # type: State raise ENDRUN(new_state) elif not all( [ hasattr(state.result, "__getitem__") for edge, state in upstream_states.items() if state.is_successful() and not state.is_mapped() and edge.mapped ] ): new_state = Failed("At least one upstream state has an unmappable result.") raise ENDRUN(new_state) else: new_state = Mapped("Ready to proceed with mapping.") raise ENDRUN(new_state)
def call_runner_target_handlers(self, old_state: State, new_state: State) -> State: """ A special state handler that the TaskRunner uses to call its task's state handlers. This method is called as part of the base Runner's `handle_state_change()` method. Args: - old_state (State): the old (previous) state - new_state (State): the new (current) state Returns: - State: the new state """ raise_on_exception = prefect.context.get("raise_on_exception", False) try: new_state = super().call_runner_target_handlers( old_state=old_state, new_state=new_state) except Exception as exc: msg = "Exception raised while calling state handlers: {}".format( repr(exc)) self.logger.debug(msg) if raise_on_exception: raise exc new_state = Failed(msg, result=exc) task_run_id = prefect.context.get("task_run_id") version = prefect.context.get("task_run_version") try: cloud_state = prepare_state_for_cloud(new_state) self.client.set_task_run_state( task_run_id=task_run_id, version=version, state=cloud_state, cache_for=self.task.cache_for, ) except Exception as exc: self.logger.debug("Failed to set task state with error: {}".format( repr(exc))) raise ENDRUN(state=ClientFailed(state=new_state)) if version is not None: prefect.context.update(task_run_version=version + 1) # type: ignore return new_state
def initialize_run( # type: ignore self, state: Optional[State], context: Dict[str, Any]) -> TaskRunnerInitializeResult: """ Initializes the Task run by initializing state and context appropriately. Args: - state (Optional[State]): the initial state of the run - context (Dict[str, Any]): the context to be updated with relevant information Returns: - tuple: a tuple of the updated state, context, and upstream_states objects """ # if the map_index is not None, this is a dynamic task and we need to load # task run info for it map_index = context.get("map_index") if map_index not in [-1, None]: try: task_run_info = self.client.get_task_run_info( flow_run_id=context.get("flow_run_id", ""), task_id=context.get("task_id", ""), map_index=map_index, ) # if state was provided, keep it; otherwise use the one from db state = state or task_run_info.state # type: ignore context.update( task_run_id=task_run_info.id, # type: ignore task_run_version=task_run_info.version, # type: ignore ) except Exception as exc: self.logger.exception( "Failed to retrieve task state with error: {}".format( repr(exc))) if state is None: state = Failed( message="Could not retrieve state from Prefect Cloud", result=exc, ) raise ENDRUN(state=state) from exc # we assign this so it can be shared with heartbeat thread self.task_run_id = context.get("task_run_id", "") # type: str context.update(checkpointing=True) return super().initialize_run(state=state, context=context)
def get_task_inputs( self, state: State, upstream_states: Dict[Edge, State]) -> Dict[str, Result]: """ Given the task's current state and upstream states, generates the inputs for this task. Upstream state result values are used. If the current state has `cached_inputs`, they will override any upstream values. Args: - state (State): the task's current state. - upstream_states (Dict[Edge, State]): the upstream state_handlers Returns: - Dict[str, Result]: the task inputs """ task_inputs = super().get_task_inputs(state, upstream_states) try: ## for mapped tasks, we need to take extra steps to store the cached_inputs; ## this is because in the event of a retry we don't want to have to load the ## entire upstream array that is being mapped over, instead we need store the ## individual pieces of data separately for more efficient retries map_index = prefect.context.get("map_index") if map_index not in [-1, None]: for edge, upstream_state in upstream_states.items(): if (edge.key and edge.mapped and edge.upstream_task.checkpoint is not False): try: task_inputs[edge.key] = task_inputs[ edge.key].write( # type: ignore task_inputs[edge.key].value, filename=f"{edge.key}-{map_index}", **prefect.context, ) except NotImplementedError: pass except Exception as exc: new_state = Failed( message=f"Failed to save inputs for mapped task: {exc}", result=exc) final_state = self.handle_state_change(old_state=state, new_state=new_state) raise ENDRUN(final_state) return task_inputs
def set_flow_to_running(self, state: State) -> State: """ Puts Pending flows in a Running state; leaves Running flows Running. Args: - state (State): the current state of this flow Returns: - State: the state of the flow after running the check Raises: - ENDRUN: if the flow is not pending or running """ if state.is_pending(): return Running(message="Running flow.") elif state.is_running(): return state else: raise ENDRUN(state)
def set_task_to_running(self, state: State) -> State: """ Sets the task to running Args: - state (State): the current state of this task Returns: - State: the state of the task after running the check Raises: - ENDRUN: if the task is not ready to run """ if not state.is_pending(): self.logger.debug( "Task '{name}': can't set state to Running because it " "isn't Pending; ending run.".format(name=prefect.context.get( "task_full_name", self.task.name))) raise ENDRUN(state) return Running(message="Starting task run.")
def check_flow_reached_start_time(self, state: State) -> State: """ Checks if the Flow is in a Scheduled state and, if it is, ensures that the scheduled time has been reached. Args: - state (State): the current state of this Flow Returns: - State: the state of the flow after performing the check Raises: - ENDRUN: if the flow is Scheduled with a future scheduled time """ if isinstance(state, Scheduled): if state.start_time and state.start_time > pendulum.now("utc"): self.logger.debug( "Flow '{name}': start_time has not been reached; ending run." .format(name=self.flow.name)) raise ENDRUN(state) return state
def check_upstream_finished(self, state: State, upstream_states: Dict[Edge, State]) -> State: """ Checks if the upstream tasks have all finshed. Args: - state (State): the current state of this task - upstream_states (Dict[Edge, Union[State, List[State]]]): the upstream states Returns: - State: the state of the task after running the check Raises: - ENDRUN: if upstream tasks are not finished. """ if not all(s.is_finished() for s in upstream_states.values()): self.logger.debug( "Task '{name}': not all upstream states are finished; ending run." .format(name=prefect.context.get("task_full_name", self.task.name))) raise ENDRUN(state) return state
def check_task_reached_start_time(self, state: State) -> State: """ Checks if a task is in a Scheduled state and, if it is, ensures that the scheduled time has been reached. Note: Scheduled states include Retry states. Args: - state (State): the current state of this task Returns: - State: the state of the task after performing the check Raises: - ENDRUN: if the task is Scheduled with a future scheduled time """ if isinstance(state, Scheduled): if state.start_time and state.start_time > pendulum.now("utc"): self.logger.debug( "Task '{name}': start_time has not been reached; ending run." .format(name=prefect.context.get("task_full_name", self.task.name))) raise ENDRUN(state) return state
def get_task_run_state(self, state: State, inputs: Dict[str, Result]) -> State: """ Runs the task and traps any signals or errors it raises. Also checkpoints the result of a successful task, if `task.checkpoint` is `True`. Args: - state (State): the current state of this task - inputs (Dict[str, Result], optional): a dictionary of inputs whose keys correspond to the task's `run()` arguments. Returns: - State: the state of the task after running the check Raises: - signals.PAUSE: if the task raises PAUSE - ENDRUN: if the task is not ready to run """ task_name = prefect.context.get("task_full_name", self.task.name) if not state.is_running(): self.logger.debug( f"Task {task_name!r}: Can't run task because it's not in a Running " "state; ending run.") raise ENDRUN(state) value = None raw_inputs = {k: r.value for k, r in inputs.items()} new_state = None try: self.logger.debug( f"Task {task_name!r}: Calling task.run() method...") # Create a stdout redirect if the task has log_stdout enabled log_context = ( redirect_stdout( prefect.utilities.logging.RedirectToLog(self.logger)) if getattr(self.task, "log_stdout", False) else nullcontext() ) # type: AbstractContextManager with log_context: value = prefect.utilities.executors.run_task_with_timeout( task=self.task, args=(), kwargs=raw_inputs, logger=self.logger, ) except TaskTimeoutSignal as exc: # Convert timeouts to a `TimedOut` state if prefect.context.get("raise_on_exception"): raise exc state = TimedOut("Task timed out during execution.", result=exc) return state except signals.LOOP as exc: # Convert loop signals to a `Looped` state new_state = exc.state assert isinstance(new_state, Looped) value = new_state.result new_state.message = exc.state.message or "Task is looping ({})".format( new_state.loop_count) except signals.SUCCESS as exc: # Success signals can be treated like a normal result new_state = exc.state assert isinstance(new_state, Success) value = new_state.result except Exception as exc: # Handle exceptions in the task if prefect.context.get("raise_on_exception"): raise self.logger.error( f"Task {task_name!r}: Exception encountered during task execution!", exc_info=True, ) state = Failed(f"Error during execution of task: {exc!r}", result=exc) return state # checkpoint tasks if a result is present, except for when the user has opted out by # disabling checkpointing if (prefect.context.get("checkpointing") is True and self.task.checkpoint is not False and value is not None): try: formatting_kwargs = { **prefect.context.get("parameters", {}).copy(), **prefect.context, **raw_inputs, } result = self.result.write(value, **formatting_kwargs) except ResultNotImplementedError: result = self.result.from_value(value=value) else: result = self.result.from_value(value=value) if new_state is not None: new_state.result = result return new_state state = Success(result=result, message="Task run succeeded.") return state
def run( self, state: State = None, upstream_states: Dict[Edge, State] = None, context: Dict[str, Any] = None, executor: "prefect.engine.executors.Executor" = None, ) -> State: """ The main endpoint for TaskRunners. Calling this method will conditionally execute `self.task.run` with any provided inputs, assuming the upstream dependencies are in a state which allow this Task to run. Args: - state (State, optional): initial `State` to begin task run from; defaults to `Pending()` - upstream_states (Dict[Edge, State]): a dictionary representing the states of any tasks upstream of this one. The keys of the dictionary should correspond to the edges leading to the task. - context (dict, optional): prefect Context to use for execution - executor (Executor, optional): executor to use when performing computation; defaults to the executor specified in your prefect configuration Returns: - `State` object representing the final post-run state of the Task """ upstream_states = upstream_states or {} context = context or {} map_index = context.setdefault("map_index", None) context["task_full_name"] = "{name}{index}".format( name=self.task.name, index=("" if map_index is None else "[{}]".format(map_index)), ) if executor is None: executor = prefect.engine.get_default_executor_class()() # if mapped is true, this task run is going to generate a Mapped state. It won't # actually run, but rather spawn children tasks to map over its inputs. We # detect this case by checking for: # - upstream edges that are `mapped` # - no `map_index` (which indicates that this is the child task, not the parent) mapped = any([e.mapped for e in upstream_states]) and map_index is None task_inputs = {} # type: Dict[str, Any] try: # initialize the run state, context = self.initialize_run(state, context) # run state transformation pipeline with prefect.context(context): if prefect.context.get("task_loop_count") is None: self.logger.info( "Task '{name}': Starting task run...".format( name=context["task_full_name"] ) ) # check to make sure the task is in a pending state state = self.check_task_is_ready(state) # check if the task has reached its scheduled time state = self.check_task_reached_start_time(state) # Tasks never run if the upstream tasks haven't finished state = self.check_upstream_finished( state, upstream_states=upstream_states ) # check if any upstream tasks skipped (and if we need to skip) state = self.check_upstream_skipped( state, upstream_states=upstream_states ) # if the task is mapped, process the mapped children and exit if mapped: state = self.run_mapped_task( state=state, upstream_states=upstream_states, context=context, executor=executor, ) state = self.wait_for_mapped_task(state=state, executor=executor) self.logger.debug( "Task '{name}': task has been mapped; ending run.".format( name=context["task_full_name"] ) ) raise ENDRUN(state) # retrieve task inputs from upstream and also explicitly passed inputs task_inputs = self.get_task_inputs( state=state, upstream_states=upstream_states ) # check to see if the task has a cached result state = self.check_task_is_cached(state, inputs=task_inputs) # check if the task's trigger passes # triggers can raise Pauses, which require task_inputs to be available for caching # so we run this after the previous step state = self.check_task_trigger(state, upstream_states=upstream_states) # set the task state to running state = self.set_task_to_running(state) # run the task state = self.get_task_run_state( state, inputs=task_inputs, timeout_handler=executor.timeout_handler ) # cache the output, if appropriate state = self.cache_result(state, inputs=task_inputs) # check if the task needs to be retried state = self.check_for_retry(state, inputs=task_inputs) state = self.check_task_is_looping( state, inputs=task_inputs, upstream_states=upstream_states, context=context, executor=executor, ) # for pending signals, including retries and pauses we need to make sure the # task_inputs are set except (ENDRUN, signals.PrefectStateSignal) as exc: if exc.state.is_pending() or exc.state.is_failed(): exc.state.cached_inputs = task_inputs or {} # type: ignore state = exc.state except RecursiveCall as exc: raise exc except Exception as exc: msg = "Task '{name}': unexpected error while running task: {exc}".format( name=context["task_full_name"], exc=repr(exc) ) self.logger.exception(msg) state = Failed(message=msg, result=exc) if prefect.context.get("raise_on_exception"): raise exc # to prevent excessive repetition of this log # since looping relies on recursively calling self.run # TODO: figure out a way to only log this one single time instead of twice if prefect.context.get("task_loop_count") is None: # wrapping this final log in prefect.context(context) ensures # that any run-context, including task-run-ids, are respected with prefect.context(context): self.logger.info( "Task '{name}': finished task run for task with final state: '{state}'".format( name=context["task_full_name"], state=type(state).__name__ ) ) return state
def call_runner_target_handlers(self, old_state: State, new_state: State) -> State: """ A special state handler that the TaskRunner uses to call its task's state handlers. This method is called as part of the base Runner's `handle_state_change()` method. Args: - old_state (State): the old (previous) state - new_state (State): the new (current) state Returns: - State: the new state """ raise_on_exception = prefect.context.get("raise_on_exception", False) try: new_state = super().call_runner_target_handlers( old_state=old_state, new_state=new_state) # PrefectStateSignals are trapped and turned into States except prefect.engine.signals.PrefectStateSignal as exc: self.logger.info("{name} signal raised: {rep}".format( name=type(exc).__name__, rep=repr(exc))) if raise_on_exception: raise exc new_state = exc.state except Exception as exc: msg = "Exception raised while calling state handlers: {}".format( repr(exc)) self.logger.exception(msg) if raise_on_exception: raise exc new_state = Failed(msg, result=exc) task_run_id = prefect.context.get("task_run_id") version = prefect.context.get("task_run_version") try: cloud_state = new_state state = self.client.set_task_run_state( task_run_id=task_run_id, version=version if cloud_state.is_running() else None, state=cloud_state, cache_for=self.task.cache_for, ) except VersionLockError as exc: state = self.client.get_task_run_state(task_run_id=task_run_id) if state.is_running(): self.logger.debug( "Version lock encountered and task {} is already in a running state." .format(self.task.name)) raise ENDRUN(state=state) from exc self.logger.debug( "Version lock encountered for task {}, proceeding with state {}..." .format(self.task.name, type(state).__name__)) try: new_state = state.load_result(self.result) except Exception as exc_inner: self.logger.debug( "Error encountered attempting to load result for state of {} task..." .format(self.task.name)) self.logger.error(repr(exc_inner)) raise ENDRUN(state=state) from exc_inner except Exception as exc: self.logger.exception( "Failed to set task state with error: {}".format(repr(exc))) raise ENDRUN(state=ClientFailed(state=new_state)) from exc if state.is_queued(): state.state = old_state # type: ignore raise ENDRUN(state=state) prefect.context.update(task_run_version=(version or 0) + 1) return new_state
def check_task_ready_to_map( self, state: State, upstream_states: Dict[Edge, State] ) -> State: """ Checks if the parent task is ready to proceed with mapping. Args: - state (State): the current state of this task - upstream_states (Dict[Edge, Union[State, List[State]]]): the upstream states Raises: - ENDRUN: either way, we dont continue past this point """ if state.is_mapped(): # this indicates we are executing a re-run of a mapped pipeline; # in this case, we populate both `map_states` and `cached_inputs` # to ensure the flow runner can properly regenerate the child tasks, # regardless of whether we mapped over an exchanged piece of data # or a non-data-exchanging upstream dependency if len(state.map_states) == 0 and state.n_map_states > 0: # type: ignore state.map_states = [None] * state.n_map_states # type: ignore state.cached_inputs = { edge.key: state._result # type: ignore for edge, state in upstream_states.items() if edge.key } raise ENDRUN(state) # we can't map if there are no success states with iterables upstream if upstream_states and not any( [ edge.mapped and state.is_successful() for edge, state in upstream_states.items() ] ): new_state = Failed("No upstream states can be mapped over.") # type: State raise ENDRUN(new_state) elif not all( [ hasattr(state.result, "__getitem__") for edge, state in upstream_states.items() if state.is_successful() and not state.is_mapped() and edge.mapped ] ): new_state = Failed("At least one upstream state has an unmappable result.") raise ENDRUN(new_state) else: # compute and set n_map_states n_map_states = min( [ len(s.result) for e, s in upstream_states.items() if e.mapped and s.is_successful() and not s.is_mapped() ] + [ s.n_map_states # type: ignore for e, s in upstream_states.items() if e.mapped and s.is_mapped() ], default=0, ) new_state = Mapped( "Ready to proceed with mapping.", n_map_states=n_map_states ) raise ENDRUN(new_state)
def get_flow_run_state( self, state: State, task_states: Dict[Task, State], task_contexts: Dict[Task, Dict[str, Any]], return_tasks: Set[Task], task_runner_state_handlers: Iterable[Callable], executor: "prefect.executors.base.Executor", ) -> State: """ Runs the flow. Args: - state (State): starting state for the Flow. Defaults to `Pending` - task_states (dict): dictionary of task states to begin computation with, with keys being Tasks and values their corresponding state - task_contexts (Dict[Task, Dict[str, Any]]): contexts that will be provided to each task - return_tasks ([Task], optional): list of Tasks to include in the final returned Flow state. Defaults to `None` - task_runner_state_handlers (Iterable[Callable]): A list of state change handlers that will be provided to the task_runner, and called whenever a task changes state. - executor (Executor): executor to use when performing computation; defaults to the executor provided in your prefect configuration Returns: - State: `State` representing the final post-run state of the `Flow`. """ # this dictionary is used for tracking the states of "children" mapped tasks; # when running on Dask, we want to avoid serializing futures, so instead # of storing child task states in the `map_states` attribute we instead store # in this dictionary and only after they are resolved do we attach them to the Mapped state mapped_children = dict() # type: Dict[Task, list] if not state.is_running(): self.logger.info("Flow is not in a Running state.") raise ENDRUN(state) if return_tasks is None: return_tasks = set() if set(return_tasks).difference(self.flow.tasks): raise ValueError("Some tasks in return_tasks were not found in the flow.") def extra_context(task: Task, task_index: int = None) -> dict: return { "task_name": task.name, "task_tags": task.tags, "task_index": task_index, } # -- process each task in order with self.check_for_cancellation(), executor.start(): for task in self.flow.sorted_tasks(): task_state = task_states.get(task) # if a task is a constant task, we already know its return value # no need to use up resources by running it through a task runner if task_state is None and isinstance( task, prefect.tasks.core.constants.Constant ): task_states[task] = task_state = Success(result=task.value) # Always restart completed resource setup/cleanup tasks and # secret tasks unless they were explicitly cached. # TODO: we only need to rerun these tasks if any pending # downstream tasks depend on them. if ( isinstance( task, ( prefect.tasks.core.resource_manager.ResourceSetupTask, prefect.tasks.core.resource_manager.ResourceCleanupTask, prefect.tasks.secrets.SecretBase, ), ) and task_state is not None and task_state.is_finished() and not task_state.is_cached() ): task_states[task] = task_state = Pending() # if the state is finished, don't run the task, just use the provided state if # the state is cached / mapped, we still want to run the task runner pipeline # steps to either ensure the cache is still valid / or to recreate the mapped # pipeline for possible retries if ( isinstance(task_state, State) and task_state.is_finished() and not task_state.is_cached() and not task_state.is_mapped() ): continue upstream_states = {} # type: Dict[Edge, State] # this dictionary is used exclusively for "reduce" tasks in particular we store # the states / futures corresponding to the upstream children, and if running # on Dask, let Dask resolve them at the appropriate time. # Note: this is an optimization that allows Dask to resolve the mapped # dependencies by "elevating" them to a function argument. upstream_mapped_states = {} # type: Dict[Edge, list] # -- process each edge to the task for edge in self.flow.edges_to(task): # load the upstream task states (supplying Pending as a default) upstream_states[edge] = task_states.get( edge.upstream_task, Pending(message="Task state not available.") ) # if the edge is flattened and not the result of a map, then we # preprocess the upstream states. If it IS the result of a # map, it will be handled in `prepare_upstream_states_for_mapping` if edge.flattened: if not isinstance(upstream_states[edge], Mapped): upstream_states[edge] = executor.submit( executors.flatten_upstream_state, upstream_states[edge] ) # this checks whether the task is a "reduce" task for a mapped pipeline # and if so, collects the appropriate upstream children if not edge.mapped and isinstance(upstream_states[edge], Mapped): children = mapped_children.get(edge.upstream_task, []) # if the edge is flattened, then we need to wait for the mapped children # to complete and then flatten them if edge.flattened: children = executors.flatten_mapped_children( mapped_children=children, executor=executor ) upstream_mapped_states[edge] = children # augment edges with upstream constants for key, val in self.flow.constants[task].items(): edge = Edge( upstream_task=prefect.tasks.core.constants.Constant(val), downstream_task=task, key=key, ) upstream_states[edge] = Success( "Auto-generated constant value", result=ConstantResult(value=val), ) # handle mapped tasks if any(edge.mapped for edge in upstream_states.keys()): # wait on upstream states to determine the width of the pipeline # this is the key to depth-first execution upstream_states = executor.wait( {e: state for e, state in upstream_states.items()} ) # we submit the task to the task runner to determine if # we can proceed with mapping - if the new task state is not a Mapped # state then we don't proceed task_states[task] = executor.wait( executor.submit( run_task, task=task, state=task_state, # original state upstream_states=upstream_states, context=dict( prefect.context, **task_contexts.get(task, {}) ), flow_result=self.flow.result, task_runner_cls=self.task_runner_cls, task_runner_state_handlers=task_runner_state_handlers, upstream_mapped_states=upstream_mapped_states, is_mapped_parent=True, extra_context=extra_context(task), ) ) # either way, we should now have enough resolved states to restructure # the upstream states into a list of upstream state dictionaries to iterate over list_of_upstream_states = ( executors.prepare_upstream_states_for_mapping( task_states[task], upstream_states, mapped_children, executor=executor, ) ) submitted_states = [] for idx, states in enumerate(list_of_upstream_states): # if we are on a future rerun of a partially complete flow run, # there might be mapped children in a retrying state; this check # looks into the current task state's map_states for such info if ( isinstance(task_state, Mapped) and len(task_state.map_states) >= idx + 1 ): current_state = task_state.map_states[ idx ] # type: Optional[State] elif isinstance(task_state, Mapped): current_state = None else: current_state = task_state # this is where each child is submitted for actual work submitted_states.append( executor.submit( run_task, task=task, state=current_state, upstream_states=states, context=dict( prefect.context, **task_contexts.get(task, {}), map_index=idx, ), flow_result=self.flow.result, task_runner_cls=self.task_runner_cls, task_runner_state_handlers=task_runner_state_handlers, upstream_mapped_states=upstream_mapped_states, extra_context=extra_context(task, task_index=idx), ) ) if isinstance(task_states.get(task), Mapped): mapped_children[task] = submitted_states # type: ignore else: task_states[task] = executor.submit( run_task, task=task, state=task_state, upstream_states=upstream_states, context=dict(prefect.context, **task_contexts.get(task, {})), flow_result=self.flow.result, task_runner_cls=self.task_runner_cls, task_runner_state_handlers=task_runner_state_handlers, upstream_mapped_states=upstream_mapped_states, extra_context=extra_context(task), ) # --------------------------------------------- # Collect results # --------------------------------------------- # terminal tasks determine if the flow is finished terminal_tasks = self.flow.terminal_tasks() # reference tasks determine flow state reference_tasks = self.flow.reference_tasks() # wait until all terminal tasks are finished final_tasks = terminal_tasks.union(reference_tasks).union(return_tasks) final_states = executor.wait( { t: task_states.get(t, Pending("Task not evaluated by FlowRunner.")) for t in final_tasks } ) # also wait for any children of Mapped tasks to finish, and add them # to the dictionary to determine flow state all_final_states = final_states.copy() for t, s in list(final_states.items()): if s.is_mapped(): # ensure we wait for any mapped children to complete if t in mapped_children: s.map_states = executor.wait(mapped_children[t]) s.result = [ms.result for ms in s.map_states] all_final_states[t] = s.map_states assert isinstance(final_states, dict) key_states = set(flatten_seq([all_final_states[t] for t in reference_tasks])) terminal_states = set( flatten_seq([all_final_states[t] for t in terminal_tasks]) ) return_states = {t: final_states[t] for t in return_tasks} state = self.determine_final_state( state=state, key_states=key_states, return_states=return_states, terminal_states=terminal_states, ) return state
def get_task_run_state( self, state: State, inputs: Dict[str, Result], timeout_handler: Optional[Callable] = None, ) -> State: """ Runs the task and traps any signals or errors it raises. Also checkpoints the result of a successful task, if `task.checkpoint` is `True`. Args: - state (State): the current state of this task - inputs (Dict[str, Result], optional): a dictionary of inputs whose keys correspond to the task's `run()` arguments. - timeout_handler (Callable, optional): function for timing out task execution, with call signature `handler(fn, *args, **kwargs)`. Defaults to `prefect.utilities.executors.timeout_handler` Returns: - State: the state of the task after running the check Raises: - signals.PAUSE: if the task raises PAUSE - ENDRUN: if the task is not ready to run """ if not state.is_running(): self.logger.debug( "Task '{name}': can't run task because it's not in a " "Running state; ending run.".format(name=prefect.context.get( "task_full_name", self.task.name))) raise ENDRUN(state) try: self.logger.debug( "Task '{name}': Calling task.run() method...".format( name=prefect.context.get("task_full_name", self.task.name))) timeout_handler = (timeout_handler or prefect.utilities.executors.timeout_handler) raw_inputs = {k: r.value for k, r in inputs.items()} if getattr(self.task, "log_stdout", False): with redirect_stdout( prefect.utilities.logging.RedirectToLog( self.logger)): # type: ignore result = timeout_handler(self.task.run, timeout=self.task.timeout, **raw_inputs) else: result = timeout_handler(self.task.run, timeout=self.task.timeout, **raw_inputs) except KeyboardInterrupt: self.logger.debug("Interrupt signal raised, cancelling task run.") state = Cancelled( message="Interrupt signal raised, cancelling task run.") return state # inform user of timeout except TimeoutError as exc: if prefect.context.get("raise_on_exception"): raise exc state = TimedOut("Task timed out during execution.", result=exc, cached_inputs=inputs) return state except signals.LOOP as exc: new_state = exc.state assert isinstance(new_state, Looped) new_state.result = Result(value=new_state.result, result_handler=self.result_handler) new_state.cached_inputs = inputs new_state.message = exc.state.message or "Task is looping ({})".format( new_state.loop_count) return new_state result = Result(value=result, result_handler=self.result_handler) state = Success(result=result, message="Task run succeeded.", cached_inputs=inputs) ## checkpoint tasks if a result_handler is present, except for when the user has opted out by disabling checkpointing if (state.is_successful() and prefect.context.get("checkpointing") is True and self.task.checkpoint is not False and self.result_handler is not None): state._result.store_safe_value() return state
def get_task_run_state(self, state: State, inputs: Dict[str, Result]) -> State: """ Runs the task and traps any signals or errors it raises. Also checkpoints the result of a successful task, if `task.checkpoint` is `True`. Args: - state (State): the current state of this task - inputs (Dict[str, Result], optional): a dictionary of inputs whose keys correspond to the task's `run()` arguments. Returns: - State: the state of the task after running the check Raises: - signals.PAUSE: if the task raises PAUSE - ENDRUN: if the task is not ready to run """ if not state.is_running(): self.logger.debug( "Task '{name}': can't run task because it's not in a " "Running state; ending run.".format(name=prefect.context.get( "task_full_name", self.task.name))) raise ENDRUN(state) value = None raw_inputs = {k: r.value for k, r in inputs.items()} try: self.logger.debug( "Task '{name}': Calling task.run() method...".format( name=prefect.context.get("task_full_name", self.task.name))) timeout_handler = prefect.utilities.executors.timeout_handler if getattr(self.task, "log_stdout", False): with redirect_stdout( prefect.utilities.logging.RedirectToLog( self.logger) # type: ignore ): value = timeout_handler(self.task.run, timeout=self.task.timeout, **raw_inputs) else: value = timeout_handler(self.task.run, timeout=self.task.timeout, **raw_inputs) # inform user of timeout except TimeoutError as exc: if prefect.context.get("raise_on_exception"): raise exc state = TimedOut("Task timed out during execution.", result=exc) return state except signals.LOOP as exc: new_state = exc.state assert isinstance(new_state, Looped) new_state.result = self.result.from_value(value=new_state.result) new_state.message = exc.state.message or "Task is looping ({})".format( new_state.loop_count) return new_state # checkpoint tasks if a result is present, except for when the user has opted out by # disabling checkpointing if (prefect.context.get("checkpointing") is True and self.task.checkpoint is not False and value is not None): try: formatting_kwargs = { **prefect.context.get("parameters", {}).copy(), **raw_inputs, **prefect.context, } result = self.result.write(value, **formatting_kwargs) except NotImplementedError: result = self.result.from_value(value=value) else: result = self.result.from_value(value=value) state = Success(result=result, message="Task run succeeded.") return state
def check_task_is_ready(self, state: State) -> State: """ Checks to make sure the task is ready to run (Pending or Mapped). If the state is Paused, an ENDRUN is raised. Args: - state (State): the current state of this task Returns: - State: the state of the task after running the check Raises: - ENDRUN: if the task is not ready to run """ # the task is paused if isinstance(state, Paused): self.logger.debug( "Task '{name}': task is paused; ending run.".format( name=prefect.context.get("task_full_name", self.task.name) ) ) raise ENDRUN(state) # the task is ready elif state.is_pending(): return state # the task is mapped, in which case we still proceed so that the children tasks # are generated (note that if the children tasks) elif state.is_mapped(): self.logger.debug( "Task '{name}': task is mapped, but run will proceed so children are generated.".format( name=prefect.context.get("task_full_name", self.task.name) ) ) return state # this task is already running elif state.is_running(): self.logger.debug( "Task '{name}': task is already running.".format( name=prefect.context.get("task_full_name", self.task.name) ) ) raise ENDRUN(state) elif state.is_cached(): return state # this task is already finished elif state.is_finished(): self.logger.debug( "Task '{name}': task is already finished.".format( name=prefect.context.get("task_full_name", self.task.name) ) ) raise ENDRUN(state) # this task is not pending else: self.logger.debug( "Task '{name}' is not ready to run or state was unrecognized ({state}).".format( name=prefect.context.get("task_full_name", self.task.name), state=state, ) ) raise ENDRUN(state)
def call_runner_target_handlers(self, old_state: State, new_state: State) -> State: """ A special state handler that the FlowRunner uses to call its flow's state handlers. This method is called as part of the base Runner's `handle_state_change()` method. Args: - old_state (State): the old (previous) state - new_state (State): the new (current) state Returns: - State: the new state """ raise_on_exception = prefect.context.get("raise_on_exception", False) try: new_state = super().call_runner_target_handlers( old_state=old_state, new_state=new_state ) except Exception as exc: msg = "Exception raised while calling state handlers: {}".format(repr(exc)) self.logger.exception(msg) if raise_on_exception: raise exc new_state = Failed(msg, result=exc) flow_run_id = prefect.context.get("flow_run_id", None) version = prefect.context.get("flow_run_version") try: cloud_state = new_state state = self.client.set_flow_run_state( flow_run_id=flow_run_id, version=version if cloud_state.is_running() else None, state=cloud_state, ) except VersionLockError as exc: state = self.client.get_flow_run_state(flow_run_id=flow_run_id) if state.is_running(): self.logger.debug( "Version lock encountered and flow is already in a running state." ) raise ENDRUN(state=state) from exc self.logger.debug( "Version lock encountered, proceeding with state {}...".format( type(state).__name__ ) ) new_state = state except Exception as exc: self.logger.exception( "Failed to set flow state with error: {}".format(repr(exc)) ) raise ENDRUN(state=new_state) from exc if state.is_queued(): state.state = old_state # type: ignore raise ENDRUN(state=state) prefect.context.update(flow_run_version=(version or 0) + 1) return new_state
def initialize_run( # type: ignore self, state: Optional[State], task_states: Dict[Task, State], context: Dict[str, Any], task_contexts: Dict[Task, Dict[str, Any]], parameters: Dict[str, Any], ) -> FlowRunnerInitializeResult: """ Initializes the Task run by initializing state and context appropriately. If the provided state is a Submitted state, the state it wraps is extracted. Args: - state (Optional[State]): the initial state of the run - task_states (Dict[Task, State]): a dictionary of any initial task states - context (Dict[str, Any], optional): prefect.Context to use for execution to use for each Task run - task_contexts (Dict[Task, Dict[str, Any]], optional): contexts that will be provided to each task - parameters(dict): the parameter values for the run Returns: - NamedTuple: a tuple of initialized objects: `(state, task_states, context, task_contexts)` """ # load id from context flow_run_id = prefect.context.get("flow_run_id") try: flow_run_info = self.client.get_flow_run_info(flow_run_id) except Exception as exc: self.logger.debug( "Failed to retrieve flow state with error: {}".format(repr(exc)) ) if state is None: state = Failed( message="Could not retrieve state from Prefect Cloud", result=exc ) raise ENDRUN(state=state) from exc updated_context = context or {} updated_context.update(flow_run_info.context or {}) updated_context.update( flow_id=flow_run_info.flow_id, flow_run_id=flow_run_info.id, flow_run_version=flow_run_info.version, flow_run_name=flow_run_info.name, scheduled_start_time=flow_run_info.scheduled_start_time, ) tasks = {slug: t for t, slug in self.flow.slugs.items()} # update task states and contexts for task_run in flow_run_info.task_runs: try: task = tasks[task_run.task_slug] except KeyError as exc: msg = ( f"Task slug {task_run.task_slug} not found in the current Flow; " f"this is usually caused by changing the Flow without reregistering " f"it with the Prefect API." ) raise KeyError(msg) from exc task_states.setdefault(task, task_run.state) task_contexts.setdefault(task, {}).update( task_id=task_run.task_id, task_run_id=task_run.id, task_run_version=task_run.version, ) # if state is set, keep it; otherwise load from Cloud state = state or flow_run_info.state # type: ignore # update parameters, prioritizing kwarg-provided params updated_parameters = flow_run_info.parameters or {} # type: ignore updated_parameters.update(parameters) return super().initialize_run( state=state, task_states=task_states, context=updated_context, task_contexts=task_contexts, parameters=updated_parameters, )
def check_for_cancellation(self) -> Iterator: """Contextmanager used to wrap a cancellable section of a flow run.""" cancelling = False done = threading.Event() flow_run_version = None context = prefect.context.to_dict() def interrupt_if_cancelling() -> None: # We need to copy the context into this thread, since context is a # thread local. with prefect.context(context): flow_run_id = prefect.context["flow_run_id"] while True: exiting_context = done.wait( prefect.config.cloud.check_cancellation_interval ) try: self.logger.debug("Checking flow run state...") flow_run_info = self.client.get_flow_run_info(flow_run_id) except Exception: self.logger.warning( "Error getting flow run info", exc_info=True ) continue if not flow_run_info.state.is_running(): self.logger.warning( "Flow run is no longer in a running state; the current state is: %r", flow_run_info.state, ) if isinstance(flow_run_info.state, Cancelling): self.logger.info( "Flow run has been cancelled, cancelling active tasks" ) nonlocal cancelling nonlocal flow_run_version cancelling = True flow_run_version = flow_run_info.version # If not already leaving context, raise KeyboardInterrupt in the main thread if not exiting_context: if hasattr(signal, "raise_signal"): # New in python 3.8 signal.raise_signal(signal.SIGINT) # type: ignore else: if os.name == "nt": # This doesn't actually send a signal, so it will only # interrupt the next Python bytecode instruction - if the # main thread is blocked in a c extension the interrupt # won't be seen until that returns. from _thread import interrupt_main interrupt_main() else: signal.pthread_kill( threading.main_thread().ident, signal.SIGINT # type: ignore ) break elif exiting_context: break thread = threading.Thread(target=interrupt_if_cancelling, daemon=True) thread.start() try: yield except KeyboardInterrupt: if not cancelling: raise finally: done.set() thread.join() if cancelling: prefect.context.update(flow_run_version=flow_run_version) raise ENDRUN(state=Cancelled("Flow run is cancelled"))
def get_flow_run_state( self, state: State, task_states: Dict[Task, State], task_contexts: Dict[Task, Dict[str, Any]], return_tasks: Set[Task], task_runner_state_handlers: Iterable[Callable], executor: "prefect.engine.executors.base.Executor", ) -> State: """ Runs the flow. Args: - state (State): starting state for the Flow. Defaults to `Pending` - task_states (dict): dictionary of task states to begin computation with, with keys being Tasks and values their corresponding state - task_contexts (Dict[Task, Dict[str, Any]]): contexts that will be provided to each task - return_tasks ([Task], optional): list of Tasks to include in the final returned Flow state. Defaults to `None` - task_runner_state_handlers (Iterable[Callable]): A list of state change handlers that will be provided to the task_runner, and called whenever a task changes state. - executor (Executor): executor to use when performing computation; defaults to the executor provided in your prefect configuration Returns: - State: `State` representing the final post-run state of the `Flow`. """ if not state.is_running(): self.logger.info("Flow is not in a Running state.") raise ENDRUN(state) if return_tasks is None: return_tasks = set() if set(return_tasks).difference(self.flow.tasks): raise ValueError( "Some tasks in return_tasks were not found in the flow.") # -- process each task in order with executor.start(): for task in self.flow.sorted_tasks(): task_state = task_states.get(task) if task_state is None and isinstance( task, prefect.tasks.core.constants.Constant): task_states[task] = task_state = Success(result=task.value) # if the state is finished, don't run the task, just use the provided state if (isinstance(task_state, State) and task_state.is_finished() and not task_state.is_cached() and not task_state.is_mapped()): continue upstream_states = { } # type: Dict[Edge, Union[State, Iterable]] # -- process each edge to the task for edge in self.flow.edges_to(task): upstream_states[edge] = task_states.get( edge.upstream_task, Pending(message="Task state not available.")) # -- run the task with prefect.context(task_full_name=task.name, task_tags=task.tags): task_states[task] = executor.submit( self.run_task, task=task, state=task_state, upstream_states=upstream_states, context=dict(prefect.context, **task_contexts.get(task, {})), task_runner_state_handlers=task_runner_state_handlers, executor=executor, ) # --------------------------------------------- # Collect results # --------------------------------------------- # terminal tasks determine if the flow is finished terminal_tasks = self.flow.terminal_tasks() # reference tasks determine flow state reference_tasks = self.flow.reference_tasks() # wait until all terminal tasks are finished final_tasks = terminal_tasks.union(reference_tasks).union( return_tasks) final_states = executor.wait({ t: task_states.get(t, Pending("Task not evaluated by FlowRunner.")) for t in final_tasks }) # also wait for any children of Mapped tasks to finish, and add them # to the dictionary to determine flow state all_final_states = final_states.copy() for t, s in list(final_states.items()): if s.is_mapped(): s.map_states = executor.wait(s.map_states) s.result = [ms.result for ms in s.map_states] all_final_states[t] = s.map_states assert isinstance(final_states, dict) key_states = set( flatten_seq([all_final_states[t] for t in reference_tasks])) terminal_states = set( flatten_seq([all_final_states[t] for t in terminal_tasks])) return_states = {t: final_states[t] for t in return_tasks} state = self.determine_final_state( state=state, key_states=key_states, return_states=return_states, terminal_states=terminal_states, ) return state