def skip_all_except(self, ti: TaskInstance, branch_task_ids: Union[str, Iterable[str]]): """ This method implements the logic for a branching operator; given a single task ID or list of task IDs to follow, this skips all other tasks immediately downstream of this operator. branch_task_ids is stored to XCom so that NotPreviouslySkippedDep knows skipped tasks or newly added tasks should be skipped when they are cleared. """ self.log.info("Following branch %s", branch_task_ids) if isinstance(branch_task_ids, str): branch_task_ids = {branch_task_ids} branch_task_ids = set(branch_task_ids) dag_run = ti.get_dagrun() task = ti.task dag = task.dag assert dag # For Mypy. # At runtime, the downstream list will only be operators downstream_tasks = cast("List[BaseOperator]", task.downstream_list) if downstream_tasks: # For a branching workflow that looks like this, when "branch" does skip_all_except("task1"), # we intuitively expect both "task1" and "join" to execute even though strictly speaking, # "join" is also immediately downstream of "branch" and should have been skipped. Therefore, # we need a special case here for such empty branches: Check downstream tasks of branch_task_ids. # In case the task to skip is also downstream of branch_task_ids, we add it to branch_task_ids and # exclude it from skipping. # # branch -----> join # \ ^ # v / # task1 # for branch_task_id in list(branch_task_ids): branch_task_ids.update( dag.get_task(branch_task_id).get_flat_relative_ids( upstream=False)) skip_tasks = [ t for t in downstream_tasks if t.task_id not in branch_task_ids ] follow_task_ids = [ t.task_id for t in downstream_tasks if t.task_id in branch_task_ids ] self.log.info("Skipping tasks %s", [t.task_id for t in skip_tasks]) with create_session() as session: self._set_state_to_skipped(dag_run, skip_tasks, session=session) # For some reason, session.commit() needs to happen before xcom_push. # Otherwise the session is not committed. session.commit() ti.xcom_push(key=XCOM_SKIPMIXIN_KEY, value={XCOM_SKIPMIXIN_FOLLOWED: follow_task_ids})