def skip_all_except(self, ti: TaskInstance, branch_task_ids: Union[str, Iterable[str]]): """ This method implements the logic for a branching operator; given a single task ID or list of task IDs to follow, this skips all other tasks immediately downstream of this operator. """ self.log.info("Following branch %s", branch_task_ids) if isinstance(branch_task_ids, str): branch_task_ids = [branch_task_ids] dag_run = ti.get_dagrun() task = ti.task dag = task.dag downstream_tasks = task.downstream_list if downstream_tasks: # Also check downstream tasks of the branch task. In case the task to skip # is also a downstream task of the branch task, we exclude it from skipping. branch_downstream_task_ids = set() # type: Set[str] for b in branch_task_ids: branch_downstream_task_ids.update(dag. get_task(b). get_flat_relative_ids(upstream=False)) skip_tasks = [t for t in downstream_tasks if t.task_id not in branch_task_ids and t.task_id not in branch_downstream_task_ids] self.log.info("Skipping tasks %s", [t.task_id for t in skip_tasks]) self.skip(dag_run, ti.execution_date, skip_tasks)
def skip_all_except(self, ti: TaskInstance, branch_task_ids: Union[str, Iterable[str]]): """ This method implements the logic for a branching operator; given a single task ID or list of task IDs to follow, this skips all other tasks immediately downstream of this operator. branch_task_ids is stored to XCom so that NotPreviouslySkippedDep knows skipped tasks or newly added tasks should be skipped when they are cleared. """ self.log.info("Following branch %s", branch_task_ids) if isinstance(branch_task_ids, str): branch_task_ids = {branch_task_ids} branch_task_ids = set(branch_task_ids) dag_run = ti.get_dagrun() task = ti.task dag = task.dag assert dag # For Mypy. # At runtime, the downstream list will only be operators downstream_tasks = cast("List[BaseOperator]", task.downstream_list) if downstream_tasks: # For a branching workflow that looks like this, when "branch" does skip_all_except("task1"), # we intuitively expect both "task1" and "join" to execute even though strictly speaking, # "join" is also immediately downstream of "branch" and should have been skipped. Therefore, # we need a special case here for such empty branches: Check downstream tasks of branch_task_ids. # In case the task to skip is also downstream of branch_task_ids, we add it to branch_task_ids and # exclude it from skipping. # # branch -----> join # \ ^ # v / # task1 # for branch_task_id in list(branch_task_ids): branch_task_ids.update( dag.get_task(branch_task_id).get_flat_relative_ids( upstream=False)) skip_tasks = [ t for t in downstream_tasks if t.task_id not in branch_task_ids ] follow_task_ids = [ t.task_id for t in downstream_tasks if t.task_id in branch_task_ids ] self.log.info("Skipping tasks %s", [t.task_id for t in skip_tasks]) with create_session() as session: self._set_state_to_skipped(dag_run, skip_tasks, session=session) # For some reason, session.commit() needs to happen before xcom_push. # Otherwise the session is not committed. session.commit() ti.xcom_push(key=XCOM_SKIPMIXIN_KEY, value={XCOM_SKIPMIXIN_FOLLOWED: follow_task_ids})
def skip_all_except( self, ti: TaskInstance, branch_task_ids: Union[str, Iterable[str]] ): """ This method implements the logic for a branching operator; given a single task ID or list of task IDs to follow, this skips all other tasks immediately downstream of this operator. branch_task_ids is stored to XCom so that NotPreviouslySkippedDep knows skipped tasks or newly added tasks should be skipped when they are cleared. """ self.log.info("Following branch %s", branch_task_ids) if isinstance(branch_task_ids, str): branch_task_ids = [branch_task_ids] dag_run = ti.get_dagrun() task = ti.task dag = task.dag downstream_tasks = task.downstream_list if downstream_tasks: # Also check downstream tasks of the branch task. In case the task to skip # is also a downstream task of the branch task, we exclude it from skipping. branch_downstream_task_ids = set() # type: Set[str] for branch_task_id in branch_task_ids: branch_downstream_task_ids.update( dag.get_task(branch_task_id).get_flat_relative_ids(upstream=False) ) skip_tasks = [ t for t in downstream_tasks if t.task_id not in branch_task_ids and t.task_id not in branch_downstream_task_ids ] self.log.info("Skipping tasks %s", [t.task_id for t in skip_tasks]) with create_session() as session: self._set_state_to_skipped( dag_run, ti.execution_date, skip_tasks, session=session ) ti.xcom_push( key=XCOM_SKIPMIXIN_KEY, value={XCOM_SKIPMIXIN_FOLLOWED: branch_task_ids} )
def _render_log_id(self, ti: TaskInstance, try_number: int) -> str: with create_session() as session: dag_run = ti.get_dagrun(session=session) if USE_PER_RUN_LOG_ID: log_id_template = dag_run.get_log_template( session=session).elasticsearch_id else: log_id_template = self.log_id_template dag = ti.task.dag assert dag is not None # For Mypy. try: data_interval: Tuple[datetime, datetime] = dag.get_run_data_interval(dag_run) except AttributeError: # ti.task is not always set. data_interval = (dag_run.data_interval_start, dag_run.data_interval_end) if self.json_format: data_interval_start = self._clean_date(data_interval[0]) data_interval_end = self._clean_date(data_interval[1]) execution_date = self._clean_date(dag_run.execution_date) else: if data_interval[0]: data_interval_start = data_interval[0].isoformat() else: data_interval_start = "" if data_interval[1]: data_interval_end = data_interval[1].isoformat() else: data_interval_end = "" execution_date = dag_run.execution_date.isoformat() return log_id_template.format( dag_id=ti.dag_id, task_id=ti.task_id, run_id=getattr(ti, "run_id", ""), data_interval_start=data_interval_start, data_interval_end=data_interval_end, execution_date=execution_date, try_number=try_number, map_index=getattr(ti, "map_index", ""), )
def render_log_filename( self, ti: TaskInstance, try_number: Optional[int] = None, *, session: Session = NEW_SESSION, ): """ Renders the log attachment filename :param ti: The task instance :param try_number: The task try number :rtype: str """ dagrun = ti.get_dagrun(session=session) attachment_filename = render_log_filename( ti=ti, try_number="all" if try_number is None else try_number, filename_template=dagrun.get_log_filename_template( session=session), ) return attachment_filename
def _per_task_process(key, ti: TaskInstance, session=None): ti.refresh_from_db(lock_for_update=True, session=session) task = self.dag.get_task(ti.task_id, include_subdags=True) ti.task = task self.log.debug("Task instance to run %s state %s", ti, ti.state) # The task was already marked successful or skipped by a # different Job. Don't rerun it. if ti.state == State.SUCCESS and not self.rerun_succeeded_tasks: ti_status.succeeded.add(key) self.log.debug("Task instance %s succeeded. Don't rerun.", ti) ti_status.to_run.pop(key) if key in ti_status.running: ti_status.running.pop(key) return elif ti.state == State.SKIPPED: ti_status.skipped.add(key) self.log.debug("Task instance %s skipped. Don't rerun.", ti) ti_status.to_run.pop(key) if key in ti_status.running: ti_status.running.pop(key) return # guard against externally modified tasks instances or # in case max concurrency has been reached at task runtime elif ti.state == State.NONE: self.log.warning( "FIXME: Task instance %s state was set to None externally. This should not happen", ti) ti.set_state(State.SCHEDULED, session=session) if self.rerun_failed_tasks: # Rerun failed tasks or upstreamed failed tasks if ti.state in (State.FAILED, State.UPSTREAM_FAILED): self.log.error("Task instance %s with state %s", ti, ti.state) if key in ti_status.running: ti_status.running.pop(key) # Reset the failed task in backfill to scheduled state ti.set_state(State.SCHEDULED, session=session) elif self.rerun_succeeded_tasks and ti.state == State.SUCCESS: # Rerun succeeded tasks self.log.info( "Task instance %s with state %s, rerunning succeeded task ", ti, ti.state) if key in ti_status.running: ti_status.running.pop(key) # Reset the succeeded task in backfill to scheduled state ti.set_state(State.SCHEDULED, session=session) else: # Default behaviour which works for subdag. if ti.state in (State.FAILED, State.UPSTREAM_FAILED): self.log.error("Task instance %s with state %s", ti, ti.state) ti_status.failed.add(key) ti_status.to_run.pop(key) if key in ti_status.running: ti_status.running.pop(key) return if self.ignore_first_depends_on_past: dagrun = ti.get_dagrun(session=session) ignore_depends_on_past = dagrun.execution_date == ( start_date or ti.start_date) else: ignore_depends_on_past = False backfill_context = DepContext( deps=BACKFILL_QUEUED_DEPS, ignore_depends_on_past=ignore_depends_on_past, ignore_task_deps=self.ignore_task_deps, flag_upstream_failed=True, ) # Is the task runnable? -- then run it # the dependency checker can change states of tis if ti.are_dependencies_met(dep_context=backfill_context, session=session, verbose=self.verbose): if executor.has_task(ti): self.log.debug( "Task Instance %s already in executor waiting for queue to clear", ti) else: self.log.debug('Sending %s to executor', ti) # Skip scheduled state, we are executing immediately ti.state = State.QUEUED ti.queued_by_job_id = self.id ti.queued_dttm = timezone.utcnow() session.merge(ti) cfg_path = None if self.executor_class in ( executor_constants.LOCAL_EXECUTOR, executor_constants.SEQUENTIAL_EXECUTOR, ): cfg_path = tmp_configuration_copy() executor.queue_task_instance( ti, mark_success=self.mark_success, pickle_id=pickle_id, ignore_task_deps=self.ignore_task_deps, ignore_depends_on_past=ignore_depends_on_past, pool=self.pool, cfg_path=cfg_path, ) ti_status.running[key] = ti ti_status.to_run.pop(key) session.commit() return if ti.state == State.UPSTREAM_FAILED: self.log.error("Task instance %s upstream failed", ti) ti_status.failed.add(key) ti_status.to_run.pop(key) if key in ti_status.running: ti_status.running.pop(key) return # special case if ti.state == State.UP_FOR_RETRY: self.log.debug( "Task instance %s retry period not expired yet", ti) if key in ti_status.running: ti_status.running.pop(key) ti_status.to_run[key] = ti return # special case if ti.state == State.UP_FOR_RESCHEDULE: self.log.debug( "Task instance %s reschedule period not expired yet", ti) if key in ti_status.running: ti_status.running.pop(key) ti_status.to_run[key] = ti return # all remaining tasks self.log.debug('Adding %s to not_ready', ti) ti_status.not_ready.add(key)
def _get_dep_statuses(self, ti: TI, session, dep_context): if dep_context.ignore_depends_on_past: reason = "The context specified that the state of past DAGs could be ignored." yield self._passing_status(reason=reason) return if not ti.task.depends_on_past: yield self._passing_status( reason="The task did not have depends_on_past set.") return dr = ti.get_dagrun(session=session) if not dr: yield self._passing_status( reason="This task instance does not belong to a DAG.") return # Don't depend on the previous task instance if we are the first task. catchup = ti.task.dag and ti.task.dag.catchup if catchup: last_dagrun = dr.get_previous_scheduled_dagrun(session) else: last_dagrun = dr.get_previous_dagrun(session=session) # First ever run for this DAG. if not last_dagrun: yield self._passing_status( reason= "This task instance was the first task instance for its task.") return # There was a DAG run, but the task wasn't active back then. if catchup and last_dagrun.execution_date < ti.task.start_date: yield self._passing_status( reason= "This task instance was the first task instance for its task.") return previous_ti = last_dagrun.get_task_instance(ti.task_id, session=session) if not previous_ti: if ti.task.ignore_first_depends_on_past: has_historical_ti = (session.query(func.count( TI.dag_id)).filter( TI.dag_id == ti.dag_id, TI.task_id == ti.task_id, TI.execution_date < ti.execution_date, ).scalar() > 0) if not has_historical_ti: yield self._passing_status( reason= "ignore_first_depends_on_past is true for this task " "and it is the first task instance for its task.") return yield self._failing_status( reason= "depends_on_past is true for this task's DAG, but the previous " "task instance has not run yet.") return if previous_ti.state not in {State.SKIPPED, State.SUCCESS}: yield self._failing_status(reason=( f"depends_on_past is true for this task, but the previous task instance {previous_ti} " f"is in the state '{previous_ti.state}' which is not a successful state." )) previous_ti.task = ti.task if ti.task.wait_for_downstream and not previous_ti.are_dependents_done( session=session): yield self._failing_status(reason=( f"The tasks downstream of the previous task instance {previous_ti} haven't completed " f"(and wait_for_downstream is True)."))