def slots_stats( *, lock_rows: bool = False, session: Session = None, ) -> Dict[str, PoolStats]: """ Get Pool stats (Number of Running, Queued, Open & Total tasks) If ``lock_rows`` is True, and the database engine in use supports the ``NOWAIT`` syntax, then a non-blocking lock will be attempted -- if the lock is not available then SQLAlchemy will throw an OperationalError. :param lock_rows: Should we attempt to obtain a row-level lock on all the Pool rows returns :param session: SQLAlchemy ORM Session """ from airflow.models.taskinstance import TaskInstance # Avoid circular import pools: Dict[str, PoolStats] = {} query = session.query(Pool.pool, Pool.slots) if lock_rows: query = with_row_locks(query, **nowait(session)) pool_rows: Iterable[Tuple[str, int]] = query.all() for (pool_name, total_slots) in pool_rows: pools[pool_name] = PoolStats(total=total_slots, running=0, queued=0, open=0) state_count_by_pool = ( session.query(TaskInstance.pool, TaskInstance.state, func.count()) .filter(TaskInstance.state.in_(list(EXECUTION_STATES))) .group_by(TaskInstance.pool, TaskInstance.state) ).all() # calculate queued and running metrics count: int for (pool_name, state, count) in state_count_by_pool: stats_dict: Optional[PoolStats] = pools.get(pool_name) if not stats_dict: continue # TypedDict key must be a string literal, so we use if-statements to set value if state == "running": stats_dict["running"] = count elif state == "queued": stats_dict["queued"] = count else: raise AirflowException( f"Unexpected state. Expected values: {EXECUTION_STATES}." ) # calculate open metric for pool_name, stats_dict in pools.items(): if stats_dict["total"] == -1: # -1 means infinite stats_dict["open"] = -1 else: stats_dict["open"] = stats_dict["total"] - stats_dict["running"] - stats_dict["queued"] return pools
def _run_mini_scheduler_on_child_tasks(self, session=None) -> None: try: # Re-select the row with a lock dag_run = with_row_locks( session.query(DagRun).filter_by( dag_id=self.dag_id, run_id=self.task_instance.run_id, ), session=session, ).one() task = self.task_instance.task if TYPE_CHECKING: assert task.dag # Get a partial DAG with just the specific tasks we want to examine. # In order for dep checks to work correctly, we include ourself (so # TriggerRuleDep can check the state of the task we just executed). partial_dag = task.dag.partial_subset( task.downstream_task_ids, include_downstream=True, include_upstream=False, include_direct_upstream=True, ) dag_run.dag = partial_dag info = dag_run.task_instance_scheduling_decisions(session) skippable_task_ids = { task_id for task_id in partial_dag.task_ids if task_id not in task.downstream_task_ids } schedulable_tis = [ ti for ti in info.schedulable_tis if ti.task_id not in skippable_task_ids ] for schedulable_ti in schedulable_tis: if not hasattr(schedulable_ti, "task"): schedulable_ti.task = task.dag.get_task( schedulable_ti.task_id) num = dag_run.schedule_tis(schedulable_tis) self.log.info( "%d downstream tasks scheduled from follow-on schedule check", num) session.commit() except OperationalError as e: # Any kind of DB error here is _non fatal_ as this block is just an optimisation. self.log.info( "Skipping mini scheduling run due to exception: %s", e.statement, exc_info=True, ) session.rollback()
def next_dagruns_to_examine( cls, state: DagRunState, session: Session, max_number: Optional[int] = None, ): """ Return the next DagRuns that the scheduler should attempt to schedule. This will return zero or more DagRun rows that are row-level-locked with a "SELECT ... FOR UPDATE" query, you should ensure that any scheduling decisions are made in a single transaction -- as soon as the transaction is committed it will be unlocked. :rtype: list[airflow.models.DagRun] """ from airflow.models.dag import DagModel if max_number is None: max_number = cls.DEFAULT_DAGRUNS_TO_EXAMINE # TODO: Bake this query, it is run _A lot_ query = ( session.query(cls) .filter(cls.state == state, cls.run_type != DagRunType.BACKFILL_JOB) .join( DagModel, DagModel.dag_id == cls.dag_id, ) .filter( DagModel.is_paused == expression.false(), DagModel.is_active == expression.true(), ) ) if state == State.QUEUED: # For dag runs in the queued state, we check if they have reached the max_active_runs limit # and if so we drop them running_drs = ( session.query(DagRun.dag_id, func.count(DagRun.state).label('num_running')) .filter(DagRun.state == DagRunState.RUNNING) .group_by(DagRun.dag_id) .subquery() ) query = query.outerjoin(running_drs, running_drs.c.dag_id == DagRun.dag_id).filter( func.coalesce(running_drs.c.num_running, 0) < DagModel.max_active_runs ) query = query.order_by( nulls_first(cls.last_scheduling_decision, session=session), cls.execution_date, ) if not settings.ALLOW_FUTURE_EXEC_DATES: query = query.filter(DagRun.execution_date <= func.now()) return with_row_locks( query.limit(max_number), of=cls, session=session, **skip_locked(session=session) )
def test_with_row_locks(self, dialect, supports_for_update_of, use_row_level_lock_conf, expected_use_row_level_lock): query = mock.Mock() session = mock.Mock() session.bind.dialect.name = dialect session.bind.dialect.supports_for_update_of = supports_for_update_of with mock.patch("airflow.utils.sqlalchemy.USE_ROW_LEVEL_LOCKING", use_row_level_lock_conf): returned_value = with_row_locks(query=query, session=session, nowait=True) if expected_use_row_level_lock: query.with_for_update.assert_called_once_with(nowait=True) else: assert returned_value == query query.with_for_update.assert_not_called()
def next_dagruns_to_examine( cls, state: DagRunState, session: Session, max_number: Optional[int] = None, ): """ Return the next DagRuns that the scheduler should attempt to schedule. This will return zero or more DagRun rows that are row-level-locked with a "SELECT ... FOR UPDATE" query, you should ensure that any scheduling decisions are made in a single transaction -- as soon as the transaction is committed it will be unlocked. :rtype: list[airflow.models.DagRun] """ from airflow.models.dag import DagModel if max_number is None: max_number = cls.DEFAULT_DAGRUNS_TO_EXAMINE # TODO: Bake this query, it is run _A lot_ query = ( session.query(cls) .filter(cls.state == state, cls.run_type != DagRunType.BACKFILL_JOB) .join( DagModel, DagModel.dag_id == cls.dag_id, ) .filter( DagModel.is_paused == expression.false(), DagModel.is_active == expression.true(), ) .order_by( nulls_first(cls.last_scheduling_decision, session=session), cls.execution_date, ) ) if not settings.ALLOW_FUTURE_EXEC_DATES: query = query.filter(DagRun.execution_date <= func.now()) return with_row_locks( query.limit(max_number), of=cls, session=session, **skip_locked(session=session) )
def _fetch_callbacks(self, max_callbacks: int, session: Session = NEW_SESSION): """Fetches callbacks from database and add them to the internal queue for execution.""" self.log.debug("Fetching callbacks from the database.") with prohibit_commit(session) as guard: query = ( session.query(DbCallbackRequest) .order_by(DbCallbackRequest.priority_weight.asc()) .limit(max_callbacks) ) callbacks = with_row_locks( query, of=DbCallbackRequest, session=session, **skip_locked(session=session) ).all() for callback in callbacks: try: self._add_callback_to_queue(callback.get_callback_request()) session.delete(callback) except Exception as e: self.log.warning("Error adding callback for execution: %s, %s", callback, e) guard.commit()
def _find_schedulable_tasks( self, dag_run: DagRun, session: Session, check_execution_date=False) -> Optional[List[TI]]: """ Make scheduling decisions about an individual dag run ``currently_active_runs`` is passed in so that a batch query can be used to ask this for all dag runs in the batch, to avoid an n+1 query. :param dag_run: The DagRun to schedule :return: scheduled tasks """ if not dag_run or dag_run.get_state() in State.finished: return try: dag = dag_run.dag = self.dagbag.get_dag(dag_run.dag_id, session=session) except SerializedDagNotFound: self.log.exception("DAG '%s' not found in serialized_dag table", dag_run.dag_id) return None if not dag: self.log.error("Couldn't find dag %s in DagBag/DB!", dag_run.dag_id) return None currently_active_runs = session.query(TI.execution_date, ).filter( TI.dag_id == dag_run.dag_id, TI.state.notin_(list(State.finished)), ).all() if check_execution_date and dag_run.execution_date > timezone.utcnow( ) and not dag.allow_future_exec_dates: self.log.warning("Execution date is in future: %s", dag_run.execution_date) return None if dag.max_active_runs: if (len(currently_active_runs) >= dag.max_active_runs and dag_run.execution_date not in currently_active_runs): self.log.info( "DAG %s already has %d active runs, not queuing any tasks for run %s", dag.dag_id, len(currently_active_runs), dag_run.execution_date, ) return None self._verify_integrity_if_dag_changed(dag_run=dag_run, session=session) schedulable_tis, callback_to_run = dag_run.update_state( session=session, execute_callbacks=False) dag_run.schedule_tis(schedulable_tis, session) query = (session.query(TI).outerjoin(TI.dag_run).filter( or_(DR.run_id.is_(None), DR.run_type != DagRunType.BACKFILL_JOB)).join( TI.dag_model).filter(not_(DM.is_paused)).filter( TI.state == State.SCHEDULED).options( selectinload('dag_model'))) scheduled_tis: List[TI] = with_row_locks( query, of=TI, **skip_locked(session=session), ).all() # filter need event tasks serialized_dag = session.query(SerializedDagModel).filter( SerializedDagModel.dag_id == dag_run.dag_id).first() dep: DagEventDependencies = DagEventDependencies.from_json( serialized_dag.event_relationships) event_task_set = dep.find_event_dependencies_tasks() final_scheduled_tis = [] for ti in scheduled_tis: if ti.task_id not in event_task_set: final_scheduled_tis.append(ti) return final_scheduled_tis
def _executable_task_instances_to_queued(self, max_tis: int, session: Session = None) -> List[TI]: """ Finds TIs that are ready for execution with respect to pool limits, dag max_active_tasks, executor state, and priority. :param max_tis: Maximum number of TIs to queue in this loop. :type max_tis: int :return: list[airflow.models.TaskInstance] """ executable_tis: List[TI] = [] # Get the pool settings. We get a lock on the pool rows, treating this as a "critical section" # Throws an exception if lock cannot be obtained, rather than blocking pools = models.Pool.slots_stats(lock_rows=True, session=session) # If the pools are full, there is no point doing anything! # If _somehow_ the pool is overfull, don't let the limit go negative - it breaks SQL pool_slots_free = max(0, sum(pool['open'] for pool in pools.values())) if pool_slots_free == 0: self.log.debug("All pools are full!") return executable_tis max_tis = min(max_tis, pool_slots_free) # Get all task instances associated with scheduled # DagRuns which are not backfilled, in the given states, # and the dag is not paused query = ( session.query(TI) .join(TI.dag_run) .options(eagerload(TI.dag_run)) .filter(DR.run_type != DagRunType.BACKFILL_JOB, DR.state != DagRunState.QUEUED) .join(TI.dag_model) .filter(not_(DM.is_paused)) .filter(TI.state == State.SCHEDULED) .options(selectinload('dag_model')) .order_by(-TI.priority_weight, DR.execution_date) ) starved_pools = [pool_name for pool_name, stats in pools.items() if stats['open'] <= 0] if starved_pools: query = query.filter(not_(TI.pool.in_(starved_pools))) query = query.limit(max_tis) task_instances_to_examine: List[TI] = with_row_locks( query, of=TI, session=session, **skip_locked(session=session), ).all() # TODO[HA]: This was wrong before anyway, as it only looked at a sub-set of dags, not everything. # Stats.gauge('scheduler.tasks.pending', len(task_instances_to_examine)) if len(task_instances_to_examine) == 0: self.log.debug("No tasks to consider for execution.") return executable_tis # Put one task instance on each line task_instance_str = "\n\t".join(repr(x) for x in task_instances_to_examine) self.log.info("%s tasks up for execution:\n\t%s", len(task_instances_to_examine), task_instance_str) pool_to_task_instances: DefaultDict[str, List[models.Pool]] = defaultdict(list) for task_instance in task_instances_to_examine: pool_to_task_instances[task_instance.pool].append(task_instance) # dag_id to # of running tasks and (dag_id, task_id) to # of running tasks. dag_max_active_tasks_map: DefaultDict[str, int] task_concurrency_map: DefaultDict[Tuple[str, str], int] dag_max_active_tasks_map, task_concurrency_map = self.__get_concurrency_maps( states=list(EXECUTION_STATES), session=session ) num_tasks_in_executor = 0 # Number of tasks that cannot be scheduled because of no open slot in pool num_starving_tasks_total = 0 # Go through each pool, and queue up a task for execution if there are # any open slots in the pool. for pool, task_instances in pool_to_task_instances.items(): pool_name = pool if pool not in pools: self.log.warning("Tasks using non-existent pool '%s' will not be scheduled", pool) continue open_slots = pools[pool]["open"] num_ready = len(task_instances) self.log.info( "Figuring out tasks to run in Pool(name=%s) with %s open slots " "and %s task instances ready to be queued", pool, open_slots, num_ready, ) priority_sorted_task_instances = sorted( task_instances, key=lambda ti: (-ti.priority_weight, ti.execution_date) ) num_starving_tasks = 0 for current_index, task_instance in enumerate(priority_sorted_task_instances): if open_slots <= 0: self.log.info("Not scheduling since there are %s open slots in pool %s", open_slots, pool) # Can't schedule any more since there are no more open slots. num_unhandled = len(priority_sorted_task_instances) - current_index num_starving_tasks += num_unhandled num_starving_tasks_total += num_unhandled break # Check to make sure that the task max_active_tasks of the DAG hasn't been # reached. dag_id = task_instance.dag_id current_max_active_tasks_per_dag = dag_max_active_tasks_map[dag_id] max_active_tasks_per_dag_limit = task_instance.dag_model.max_active_tasks self.log.info( "DAG %s has %s/%s running and queued tasks", dag_id, current_max_active_tasks_per_dag, max_active_tasks_per_dag_limit, ) if current_max_active_tasks_per_dag >= max_active_tasks_per_dag_limit: self.log.info( "Not executing %s since the number of tasks running or queued " "from DAG %s is >= to the DAG's max_active_tasks limit of %s", task_instance, dag_id, max_active_tasks_per_dag_limit, ) continue task_concurrency_limit: Optional[int] = None if task_instance.dag_model.has_task_concurrency_limits: # Many dags don't have a task_concurrency, so where we can avoid loading the full # serialized DAG the better. serialized_dag = self.dagbag.get_dag(dag_id, session=session) if serialized_dag.has_task(task_instance.task_id): task_concurrency_limit = serialized_dag.get_task( task_instance.task_id ).max_active_tis_per_dag if task_concurrency_limit is not None: current_task_concurrency = task_concurrency_map[ (task_instance.dag_id, task_instance.task_id) ] if current_task_concurrency >= task_concurrency_limit: self.log.info( "Not executing %s since the task concurrency for" " this task has been reached.", task_instance, ) continue if task_instance.pool_slots > open_slots: self.log.info( "Not executing %s since it requires %s slots " "but there are %s open slots in the pool %s.", task_instance, task_instance.pool_slots, open_slots, pool, ) num_starving_tasks += 1 num_starving_tasks_total += 1 # Though we can execute tasks with lower priority if there's enough room continue executable_tis.append(task_instance) open_slots -= task_instance.pool_slots dag_max_active_tasks_map[dag_id] += 1 task_concurrency_map[(task_instance.dag_id, task_instance.task_id)] += 1 Stats.gauge(f'pool.starving_tasks.{pool_name}', num_starving_tasks) Stats.gauge('scheduler.tasks.starving', num_starving_tasks_total) Stats.gauge('scheduler.tasks.running', num_tasks_in_executor) Stats.gauge('scheduler.tasks.executable', len(executable_tis)) task_instance_str = "\n\t".join(repr(x) for x in executable_tis) self.log.info("Setting the following tasks to queued state:\n\t%s", task_instance_str) if len(executable_tis) > 0: # set TIs to queued state filter_for_tis = TI.filter_for_tis(executable_tis) session.query(TI).filter(filter_for_tis).update( # TODO[ha]: should we use func.now()? How does that work with DB timezone # on mysql when it's not UTC? {TI.state: State.QUEUED, TI.queued_dttm: timezone.utcnow(), TI.queued_by_job_id: self.id}, synchronize_session=False, ) for ti in executable_tis: make_transient(ti) return executable_tis
def adopt_or_reset_orphaned_tasks(self, session: Session = None): """ Reset any TaskInstance still in QUEUED or SCHEDULED states that were enqueued by a SchedulerJob that is no longer running. :return: the number of TIs reset :rtype: int """ self.log.info("Resetting orphaned tasks for active dag runs") timeout = conf.getint('scheduler', 'scheduler_health_check_threshold') for attempt in run_with_db_retries(logger=self.log): with attempt: self.log.debug( "Running SchedulerJob.adopt_or_reset_orphaned_tasks with retries. Try %d of %d", attempt.retry_state.attempt_number, MAX_DB_RETRIES, ) self.log.debug("Calling SchedulerJob.adopt_or_reset_orphaned_tasks method") try: num_failed = ( session.query(SchedulerJob) .filter( SchedulerJob.state == State.RUNNING, SchedulerJob.latest_heartbeat < (timezone.utcnow() - timedelta(seconds=timeout)), ) .update({"state": State.FAILED}) ) if num_failed: self.log.info("Marked %d SchedulerJob instances as failed", num_failed) Stats.incr(self.__class__.__name__.lower() + '_end', num_failed) resettable_states = [State.QUEUED, State.RUNNING] query = ( session.query(TI) .filter(TI.state.in_(resettable_states)) # outerjoin is because we didn't use to have queued_by_job # set, so we need to pick up anything pre upgrade. This (and the # "or queued_by_job_id IS NONE") can go as soon as scheduler HA is # released. .outerjoin(TI.queued_by_job) .filter(or_(TI.queued_by_job_id.is_(None), SchedulerJob.state != State.RUNNING)) .join(TI.dag_run) .filter( DagRun.run_type != DagRunType.BACKFILL_JOB, DagRun.state == State.RUNNING, ) .options(load_only(TI.dag_id, TI.task_id, TI.run_id)) ) # Lock these rows, so that another scheduler can't try and adopt these too tis_to_reset_or_adopt = with_row_locks( query, of=TI, session=session, **skip_locked(session=session) ).all() to_reset = self.executor.try_adopt_task_instances(tis_to_reset_or_adopt) reset_tis_message = [] for ti in to_reset: reset_tis_message.append(repr(ti)) ti.state = State.NONE ti.queued_by_job_id = None for ti in set(tis_to_reset_or_adopt) - set(to_reset): ti.queued_by_job_id = self.id Stats.incr('scheduler.orphaned_tasks.cleared', len(to_reset)) Stats.incr('scheduler.orphaned_tasks.adopted', len(tis_to_reset_or_adopt) - len(to_reset)) if to_reset: task_instance_str = '\n\t'.join(reset_tis_message) self.log.info( "Reset the following %s orphaned TaskInstances:\n\t%s", len(to_reset), task_instance_str, ) # Issue SQL/finish "Unit of Work", but let @provide_session # commit (or if passed a session, let caller decide when to commit session.flush() except OperationalError: session.rollback() raise return len(to_reset)