def skip_all_except(self, ti: TaskInstance, branch_task_ids: Union[str, Iterable[str]]):
        """
        This method implements the logic for a branching operator; given a single
        task ID or list of task IDs to follow, this skips all other tasks
        immediately downstream of this operator.
        """
        self.log.info("Following branch %s", branch_task_ids)
        if isinstance(branch_task_ids, str):
            branch_task_ids = [branch_task_ids]

        dag_run = ti.get_dagrun()
        task = ti.task
        dag = task.dag

        downstream_tasks = task.downstream_list

        if downstream_tasks:
            # Also check downstream tasks of the branch task. In case the task to skip
            # is also a downstream task of the branch task, we exclude it from skipping.
            branch_downstream_task_ids = set()  # type: Set[str]
            for b in branch_task_ids:
                branch_downstream_task_ids.update(dag.
                                                  get_task(b).
                                                  get_flat_relative_ids(upstream=False))

            skip_tasks = [t for t in downstream_tasks
                          if t.task_id not in branch_task_ids and
                          t.task_id not in branch_downstream_task_ids]

            self.log.info("Skipping tasks %s", [t.task_id for t in skip_tasks])
            self.skip(dag_run, ti.execution_date, skip_tasks)
示例#2
0
    def skip_all_except(self, ti: TaskInstance,
                        branch_task_ids: Union[str, Iterable[str]]):
        """
        This method implements the logic for a branching operator; given a single
        task ID or list of task IDs to follow, this skips all other tasks
        immediately downstream of this operator.

        branch_task_ids is stored to XCom so that NotPreviouslySkippedDep knows skipped tasks or
        newly added tasks should be skipped when they are cleared.
        """
        self.log.info("Following branch %s", branch_task_ids)
        if isinstance(branch_task_ids, str):
            branch_task_ids = {branch_task_ids}

        branch_task_ids = set(branch_task_ids)

        dag_run = ti.get_dagrun()
        task = ti.task
        dag = task.dag
        assert dag  # For Mypy.

        # At runtime, the downstream list will only be operators
        downstream_tasks = cast("List[BaseOperator]", task.downstream_list)

        if downstream_tasks:
            # For a branching workflow that looks like this, when "branch" does skip_all_except("task1"),
            # we intuitively expect both "task1" and "join" to execute even though strictly speaking,
            # "join" is also immediately downstream of "branch" and should have been skipped. Therefore,
            # we need a special case here for such empty branches: Check downstream tasks of branch_task_ids.
            # In case the task to skip is also downstream of branch_task_ids, we add it to branch_task_ids and
            # exclude it from skipping.
            #
            # branch  ----->  join
            #   \            ^
            #     v        /
            #       task1
            #
            for branch_task_id in list(branch_task_ids):
                branch_task_ids.update(
                    dag.get_task(branch_task_id).get_flat_relative_ids(
                        upstream=False))

            skip_tasks = [
                t for t in downstream_tasks if t.task_id not in branch_task_ids
            ]
            follow_task_ids = [
                t.task_id for t in downstream_tasks
                if t.task_id in branch_task_ids
            ]

            self.log.info("Skipping tasks %s", [t.task_id for t in skip_tasks])
            with create_session() as session:
                self._set_state_to_skipped(dag_run,
                                           skip_tasks,
                                           session=session)
                # For some reason, session.commit() needs to happen before xcom_push.
                # Otherwise the session is not committed.
                session.commit()
                ti.xcom_push(key=XCOM_SKIPMIXIN_KEY,
                             value={XCOM_SKIPMIXIN_FOLLOWED: follow_task_ids})
示例#3
0
    def skip_all_except(
        self, ti: TaskInstance, branch_task_ids: Union[str, Iterable[str]]
    ):
        """
        This method implements the logic for a branching operator; given a single
        task ID or list of task IDs to follow, this skips all other tasks
        immediately downstream of this operator.

        branch_task_ids is stored to XCom so that NotPreviouslySkippedDep knows skipped tasks or
        newly added tasks should be skipped when they are cleared.
        """
        self.log.info("Following branch %s", branch_task_ids)
        if isinstance(branch_task_ids, str):
            branch_task_ids = [branch_task_ids]

        dag_run = ti.get_dagrun()
        task = ti.task
        dag = task.dag

        downstream_tasks = task.downstream_list

        if downstream_tasks:
            # Also check downstream tasks of the branch task. In case the task to skip
            # is also a downstream task of the branch task, we exclude it from skipping.
            branch_downstream_task_ids = set()  # type: Set[str]
            for branch_task_id in branch_task_ids:
                branch_downstream_task_ids.update(
                    dag.get_task(branch_task_id).get_flat_relative_ids(upstream=False)
                )

            skip_tasks = [
                t
                for t in downstream_tasks
                if t.task_id not in branch_task_ids
                and t.task_id not in branch_downstream_task_ids
            ]

            self.log.info("Skipping tasks %s", [t.task_id for t in skip_tasks])
            with create_session() as session:
                self._set_state_to_skipped(
                    dag_run, ti.execution_date, skip_tasks, session=session
                )
                ti.xcom_push(
                    key=XCOM_SKIPMIXIN_KEY, value={XCOM_SKIPMIXIN_FOLLOWED: branch_task_ids}
                )
示例#4
0
    def _render_log_id(self, ti: TaskInstance, try_number: int) -> str:
        with create_session() as session:
            dag_run = ti.get_dagrun(session=session)
            if USE_PER_RUN_LOG_ID:
                log_id_template = dag_run.get_log_template(
                    session=session).elasticsearch_id
            else:
                log_id_template = self.log_id_template

        dag = ti.task.dag
        assert dag is not None  # For Mypy.
        try:
            data_interval: Tuple[datetime,
                                 datetime] = dag.get_run_data_interval(dag_run)
        except AttributeError:  # ti.task is not always set.
            data_interval = (dag_run.data_interval_start,
                             dag_run.data_interval_end)

        if self.json_format:
            data_interval_start = self._clean_date(data_interval[0])
            data_interval_end = self._clean_date(data_interval[1])
            execution_date = self._clean_date(dag_run.execution_date)
        else:
            if data_interval[0]:
                data_interval_start = data_interval[0].isoformat()
            else:
                data_interval_start = ""
            if data_interval[1]:
                data_interval_end = data_interval[1].isoformat()
            else:
                data_interval_end = ""
            execution_date = dag_run.execution_date.isoformat()

        return log_id_template.format(
            dag_id=ti.dag_id,
            task_id=ti.task_id,
            run_id=getattr(ti, "run_id", ""),
            data_interval_start=data_interval_start,
            data_interval_end=data_interval_end,
            execution_date=execution_date,
            try_number=try_number,
            map_index=getattr(ti, "map_index", ""),
        )
示例#5
0
    def render_log_filename(
        self,
        ti: TaskInstance,
        try_number: Optional[int] = None,
        *,
        session: Session = NEW_SESSION,
    ):
        """
        Renders the log attachment filename

        :param ti: The task instance
        :param try_number: The task try number
        :rtype: str
        """
        dagrun = ti.get_dagrun(session=session)
        attachment_filename = render_log_filename(
            ti=ti,
            try_number="all" if try_number is None else try_number,
            filename_template=dagrun.get_log_filename_template(
                session=session),
        )
        return attachment_filename
示例#6
0
            def _per_task_process(key, ti: TaskInstance, session=None):
                ti.refresh_from_db(lock_for_update=True, session=session)

                task = self.dag.get_task(ti.task_id, include_subdags=True)
                ti.task = task

                self.log.debug("Task instance to run %s state %s", ti,
                               ti.state)

                # The task was already marked successful or skipped by a
                # different Job. Don't rerun it.
                if ti.state == State.SUCCESS and not self.rerun_succeeded_tasks:
                    ti_status.succeeded.add(key)
                    self.log.debug("Task instance %s succeeded. Don't rerun.",
                                   ti)
                    ti_status.to_run.pop(key)
                    if key in ti_status.running:
                        ti_status.running.pop(key)
                    return
                elif ti.state == State.SKIPPED:
                    ti_status.skipped.add(key)
                    self.log.debug("Task instance %s skipped. Don't rerun.",
                                   ti)
                    ti_status.to_run.pop(key)
                    if key in ti_status.running:
                        ti_status.running.pop(key)
                    return

                # guard against externally modified tasks instances or
                # in case max concurrency has been reached at task runtime
                elif ti.state == State.NONE:
                    self.log.warning(
                        "FIXME: Task instance %s state was set to None externally. This should not happen",
                        ti)
                    ti.set_state(State.SCHEDULED, session=session)
                if self.rerun_failed_tasks:
                    # Rerun failed tasks or upstreamed failed tasks
                    if ti.state in (State.FAILED, State.UPSTREAM_FAILED):
                        self.log.error("Task instance %s with state %s", ti,
                                       ti.state)
                        if key in ti_status.running:
                            ti_status.running.pop(key)
                        # Reset the failed task in backfill to scheduled state
                        ti.set_state(State.SCHEDULED, session=session)
                elif self.rerun_succeeded_tasks and ti.state == State.SUCCESS:
                    # Rerun succeeded tasks
                    self.log.info(
                        "Task instance %s with state %s, rerunning succeeded task ",
                        ti, ti.state)
                    if key in ti_status.running:
                        ti_status.running.pop(key)
                    # Reset the succeeded task in backfill to scheduled state
                    ti.set_state(State.SCHEDULED, session=session)
                else:
                    # Default behaviour which works for subdag.
                    if ti.state in (State.FAILED, State.UPSTREAM_FAILED):
                        self.log.error("Task instance %s with state %s", ti,
                                       ti.state)
                        ti_status.failed.add(key)
                        ti_status.to_run.pop(key)
                        if key in ti_status.running:
                            ti_status.running.pop(key)
                        return

                if self.ignore_first_depends_on_past:
                    dagrun = ti.get_dagrun(session=session)
                    ignore_depends_on_past = dagrun.execution_date == (
                        start_date or ti.start_date)
                else:
                    ignore_depends_on_past = False

                backfill_context = DepContext(
                    deps=BACKFILL_QUEUED_DEPS,
                    ignore_depends_on_past=ignore_depends_on_past,
                    ignore_task_deps=self.ignore_task_deps,
                    flag_upstream_failed=True,
                )

                # Is the task runnable? -- then run it
                # the dependency checker can change states of tis
                if ti.are_dependencies_met(dep_context=backfill_context,
                                           session=session,
                                           verbose=self.verbose):
                    if executor.has_task(ti):
                        self.log.debug(
                            "Task Instance %s already in executor waiting for queue to clear",
                            ti)
                    else:
                        self.log.debug('Sending %s to executor', ti)
                        # Skip scheduled state, we are executing immediately
                        ti.state = State.QUEUED
                        ti.queued_by_job_id = self.id
                        ti.queued_dttm = timezone.utcnow()
                        session.merge(ti)

                        cfg_path = None
                        if self.executor_class in (
                                executor_constants.LOCAL_EXECUTOR,
                                executor_constants.SEQUENTIAL_EXECUTOR,
                        ):
                            cfg_path = tmp_configuration_copy()

                        executor.queue_task_instance(
                            ti,
                            mark_success=self.mark_success,
                            pickle_id=pickle_id,
                            ignore_task_deps=self.ignore_task_deps,
                            ignore_depends_on_past=ignore_depends_on_past,
                            pool=self.pool,
                            cfg_path=cfg_path,
                        )
                        ti_status.running[key] = ti
                        ti_status.to_run.pop(key)
                    session.commit()
                    return

                if ti.state == State.UPSTREAM_FAILED:
                    self.log.error("Task instance %s upstream failed", ti)
                    ti_status.failed.add(key)
                    ti_status.to_run.pop(key)
                    if key in ti_status.running:
                        ti_status.running.pop(key)
                    return

                # special case
                if ti.state == State.UP_FOR_RETRY:
                    self.log.debug(
                        "Task instance %s retry period not expired yet", ti)
                    if key in ti_status.running:
                        ti_status.running.pop(key)
                    ti_status.to_run[key] = ti
                    return

                # special case
                if ti.state == State.UP_FOR_RESCHEDULE:
                    self.log.debug(
                        "Task instance %s reschedule period not expired yet",
                        ti)
                    if key in ti_status.running:
                        ti_status.running.pop(key)
                    ti_status.to_run[key] = ti
                    return

                # all remaining tasks
                self.log.debug('Adding %s to not_ready', ti)
                ti_status.not_ready.add(key)
示例#7
0
    def _get_dep_statuses(self, ti: TI, session, dep_context):
        if dep_context.ignore_depends_on_past:
            reason = "The context specified that the state of past DAGs could be ignored."
            yield self._passing_status(reason=reason)
            return

        if not ti.task.depends_on_past:
            yield self._passing_status(
                reason="The task did not have depends_on_past set.")
            return

        dr = ti.get_dagrun(session=session)
        if not dr:
            yield self._passing_status(
                reason="This task instance does not belong to a DAG.")
            return

        # Don't depend on the previous task instance if we are the first task.
        catchup = ti.task.dag and ti.task.dag.catchup
        if catchup:
            last_dagrun = dr.get_previous_scheduled_dagrun(session)
        else:
            last_dagrun = dr.get_previous_dagrun(session=session)

        # First ever run for this DAG.
        if not last_dagrun:
            yield self._passing_status(
                reason=
                "This task instance was the first task instance for its task.")
            return

        # There was a DAG run, but the task wasn't active back then.
        if catchup and last_dagrun.execution_date < ti.task.start_date:
            yield self._passing_status(
                reason=
                "This task instance was the first task instance for its task.")
            return

        previous_ti = last_dagrun.get_task_instance(ti.task_id,
                                                    session=session)
        if not previous_ti:
            if ti.task.ignore_first_depends_on_past:
                has_historical_ti = (session.query(func.count(
                    TI.dag_id)).filter(
                        TI.dag_id == ti.dag_id,
                        TI.task_id == ti.task_id,
                        TI.execution_date < ti.execution_date,
                    ).scalar() > 0)
                if not has_historical_ti:
                    yield self._passing_status(
                        reason=
                        "ignore_first_depends_on_past is true for this task "
                        "and it is the first task instance for its task.")
                    return

            yield self._failing_status(
                reason=
                "depends_on_past is true for this task's DAG, but the previous "
                "task instance has not run yet.")
            return

        if previous_ti.state not in {State.SKIPPED, State.SUCCESS}:
            yield self._failing_status(reason=(
                f"depends_on_past is true for this task, but the previous task instance {previous_ti} "
                f"is in the state '{previous_ti.state}' which is not a successful state."
            ))

        previous_ti.task = ti.task
        if ti.task.wait_for_downstream and not previous_ti.are_dependents_done(
                session=session):
            yield self._failing_status(reason=(
                f"The tasks downstream of the previous task instance {previous_ti} haven't completed "
                f"(and wait_for_downstream is True)."))