Ejemplo n.º 1
0
    def slots_stats(
        *,
        lock_rows: bool = False,
        session: Session = None,
    ) -> Dict[str, PoolStats]:
        """
        Get Pool stats (Number of Running, Queued, Open & Total tasks)

        If ``lock_rows`` is True, and the database engine in use supports the ``NOWAIT`` syntax, then a
        non-blocking lock will be attempted -- if the lock is not available then SQLAlchemy will throw an
        OperationalError.

        :param lock_rows: Should we attempt to obtain a row-level lock on all the Pool rows returns
        :param session: SQLAlchemy ORM Session
        """
        from airflow.models.taskinstance import TaskInstance  # Avoid circular import

        pools: Dict[str, PoolStats] = {}

        query = session.query(Pool.pool, Pool.slots)

        if lock_rows:
            query = with_row_locks(query, **nowait(session))

        pool_rows: Iterable[Tuple[str, int]] = query.all()
        for (pool_name, total_slots) in pool_rows:
            pools[pool_name] = PoolStats(total=total_slots, running=0, queued=0, open=0)

        state_count_by_pool = (
            session.query(TaskInstance.pool, TaskInstance.state, func.count())
            .filter(TaskInstance.state.in_(list(EXECUTION_STATES)))
            .group_by(TaskInstance.pool, TaskInstance.state)
        ).all()

        # calculate queued and running metrics
        count: int
        for (pool_name, state, count) in state_count_by_pool:
            stats_dict: Optional[PoolStats] = pools.get(pool_name)
            if not stats_dict:
                continue
            # TypedDict key must be a string literal, so we use if-statements to set value
            if state == "running":
                stats_dict["running"] = count
            elif state == "queued":
                stats_dict["queued"] = count
            else:
                raise AirflowException(
                    f"Unexpected state. Expected values: {EXECUTION_STATES}."
                )

        # calculate open metric
        for pool_name, stats_dict in pools.items():
            if stats_dict["total"] == -1:
                # -1 means infinite
                stats_dict["open"] = -1
            else:
                stats_dict["open"] = stats_dict["total"] - stats_dict["running"] - stats_dict["queued"]

        return pools
Ejemplo n.º 2
0
    def _run_mini_scheduler_on_child_tasks(self, session=None) -> None:
        try:
            # Re-select the row with a lock
            dag_run = with_row_locks(
                session.query(DagRun).filter_by(
                    dag_id=self.dag_id,
                    run_id=self.task_instance.run_id,
                ),
                session=session,
            ).one()

            task = self.task_instance.task
            if TYPE_CHECKING:
                assert task.dag

            # Get a partial DAG with just the specific tasks we want to examine.
            # In order for dep checks to work correctly, we include ourself (so
            # TriggerRuleDep can check the state of the task we just executed).
            partial_dag = task.dag.partial_subset(
                task.downstream_task_ids,
                include_downstream=True,
                include_upstream=False,
                include_direct_upstream=True,
            )

            dag_run.dag = partial_dag
            info = dag_run.task_instance_scheduling_decisions(session)

            skippable_task_ids = {
                task_id
                for task_id in partial_dag.task_ids
                if task_id not in task.downstream_task_ids
            }

            schedulable_tis = [
                ti for ti in info.schedulable_tis
                if ti.task_id not in skippable_task_ids
            ]
            for schedulable_ti in schedulable_tis:
                if not hasattr(schedulable_ti, "task"):
                    schedulable_ti.task = task.dag.get_task(
                        schedulable_ti.task_id)

            num = dag_run.schedule_tis(schedulable_tis)
            self.log.info(
                "%d downstream tasks scheduled from follow-on schedule check",
                num)

            session.commit()
        except OperationalError as e:
            # Any kind of DB error here is _non fatal_ as this block is just an optimisation.
            self.log.info(
                "Skipping mini scheduling run due to exception: %s",
                e.statement,
                exc_info=True,
            )
            session.rollback()
Ejemplo n.º 3
0
    def next_dagruns_to_examine(
        cls,
        state: DagRunState,
        session: Session,
        max_number: Optional[int] = None,
    ):
        """
        Return the next DagRuns that the scheduler should attempt to schedule.

        This will return zero or more DagRun rows that are row-level-locked with a "SELECT ... FOR UPDATE"
        query, you should ensure that any scheduling decisions are made in a single transaction -- as soon as
        the transaction is committed it will be unlocked.

        :rtype: list[airflow.models.DagRun]
        """
        from airflow.models.dag import DagModel

        if max_number is None:
            max_number = cls.DEFAULT_DAGRUNS_TO_EXAMINE

        # TODO: Bake this query, it is run _A lot_
        query = (
            session.query(cls)
            .filter(cls.state == state, cls.run_type != DagRunType.BACKFILL_JOB)
            .join(
                DagModel,
                DagModel.dag_id == cls.dag_id,
            )
            .filter(
                DagModel.is_paused == expression.false(),
                DagModel.is_active == expression.true(),
            )
        )
        if state == State.QUEUED:
            # For dag runs in the queued state, we check if they have reached the max_active_runs limit
            # and if so we drop them
            running_drs = (
                session.query(DagRun.dag_id, func.count(DagRun.state).label('num_running'))
                .filter(DagRun.state == DagRunState.RUNNING)
                .group_by(DagRun.dag_id)
                .subquery()
            )
            query = query.outerjoin(running_drs, running_drs.c.dag_id == DagRun.dag_id).filter(
                func.coalesce(running_drs.c.num_running, 0) < DagModel.max_active_runs
            )
        query = query.order_by(
            nulls_first(cls.last_scheduling_decision, session=session),
            cls.execution_date,
        )

        if not settings.ALLOW_FUTURE_EXEC_DATES:
            query = query.filter(DagRun.execution_date <= func.now())

        return with_row_locks(
            query.limit(max_number), of=cls, session=session, **skip_locked(session=session)
        )
Ejemplo n.º 4
0
    def test_with_row_locks(self, dialect, supports_for_update_of,
                            use_row_level_lock_conf,
                            expected_use_row_level_lock):
        query = mock.Mock()
        session = mock.Mock()
        session.bind.dialect.name = dialect
        session.bind.dialect.supports_for_update_of = supports_for_update_of
        with mock.patch("airflow.utils.sqlalchemy.USE_ROW_LEVEL_LOCKING",
                        use_row_level_lock_conf):
            returned_value = with_row_locks(query=query,
                                            session=session,
                                            nowait=True)

        if expected_use_row_level_lock:
            query.with_for_update.assert_called_once_with(nowait=True)
        else:
            assert returned_value == query
            query.with_for_update.assert_not_called()
Ejemplo n.º 5
0
    def next_dagruns_to_examine(
        cls,
        state: DagRunState,
        session: Session,
        max_number: Optional[int] = None,
    ):
        """
        Return the next DagRuns that the scheduler should attempt to schedule.

        This will return zero or more DagRun rows that are row-level-locked with a "SELECT ... FOR UPDATE"
        query, you should ensure that any scheduling decisions are made in a single transaction -- as soon as
        the transaction is committed it will be unlocked.

        :rtype: list[airflow.models.DagRun]
        """
        from airflow.models.dag import DagModel

        if max_number is None:
            max_number = cls.DEFAULT_DAGRUNS_TO_EXAMINE

        # TODO: Bake this query, it is run _A lot_
        query = (
            session.query(cls)
            .filter(cls.state == state, cls.run_type != DagRunType.BACKFILL_JOB)
            .join(
                DagModel,
                DagModel.dag_id == cls.dag_id,
            )
            .filter(
                DagModel.is_paused == expression.false(),
                DagModel.is_active == expression.true(),
            )
            .order_by(
                nulls_first(cls.last_scheduling_decision, session=session),
                cls.execution_date,
            )
        )

        if not settings.ALLOW_FUTURE_EXEC_DATES:
            query = query.filter(DagRun.execution_date <= func.now())

        return with_row_locks(
            query.limit(max_number), of=cls, session=session, **skip_locked(session=session)
        )
Ejemplo n.º 6
0
 def _fetch_callbacks(self, max_callbacks: int, session: Session = NEW_SESSION):
     """Fetches callbacks from database and add them to the internal queue for execution."""
     self.log.debug("Fetching callbacks from the database.")
     with prohibit_commit(session) as guard:
         query = (
             session.query(DbCallbackRequest)
             .order_by(DbCallbackRequest.priority_weight.asc())
             .limit(max_callbacks)
         )
         callbacks = with_row_locks(
             query, of=DbCallbackRequest, session=session, **skip_locked(session=session)
         ).all()
         for callback in callbacks:
             try:
                 self._add_callback_to_queue(callback.get_callback_request())
                 session.delete(callback)
             except Exception as e:
                 self.log.warning("Error adding callback for execution: %s, %s", callback, e)
         guard.commit()
Ejemplo n.º 7
0
    def _find_schedulable_tasks(
            self,
            dag_run: DagRun,
            session: Session,
            check_execution_date=False) -> Optional[List[TI]]:
        """
        Make scheduling decisions about an individual dag run

        ``currently_active_runs`` is passed in so that a batch query can be
        used to ask this for all dag runs in the batch, to avoid an n+1 query.

        :param dag_run: The DagRun to schedule
        :return: scheduled tasks
        """
        if not dag_run or dag_run.get_state() in State.finished:
            return
        try:
            dag = dag_run.dag = self.dagbag.get_dag(dag_run.dag_id,
                                                    session=session)
        except SerializedDagNotFound:
            self.log.exception("DAG '%s' not found in serialized_dag table",
                               dag_run.dag_id)
            return None

        if not dag:
            self.log.error("Couldn't find dag %s in DagBag/DB!",
                           dag_run.dag_id)
            return None

        currently_active_runs = session.query(TI.execution_date, ).filter(
            TI.dag_id == dag_run.dag_id,
            TI.state.notin_(list(State.finished)),
        ).all()

        if check_execution_date and dag_run.execution_date > timezone.utcnow(
        ) and not dag.allow_future_exec_dates:
            self.log.warning("Execution date is in future: %s",
                             dag_run.execution_date)
            return None

        if dag.max_active_runs:
            if (len(currently_active_runs) >= dag.max_active_runs
                    and dag_run.execution_date not in currently_active_runs):
                self.log.info(
                    "DAG %s already has %d active runs, not queuing any tasks for run %s",
                    dag.dag_id,
                    len(currently_active_runs),
                    dag_run.execution_date,
                )
                return None

        self._verify_integrity_if_dag_changed(dag_run=dag_run, session=session)

        schedulable_tis, callback_to_run = dag_run.update_state(
            session=session, execute_callbacks=False)
        dag_run.schedule_tis(schedulable_tis, session)

        query = (session.query(TI).outerjoin(TI.dag_run).filter(
            or_(DR.run_id.is_(None),
                DR.run_type != DagRunType.BACKFILL_JOB)).join(
                    TI.dag_model).filter(not_(DM.is_paused)).filter(
                        TI.state == State.SCHEDULED).options(
                            selectinload('dag_model')))
        scheduled_tis: List[TI] = with_row_locks(
            query,
            of=TI,
            **skip_locked(session=session),
        ).all()
        # filter need event tasks
        serialized_dag = session.query(SerializedDagModel).filter(
            SerializedDagModel.dag_id == dag_run.dag_id).first()
        dep: DagEventDependencies = DagEventDependencies.from_json(
            serialized_dag.event_relationships)
        event_task_set = dep.find_event_dependencies_tasks()
        final_scheduled_tis = []
        for ti in scheduled_tis:
            if ti.task_id not in event_task_set:
                final_scheduled_tis.append(ti)

        return final_scheduled_tis
Ejemplo n.º 8
0
    def _executable_task_instances_to_queued(self, max_tis: int, session: Session = None) -> List[TI]:
        """
        Finds TIs that are ready for execution with respect to pool limits,
        dag max_active_tasks, executor state, and priority.

        :param max_tis: Maximum number of TIs to queue in this loop.
        :type max_tis: int
        :return: list[airflow.models.TaskInstance]
        """
        executable_tis: List[TI] = []

        # Get the pool settings. We get a lock on the pool rows, treating this as a "critical section"
        # Throws an exception if lock cannot be obtained, rather than blocking
        pools = models.Pool.slots_stats(lock_rows=True, session=session)

        # If the pools are full, there is no point doing anything!
        # If _somehow_ the pool is overfull, don't let the limit go negative - it breaks SQL
        pool_slots_free = max(0, sum(pool['open'] for pool in pools.values()))

        if pool_slots_free == 0:
            self.log.debug("All pools are full!")
            return executable_tis

        max_tis = min(max_tis, pool_slots_free)

        # Get all task instances associated with scheduled
        # DagRuns which are not backfilled, in the given states,
        # and the dag is not paused
        query = (
            session.query(TI)
            .join(TI.dag_run)
            .options(eagerload(TI.dag_run))
            .filter(DR.run_type != DagRunType.BACKFILL_JOB, DR.state != DagRunState.QUEUED)
            .join(TI.dag_model)
            .filter(not_(DM.is_paused))
            .filter(TI.state == State.SCHEDULED)
            .options(selectinload('dag_model'))
            .order_by(-TI.priority_weight, DR.execution_date)
        )
        starved_pools = [pool_name for pool_name, stats in pools.items() if stats['open'] <= 0]
        if starved_pools:
            query = query.filter(not_(TI.pool.in_(starved_pools)))

        query = query.limit(max_tis)

        task_instances_to_examine: List[TI] = with_row_locks(
            query,
            of=TI,
            session=session,
            **skip_locked(session=session),
        ).all()
        # TODO[HA]: This was wrong before anyway, as it only looked at a sub-set of dags, not everything.
        # Stats.gauge('scheduler.tasks.pending', len(task_instances_to_examine))

        if len(task_instances_to_examine) == 0:
            self.log.debug("No tasks to consider for execution.")
            return executable_tis

        # Put one task instance on each line
        task_instance_str = "\n\t".join(repr(x) for x in task_instances_to_examine)
        self.log.info("%s tasks up for execution:\n\t%s", len(task_instances_to_examine), task_instance_str)

        pool_to_task_instances: DefaultDict[str, List[models.Pool]] = defaultdict(list)
        for task_instance in task_instances_to_examine:
            pool_to_task_instances[task_instance.pool].append(task_instance)

        # dag_id to # of running tasks and (dag_id, task_id) to # of running tasks.
        dag_max_active_tasks_map: DefaultDict[str, int]
        task_concurrency_map: DefaultDict[Tuple[str, str], int]
        dag_max_active_tasks_map, task_concurrency_map = self.__get_concurrency_maps(
            states=list(EXECUTION_STATES), session=session
        )

        num_tasks_in_executor = 0
        # Number of tasks that cannot be scheduled because of no open slot in pool
        num_starving_tasks_total = 0

        # Go through each pool, and queue up a task for execution if there are
        # any open slots in the pool.

        for pool, task_instances in pool_to_task_instances.items():
            pool_name = pool
            if pool not in pools:
                self.log.warning("Tasks using non-existent pool '%s' will not be scheduled", pool)
                continue

            open_slots = pools[pool]["open"]

            num_ready = len(task_instances)
            self.log.info(
                "Figuring out tasks to run in Pool(name=%s) with %s open slots "
                "and %s task instances ready to be queued",
                pool,
                open_slots,
                num_ready,
            )

            priority_sorted_task_instances = sorted(
                task_instances, key=lambda ti: (-ti.priority_weight, ti.execution_date)
            )

            num_starving_tasks = 0
            for current_index, task_instance in enumerate(priority_sorted_task_instances):
                if open_slots <= 0:
                    self.log.info("Not scheduling since there are %s open slots in pool %s", open_slots, pool)
                    # Can't schedule any more since there are no more open slots.
                    num_unhandled = len(priority_sorted_task_instances) - current_index
                    num_starving_tasks += num_unhandled
                    num_starving_tasks_total += num_unhandled
                    break

                # Check to make sure that the task max_active_tasks of the DAG hasn't been
                # reached.
                dag_id = task_instance.dag_id

                current_max_active_tasks_per_dag = dag_max_active_tasks_map[dag_id]
                max_active_tasks_per_dag_limit = task_instance.dag_model.max_active_tasks
                self.log.info(
                    "DAG %s has %s/%s running and queued tasks",
                    dag_id,
                    current_max_active_tasks_per_dag,
                    max_active_tasks_per_dag_limit,
                )
                if current_max_active_tasks_per_dag >= max_active_tasks_per_dag_limit:
                    self.log.info(
                        "Not executing %s since the number of tasks running or queued "
                        "from DAG %s is >= to the DAG's max_active_tasks limit of %s",
                        task_instance,
                        dag_id,
                        max_active_tasks_per_dag_limit,
                    )
                    continue

                task_concurrency_limit: Optional[int] = None
                if task_instance.dag_model.has_task_concurrency_limits:
                    # Many dags don't have a task_concurrency, so where we can avoid loading the full
                    # serialized DAG the better.
                    serialized_dag = self.dagbag.get_dag(dag_id, session=session)
                    if serialized_dag.has_task(task_instance.task_id):
                        task_concurrency_limit = serialized_dag.get_task(
                            task_instance.task_id
                        ).max_active_tis_per_dag

                    if task_concurrency_limit is not None:
                        current_task_concurrency = task_concurrency_map[
                            (task_instance.dag_id, task_instance.task_id)
                        ]

                        if current_task_concurrency >= task_concurrency_limit:
                            self.log.info(
                                "Not executing %s since the task concurrency for"
                                " this task has been reached.",
                                task_instance,
                            )
                            continue

                if task_instance.pool_slots > open_slots:
                    self.log.info(
                        "Not executing %s since it requires %s slots "
                        "but there are %s open slots in the pool %s.",
                        task_instance,
                        task_instance.pool_slots,
                        open_slots,
                        pool,
                    )
                    num_starving_tasks += 1
                    num_starving_tasks_total += 1
                    # Though we can execute tasks with lower priority if there's enough room
                    continue

                executable_tis.append(task_instance)
                open_slots -= task_instance.pool_slots
                dag_max_active_tasks_map[dag_id] += 1
                task_concurrency_map[(task_instance.dag_id, task_instance.task_id)] += 1

            Stats.gauge(f'pool.starving_tasks.{pool_name}', num_starving_tasks)

        Stats.gauge('scheduler.tasks.starving', num_starving_tasks_total)
        Stats.gauge('scheduler.tasks.running', num_tasks_in_executor)
        Stats.gauge('scheduler.tasks.executable', len(executable_tis))

        task_instance_str = "\n\t".join(repr(x) for x in executable_tis)
        self.log.info("Setting the following tasks to queued state:\n\t%s", task_instance_str)
        if len(executable_tis) > 0:
            # set TIs to queued state
            filter_for_tis = TI.filter_for_tis(executable_tis)
            session.query(TI).filter(filter_for_tis).update(
                # TODO[ha]: should we use func.now()? How does that work with DB timezone
                # on mysql when it's not UTC?
                {TI.state: State.QUEUED, TI.queued_dttm: timezone.utcnow(), TI.queued_by_job_id: self.id},
                synchronize_session=False,
            )

        for ti in executable_tis:
            make_transient(ti)
        return executable_tis
Ejemplo n.º 9
0
    def adopt_or_reset_orphaned_tasks(self, session: Session = None):
        """
        Reset any TaskInstance still in QUEUED or SCHEDULED states that were
        enqueued by a SchedulerJob that is no longer running.

        :return: the number of TIs reset
        :rtype: int
        """
        self.log.info("Resetting orphaned tasks for active dag runs")
        timeout = conf.getint('scheduler', 'scheduler_health_check_threshold')

        for attempt in run_with_db_retries(logger=self.log):
            with attempt:
                self.log.debug(
                    "Running SchedulerJob.adopt_or_reset_orphaned_tasks with retries. Try %d of %d",
                    attempt.retry_state.attempt_number,
                    MAX_DB_RETRIES,
                )
                self.log.debug("Calling SchedulerJob.adopt_or_reset_orphaned_tasks method")
                try:
                    num_failed = (
                        session.query(SchedulerJob)
                        .filter(
                            SchedulerJob.state == State.RUNNING,
                            SchedulerJob.latest_heartbeat < (timezone.utcnow() - timedelta(seconds=timeout)),
                        )
                        .update({"state": State.FAILED})
                    )

                    if num_failed:
                        self.log.info("Marked %d SchedulerJob instances as failed", num_failed)
                        Stats.incr(self.__class__.__name__.lower() + '_end', num_failed)

                    resettable_states = [State.QUEUED, State.RUNNING]
                    query = (
                        session.query(TI)
                        .filter(TI.state.in_(resettable_states))
                        # outerjoin is because we didn't use to have queued_by_job
                        # set, so we need to pick up anything pre upgrade. This (and the
                        # "or queued_by_job_id IS NONE") can go as soon as scheduler HA is
                        # released.
                        .outerjoin(TI.queued_by_job)
                        .filter(or_(TI.queued_by_job_id.is_(None), SchedulerJob.state != State.RUNNING))
                        .join(TI.dag_run)
                        .filter(
                            DagRun.run_type != DagRunType.BACKFILL_JOB,
                            DagRun.state == State.RUNNING,
                        )
                        .options(load_only(TI.dag_id, TI.task_id, TI.run_id))
                    )

                    # Lock these rows, so that another scheduler can't try and adopt these too
                    tis_to_reset_or_adopt = with_row_locks(
                        query, of=TI, session=session, **skip_locked(session=session)
                    ).all()
                    to_reset = self.executor.try_adopt_task_instances(tis_to_reset_or_adopt)

                    reset_tis_message = []
                    for ti in to_reset:
                        reset_tis_message.append(repr(ti))
                        ti.state = State.NONE
                        ti.queued_by_job_id = None

                    for ti in set(tis_to_reset_or_adopt) - set(to_reset):
                        ti.queued_by_job_id = self.id

                    Stats.incr('scheduler.orphaned_tasks.cleared', len(to_reset))
                    Stats.incr('scheduler.orphaned_tasks.adopted', len(tis_to_reset_or_adopt) - len(to_reset))

                    if to_reset:
                        task_instance_str = '\n\t'.join(reset_tis_message)
                        self.log.info(
                            "Reset the following %s orphaned TaskInstances:\n\t%s",
                            len(to_reset),
                            task_instance_str,
                        )

                    # Issue SQL/finish "Unit of Work", but let @provide_session
                    # commit (or if passed a session, let caller decide when to commit
                    session.flush()
                except OperationalError:
                    session.rollback()
                    raise

        return len(to_reset)