def execute_callbacks(self, dagbag: DagBag, callback_requests: List[CallbackRequest], session: Session = NEW_SESSION) -> None: """ Execute on failure callbacks. These objects can come from SchedulerJob or from DagFileProcessorManager. :param dagbag: Dag Bag of dags :param callback_requests: failure callbacks to execute :param session: DB session. """ for request in callback_requests: self.log.debug("Processing Callback Request: %s", request) try: if isinstance(request, TaskCallbackRequest): self._execute_task_callbacks(dagbag, request) elif isinstance(request, SlaCallbackRequest): self.manage_slas(dagbag.get_dag(request.dag_id), session=session) elif isinstance(request, DagCallbackRequest): self._execute_dag_callbacks(dagbag, request, session) except Exception: self.log.exception( "Error executing %s callback for file: %s", request.__class__.__name__, request.full_filepath, ) session.commit()
def test_dags_integrity(): dag_bag = DagBag() # Assert that all DAGs can be imported, i.e., all parameters required # for DAGs are specified and no task cycles present. assert dag_bag.import_errors == {} for dag_id in dag_bag.dag_ids: dag = dag_bag.get_dag(dag_id=dag_id) # Assert that a DAG hash at least one task. assert dag is not None assert len(dag.tasks) > 0
def test_heartbeat_failed_fast(self): """ Test that task heartbeat will sleep when it fails fast """ self.mock_base_job_sleep.side_effect = time.sleep with create_session() as session: dagbag = DagBag( dag_folder=TEST_DAG_FOLDER, include_examples=False, ) dag_id = 'test_heartbeat_failed_fast' task_id = 'test_heartbeat_failed_fast_op' dag = dagbag.get_dag(dag_id) task = dag.get_task(task_id) dag.create_dagrun( run_id="test_heartbeat_failed_fast_run", state=State.RUNNING, execution_date=DEFAULT_DATE, start_date=DEFAULT_DATE, session=session, ) ti = TaskInstance(task=task, execution_date=DEFAULT_DATE) ti.refresh_from_db() ti.state = State.RUNNING ti.hostname = get_hostname() ti.pid = 1 session.commit() job = LocalTaskJob(task_instance=ti, executor=MockExecutor(do_update=False)) job.heartrate = 2 heartbeat_records = [] job.heartbeat_callback = lambda session: heartbeat_records.append( job.latest_heartbeat) job._execute() self.assertGreater(len(heartbeat_records), 2) for i in range(1, len(heartbeat_records)): time1 = heartbeat_records[i - 1] time2 = heartbeat_records[i] # Assert that difference small enough delta = (time2 - time1).total_seconds() self.assertAlmostEqual(delta, job.heartrate, delta=0.05)
def run_dag(self, dag_id: str, dag_folder: str = DEFAULT_DAG_FOLDER) -> None: """ Runs example dag by it's ID. :param dag_id: id of a DAG to be run :type dag_id: str :param dag_folder: directory where to look for the specific DAG. Relative to AIRFLOW_HOME. :type dag_folder: str """ if os.environ.get("RUN_AIRFLOW_1_10") == "true": # For system tests purpose we are changing airflow/providers # to side packages path of the installed providers package python = f"python{sys.version_info.major}.{sys.version_info.minor}" dag_folder = dag_folder.replace( "/opt/airflow/airflow/providers", f"/usr/local/lib/{python}/site-packages/airflow/providers", ) self.log.info("Looking for DAG: %s in %s", dag_id, dag_folder) dag_bag = DagBag(dag_folder=dag_folder, include_examples=False) dag = dag_bag.get_dag(dag_id) if dag is None: raise AirflowException( "The Dag {dag_id} could not be found. It's either an import problem," "wrong dag_id or DAG is not in provided dag_folder." "The content of the {dag_folder} folder is {content}".format( dag_id=dag_id, dag_folder=dag_folder, content=os.listdir(dag_folder), )) self.log.info("Attempting to run DAG: %s", dag_id) if os.environ.get("RUN_AIRFLOW_1_10") == "true": dag.clear() else: dag.clear(dag_run_state=State.NONE) try: dag.run(ignore_first_depends_on_past=True, verbose=True) except Exception: self._print_all_log_files() raise
def run_dag(self, dag_id: str, dag_folder: str = DEFAULT_DAG_FOLDER) -> None: """ Runs example dag by it's ID. :param dag_id: id of a DAG to be run :type dag_id: str :param dag_folder: directory where to look for the specific DAG. Relative to AIRFLOW_HOME. :type dag_folder: str """ if os.environ.get("RUN_AIRFLOW_1_10"): # For system tests purpose we are mounting airflow/providers to /providers folder # So that we can get example_dags from there dag_folder = dag_folder.replace("/opt/airflow/airflow/providers", "/providers") temp_dir = mkdtemp() os.rmdir(temp_dir) shutil.copytree(dag_folder, temp_dir) dag_folder = temp_dir self.correct_imports_for_airflow_1_10(temp_dir) self.log.info("Looking for DAG: %s in %s", dag_id, dag_folder) dag_bag = DagBag(dag_folder=dag_folder, include_examples=False) dag = dag_bag.get_dag(dag_id) if dag is None: raise AirflowException( "The Dag {dag_id} could not be found. It's either an import problem," "wrong dag_id or DAG is not in provided dag_folder." "The content of the {dag_folder} folder is {content}".format( dag_id=dag_id, dag_folder=dag_folder, content=os.listdir(dag_folder), )) self.log.info("Attempting to run DAG: %s", dag_id) dag.clear(reset_dag_runs=True) try: dag.run(ignore_first_depends_on_past=True, verbose=True) except Exception: self._print_all_log_files() raise
def main(num_runs, repeat, pre_create_dag_runs, executor_class, dag_ids): """ This script can be used to measure the total "scheduler overhead" of Airflow. By overhead we mean if the tasks executed instantly as soon as they are executed (i.e. they do nothing) how quickly could we schedule them. It will monitor the task completion of the Mock/stub executor (no actual tasks are run) and after the required number of dag runs for all the specified dags have completed all their tasks, it will cleanly shut down the scheduler. The dags you run with need to have an early enough start_date to create the desired number of runs. Care should be taken that other limits (DAG concurrency, pool size etc) are not the bottleneck. This script doesn't help you in that regard. It is recommended to repeat the test at least 3 times (`--repeat=3`, the default) so that you can get somewhat-accurate variance on the reported timing numbers, but this can be disabled for longer runs if needed. """ # Turn on unit test mode so that we don't do any sleep() in the scheduler # loop - not needed on master, but this script can run against older # releases too! os.environ['AIRFLOW__CORE__UNIT_TEST_MODE'] = 'True' os.environ['AIRFLOW__CORE__DAG_CONCURRENCY'] = '500' # Set this so that dags can dynamically configure their end_date os.environ['AIRFLOW_BENCHMARK_MAX_DAG_RUNS'] = str(num_runs) os.environ['PERF_MAX_RUNS'] = str(num_runs) if pre_create_dag_runs: os.environ['AIRFLOW__SCHEDULER__USE_JOB_SCHEDULE'] = 'False' from airflow.jobs.scheduler_job import SchedulerJob from airflow.models.dagbag import DagBag from airflow.utils import db dagbag = DagBag() dags = [] with db.create_session() as session: pause_all_dags(session) for dag_id in dag_ids: dag = dagbag.get_dag(dag_id) dag.sync_to_db(session=session) dags.append(dag) reset_dag(dag, session) next_run_date = dag.normalize_schedule(dag.start_date or min(t.start_date for t in dag.tasks)) for _ in range(num_runs - 1): next_run_date = dag.following_schedule(next_run_date) end_date = dag.end_date or dag.default_args.get('end_date') if end_date != next_run_date: message = ( f"DAG {dag_id} has incorrect end_date ({end_date}) for number of runs! " f"It should be " f" {next_run_date}") sys.exit(message) if pre_create_dag_runs: create_dag_runs(dag, num_runs, session) ShortCircuitExecutor = get_executor_under_test(executor_class) executor = ShortCircuitExecutor(dag_ids_to_watch=dag_ids, num_runs=num_runs) scheduler_job = SchedulerJob(dag_ids=dag_ids, do_pickle=False, executor=executor) executor.scheduler_job = scheduler_job total_tasks = sum(len(dag.tasks) for dag in dags) if 'PYSPY' in os.environ: pid = str(os.getpid()) filename = os.environ.get('PYSPY_O', 'flame-' + pid + '.html') os.spawnlp(os.P_NOWAIT, 'sudo', 'sudo', 'py-spy', 'record', '-o', filename, '-p', pid, '--idle') times = [] # Need a lambda to refer to the _latest_ value for scheduler_job, not just # the initial one code_to_test = lambda: scheduler_job.run() # pylint: disable=unnecessary-lambda for count in range(repeat): gc.disable() start = time.perf_counter() code_to_test() times.append(time.perf_counter() - start) gc.enable() print("Run %d time: %.5f" % (count + 1, times[-1])) if count + 1 != repeat: with db.create_session() as session: for dag in dags: reset_dag(dag, session) executor.reset(dag_ids) scheduler_job = SchedulerJob(dag_ids=dag_ids, do_pickle=False, executor=executor) executor.scheduler_job = scheduler_job print() print() msg = "Time for %d dag runs of %d dags with %d total tasks: %.4fs" if len(times) > 1: print((msg + " (±%.3fs)") % (num_runs, len(dags), total_tasks, statistics.mean(times), statistics.stdev(times))) else: print(msg % (num_runs, len(dags), total_tasks, times[0])) print() print()
class EventBasedScheduler(LoggingMixin): def __init__(self, id, mailbox: Mailbox, task_event_manager: DagRunEventManager, executor: BaseExecutor, notification_client: NotificationClient, context=None): super().__init__(context) self.id = id self.mailbox = mailbox self.task_event_manager: DagRunEventManager = task_event_manager self.executor = executor self.notification_client = notification_client self.dagbag = DagBag(read_dags_from_db=True) self._timer_handler = None self.timers = sched.scheduler() def sync(self): def call_regular_interval( delay: float, action: Callable, arguments=(), kwargs={}, ): # pylint: disable=dangerous-default-value def repeat(*args, **kwargs): action(*args, **kwargs) # This is not perfect. If we want a timer every 60s, but action # takes 10s to run, this will run it every 70s. # Good enough for now self._timer_handler = self.timers.enter( delay, 1, repeat, args, kwargs) self._timer_handler = self.timers.enter(delay, 1, repeat, arguments, kwargs) call_regular_interval(delay=1.0, action=self.executor.sync) self.timers.run() def _stop_timer(self): if self.timers and self._timer_handler: self.timers.cancel(self._timer_handler) def submit_sync_thread(self): threading.Thread(target=self.sync).start() def schedule(self): self.log.info("Starting the scheduler.") self._restore_unfinished_dag_run() while True: identified_message = self.mailbox.get_identified_message() origin_event = identified_message.deserialize() self.log.debug("Event: {}".format(origin_event)) if SchedulerInnerEventUtil.is_inner_event(origin_event): event = SchedulerInnerEventUtil.to_inner_event(origin_event) else: event = origin_event with create_session() as session: if isinstance(event, BaseEvent): dagruns = self._find_dagruns_by_event(event, session) for dagrun in dagruns: dag_run_id = DagRunId(dagrun.dag_id, dagrun.run_id) self.task_event_manager.handle_event(dag_run_id, event) elif isinstance(event, RequestEvent): self._process_request_event(event) elif isinstance(event, ResponseEvent): continue elif isinstance(event, TaskSchedulingEvent): self._schedule_task(event) elif isinstance(event, TaskStatusChangedEvent): dagrun = self._find_dagrun(event.dag_id, event.execution_date, session) tasks = self._find_schedulable_tasks(dagrun, session) self._send_scheduling_task_events(tasks, SchedulingAction.START) elif isinstance(event, DagExecutableEvent): dagrun = self._create_dag_run(event.dag_id, session=session) tasks = self._find_schedulable_tasks(dagrun, session) self._send_scheduling_task_events(tasks, SchedulingAction.START) elif isinstance(event, EventHandleEvent): dag_runs = DagRun.find(dag_id=event.dag_id, run_id=event.dag_run_id) assert len(dag_runs) == 1 ti = dag_runs[0].get_task_instance(event.task_id) self._send_scheduling_task_event(ti, event.action) elif isinstance(event, StopDagEvent): self._stop_dag(event.dag_id, session) elif isinstance(event, ParseDagRequestEvent) or isinstance( event, ParseDagResponseEvent): pass elif isinstance(event, StopSchedulerEvent): self.log.info("{} {}".format(self.id, event.job_id)) if self.id == event.job_id or 0 == event.job_id: self.log.info("break the scheduler event loop.") identified_message.remove_handled_message() session.expunge_all() break else: self.log.error( "can not handler the event {}".format(event)) identified_message.remove_handled_message() session.expunge_all() self._stop_timer() def stop(self) -> None: self.mailbox.send_message(StopSchedulerEvent(self.id).to_event()) self.log.info("Send stop event to the scheduler.") def recover(self, last_scheduling_id): self.log.info("Waiting for executor recovery...") self.executor.recover_state() unprocessed_messages = self.get_unprocessed_message(last_scheduling_id) self.log.info( "Recovering %s messages of last scheduler job with id: %s", len(unprocessed_messages), last_scheduling_id) for msg in unprocessed_messages: self.mailbox.send_identified_message(msg) @staticmethod def get_unprocessed_message( last_scheduling_id: int) -> List[IdentifiedMessage]: with create_session() as session: results: List[MSG] = session.query(MSG).filter( MSG.scheduling_job_id == last_scheduling_id, MSG.state == MessageState.QUEUED).order_by(asc(MSG.id)).all() unprocessed: List[IdentifiedMessage] = [] for msg in results: unprocessed.append(IdentifiedMessage(msg.data, msg.id)) return unprocessed def _find_dagrun(self, dag_id, execution_date, session) -> DagRun: dagrun = session.query(DagRun).filter( DagRun.dag_id == dag_id, DagRun.execution_date == execution_date).first() return dagrun def _create_dag_run(self, dag_id, session, run_type=DagRunType.SCHEDULED) -> DagRun: with prohibit_commit(session) as guard: if settings.USE_JOB_SCHEDULE: """ Unconditionally create a DAG run for the given DAG, and update the dag_model's fields to control if/when the next DAGRun should be created """ try: dag = self.dagbag.get_dag(dag_id, session=session) dag_model = session \ .query(DagModel).filter(DagModel.dag_id == dag_id).first() if dag_model is None: return None next_dagrun = dag_model.next_dagrun dag_hash = self.dagbag.dags_hash.get(dag.dag_id) run_id = None if run_type == DagRunType.MANUAL: run_id = f"{run_type}__{timezone.utcnow().isoformat()}" dag_run = dag.create_dagrun( run_type=run_type, execution_date=next_dagrun, run_id=run_id, start_date=timezone.utcnow(), state=State.RUNNING, external_trigger=False, session=session, dag_hash=dag_hash, creating_job_id=self.id, ) if run_type == DagRunType.SCHEDULED: self._update_dag_next_dagrun(dag_id, session) # commit the session - Release the write lock on DagModel table. guard.commit() # END: create dagrun return dag_run except SerializedDagNotFound: self.log.exception( "DAG '%s' not found in serialized_dag table", dag_id) return None except Exception: self.log.exception( "Error occurred when create dag_run of dag: %s", dag_id) def _update_dag_next_dagrun(self, dag_id, session): """ Bulk update the next_dagrun and next_dagrun_create_after for all the dags. We batch the select queries to get info about all the dags at once """ active_runs_of_dag = session \ .query(func.count('*')).filter( DagRun.dag_id == dag_id, DagRun.state == State.RUNNING, DagRun.external_trigger.is_(False), ).scalar() dag_model = session \ .query(DagModel).filter(DagModel.dag_id == dag_id).first() dag = self.dagbag.get_dag(dag_id, session=session) if dag.max_active_runs and active_runs_of_dag >= dag.max_active_runs: self.log.info( "DAG %s is at (or above) max_active_runs (%d of %d), not creating any more runs", dag.dag_id, active_runs_of_dag, dag.max_active_runs, ) dag_model.next_dagrun_create_after = None else: dag_model.next_dagrun, dag_model.next_dagrun_create_after = dag.next_dagrun_info( dag_model.next_dagrun) def _schedule_task(self, scheduling_event: TaskSchedulingEvent): task_key = TaskInstanceKey(scheduling_event.dag_id, scheduling_event.task_id, scheduling_event.execution_date, scheduling_event.try_number) self.executor.schedule_task(task_key, scheduling_event.action) def _find_dagruns_by_event(self, event, session) -> Optional[List[DagRun]]: affect_dag_runs = [] event_key = EventKey(event.key, event.event_type, event.namespace) dag_runs = session \ .query(DagRun).filter(DagRun.state == State.RUNNING).all() self.log.debug('dag_runs {}'.format(len(dag_runs))) if dag_runs is None or len(dag_runs) == 0: return affect_dag_runs dags = session.query(SerializedDagModel).filter( SerializedDagModel.dag_id.in_(dag_run.dag_id for dag_run in dag_runs)).all() self.log.debug('dags {}'.format(len(dags))) affect_dags = set() for dag in dags: self.log.debug('dag config {}'.format(dag.event_relationships)) self.log.debug('event key {} {} {}'.format(event.key, event.event_type, event.namespace)) dep: DagEventDependencies = DagEventDependencies.from_json( dag.event_relationships) if dep.is_affect(event_key): affect_dags.add(dag.dag_id) if len(affect_dags) == 0: return affect_dag_runs for dag_run in dag_runs: if dag_run.dag_id in affect_dags: affect_dag_runs.append(dag_run) return affect_dag_runs def _find_schedulable_tasks( self, dag_run: DagRun, session: Session, check_execution_date=False) -> Optional[List[TI]]: """ Make scheduling decisions about an individual dag run ``currently_active_runs`` is passed in so that a batch query can be used to ask this for all dag runs in the batch, to avoid an n+1 query. :param dag_run: The DagRun to schedule :return: scheduled tasks """ if not dag_run or dag_run.get_state() in State.finished: return try: dag = dag_run.dag = self.dagbag.get_dag(dag_run.dag_id, session=session) except SerializedDagNotFound: self.log.exception("DAG '%s' not found in serialized_dag table", dag_run.dag_id) return None if not dag: self.log.error("Couldn't find dag %s in DagBag/DB!", dag_run.dag_id) return None currently_active_runs = session.query(TI.execution_date, ).filter( TI.dag_id == dag_run.dag_id, TI.state.notin_(list(State.finished)), ).all() if check_execution_date and dag_run.execution_date > timezone.utcnow( ) and not dag.allow_future_exec_dates: self.log.warning("Execution date is in future: %s", dag_run.execution_date) return None if dag.max_active_runs: if (len(currently_active_runs) >= dag.max_active_runs and dag_run.execution_date not in currently_active_runs): self.log.info( "DAG %s already has %d active runs, not queuing any tasks for run %s", dag.dag_id, len(currently_active_runs), dag_run.execution_date, ) return None self._verify_integrity_if_dag_changed(dag_run=dag_run, session=session) schedulable_tis, callback_to_run = dag_run.update_state( session=session, execute_callbacks=False) dag_run.schedule_tis(schedulable_tis, session) query = (session.query(TI).outerjoin(TI.dag_run).filter( or_(DR.run_id.is_(None), DR.run_type != DagRunType.BACKFILL_JOB)).join( TI.dag_model).filter(not_(DM.is_paused)).filter( TI.state == State.SCHEDULED).options( selectinload('dag_model'))) scheduled_tis: List[TI] = with_row_locks( query, of=TI, **skip_locked(session=session), ).all() # filter need event tasks serialized_dag = session.query(SerializedDagModel).filter( SerializedDagModel.dag_id == dag_run.dag_id).first() dep: DagEventDependencies = DagEventDependencies.from_json( serialized_dag.event_relationships) event_task_set = dep.find_event_dependencies_tasks() final_scheduled_tis = [] for ti in scheduled_tis: if ti.task_id not in event_task_set: final_scheduled_tis.append(ti) return final_scheduled_tis @provide_session def _verify_integrity_if_dag_changed(self, dag_run: DagRun, session=None): """Only run DagRun.verify integrity if Serialized DAG has changed since it is slow""" latest_version = SerializedDagModel.get_latest_version_hash( dag_run.dag_id, session=session) if dag_run.dag_hash == latest_version: self.log.debug( "DAG %s not changed structure, skipping dagrun.verify_integrity", dag_run.dag_id) return dag_run.dag_hash = latest_version # Refresh the DAG dag_run.dag = self.dagbag.get_dag(dag_id=dag_run.dag_id, session=session) # Verify integrity also takes care of session.flush dag_run.verify_integrity(session=session) def _send_scheduling_task_event(self, ti: Optional[TI], action: SchedulingAction): if ti is None: return task_scheduling_event = TaskSchedulingEvent(ti.task_id, ti.dag_id, ti.execution_date, ti.try_number, action) self.mailbox.send_message(task_scheduling_event.to_event()) def _send_scheduling_task_events(self, tis: Optional[List[TI]], action: SchedulingAction): if tis is None: return for ti in tis: self._send_scheduling_task_event(ti, action) @provide_session def _emit_pool_metrics(self, session: Session = None) -> None: pools = models.Pool.slots_stats(session=session) for pool_name, slot_stats in pools.items(): Stats.gauge(f'pool.open_slots.{pool_name}', slot_stats["open"]) Stats.gauge(f'pool.queued_slots.{pool_name}', slot_stats[State.QUEUED]) Stats.gauge(f'pool.running_slots.{pool_name}', slot_stats[State.RUNNING]) @staticmethod def _reset_unfinished_task_state(dag_run): with create_session() as session: to_be_reset = [ s for s in State.unfinished if s not in [State.RUNNING, State.QUEUED] ] tis = dag_run.get_task_instances(to_be_reset, session) for ti in tis: ti.state = State.NONE session.commit() @provide_session def _restore_unfinished_dag_run(self, session): dag_runs = DagRun.next_dagruns_to_examine( session, max_number=sys.maxsize).all() if not dag_runs or len(dag_runs) == 0: return for dag_run in dag_runs: self._reset_unfinished_task_state(dag_run) tasks = self._find_schedulable_tasks(dag_run, session) self._send_scheduling_task_events(tasks, SchedulingAction.START) @provide_session def heartbeat_callback(self, session: Session = None) -> None: Stats.incr('scheduler_heartbeat', 1, 1) @provide_session def _process_request_event(self, event: RequestEvent, session: Session = None): try: message = BaseUserDefineMessage() message.from_json(event.body) if message.message_type == UserDefineMessageType.RUN_DAG: # todo make sure dag file is parsed. dagrun = self._create_dag_run(message.dag_id, session=session, run_type=DagRunType.MANUAL) if not dagrun: self.log.error("Failed to create dag_run.") # TODO Need to add ret_code and errro_msg in ExecutionContext in case of exception self.notification_client.send_event( ResponseEvent(event.request_id, None).to_event()) return tasks = self._find_schedulable_tasks(dagrun, session, False) self._send_scheduling_task_events(tasks, SchedulingAction.START) self.notification_client.send_event( ResponseEvent(event.request_id, dagrun.run_id).to_event()) elif message.message_type == UserDefineMessageType.STOP_DAG_RUN: dag_run = DagRun.get_run_by_id(session=session, dag_id=message.dag_id, run_id=message.dagrun_id) self._stop_dag_run(dag_run) self.notification_client.send_event( ResponseEvent(event.request_id, dag_run.run_id).to_event()) elif message.message_type == UserDefineMessageType.EXECUTE_TASK: dagrun = DagRun.get_run_by_id(session=session, dag_id=message.dag_id, run_id=message.dagrun_id) ti: TI = dagrun.get_task_instance(task_id=message.task_id) self.mailbox.send_message( TaskSchedulingEvent(task_id=ti.task_id, dag_id=ti.dag_id, execution_date=ti.execution_date, try_number=ti.try_number, action=SchedulingAction( message.action)).to_event()) self.notification_client.send_event( ResponseEvent(event.request_id, dagrun.run_id).to_event()) except Exception: self.log.exception("Error occurred when processing request event.") def _stop_dag(self, dag_id, session: Session): """ Stop the dag. Pause the dag and cancel all running dag_runs and task_instances. """ DagModel.get_dagmodel(dag_id, session)\ .set_is_paused(is_paused=True, including_subdags=True, session=session) active_runs = DagRun.find(dag_id=dag_id, state=State.RUNNING) for dag_run in active_runs: self._stop_dag_run(dag_run) def _stop_dag_run(self, dag_run: DagRun): dag_run.stop_dag_run() for ti in dag_run.get_task_instances(): if ti.state in State.unfinished: self.executor.schedule_task(ti.key, SchedulingAction.STOP)
def execute(self, context: Context): if isinstance(self.execution_date, datetime.datetime): parsed_execution_date = self.execution_date elif isinstance(self.execution_date, str): parsed_execution_date = timezone.parse(self.execution_date) else: parsed_execution_date = timezone.utcnow() if self.trigger_run_id: run_id = self.trigger_run_id else: run_id = DagRun.generate_run_id(DagRunType.MANUAL, parsed_execution_date) try: dag_run = trigger_dag( dag_id=self.trigger_dag_id, run_id=run_id, conf=self.conf, execution_date=parsed_execution_date, replace_microseconds=False, ) except DagRunAlreadyExists as e: if self.reset_dag_run: self.log.info("Clearing %s on %s", self.trigger_dag_id, parsed_execution_date) # Get target dag object and call clear() dag_model = DagModel.get_current(self.trigger_dag_id) if dag_model is None: raise DagNotFound( f"Dag id {self.trigger_dag_id} not found in DagModel") dag_bag = DagBag(dag_folder=dag_model.fileloc, read_dags_from_db=True) dag = dag_bag.get_dag(self.trigger_dag_id) dag.clear(start_date=parsed_execution_date, end_date=parsed_execution_date) dag_run = DagRun.find(dag_id=dag.dag_id, run_id=run_id)[0] else: raise e if dag_run is None: raise RuntimeError("The dag_run should be set here!") # Store the execution date from the dag run (either created or found above) to # be used when creating the extra link on the webserver. ti = context['task_instance'] ti.xcom_push(key=XCOM_EXECUTION_DATE_ISO, value=dag_run.execution_date.isoformat()) ti.xcom_push(key=XCOM_RUN_ID, value=dag_run.run_id) if self.wait_for_completion: # wait for dag to complete while True: self.log.info( 'Waiting for %s on %s to become allowed state %s ...', self.trigger_dag_id, dag_run.execution_date, self.allowed_states, ) time.sleep(self.poke_interval) dag_run.refresh_from_db() state = dag_run.state if state in self.failed_states: raise AirflowException( f"{self.trigger_dag_id} failed with failed states {state}" ) if state in self.allowed_states: self.log.info("%s finished with allowed state %s", self.trigger_dag_id, state) return
class Dashboard(BaseView): template_folder = os.path.join(os.path.dirname(__file__), 'templates') DATETIME_FORMAT = '%m/%d/%y %I:%M %p' def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.airflow_session = settings.Session() self.airflow_dag_bag = DagBag() @expose('/') def list(self): dag_info = self.get_dag_info() event_dags = [ dag for dag in dag_info if 'event' in dag['name'] or dag['name'] == 'full_scrape' ] bill_dags = [ dag for dag in dag_info if 'bill' in dag['name'] or dag['name'] == 'full_scrape' ] event_last_run = self.get_last_successful_dagrun(event_dags) bill_last_run = self.get_last_successful_dagrun(bill_dags) event_next_run = self.get_next_dagrun(event_dags) bill_next_run = self.get_next_dagrun(bill_dags) events_in_db, bills_in_db, bills_in_index = self.get_db_info() metadata = { 'data': dag_info, 'event_last_run': event_last_run, 'event_next_run': event_next_run, 'events_in_db': events_in_db, 'bill_last_run': bill_last_run, 'bill_next_run': bill_next_run, 'bills_in_db': bills_in_db, 'bills_in_index': bills_in_index, 'datetime_format': self.DATETIME_FORMAT, } return self.render_template('dashboard.html', **metadata) def get_dag_info(self): dags = [ self.airflow_dag_bag.get_dag(dag_id) for dag_id in self.airflow_dag_bag.dag_ids if not dag_id.startswith('airflow_') ] # Filter meta-DAGs data = [] for d in dags: last_run = dag.get_last_dagrun(d.dag_id, self.airflow_session, include_externally_triggered=True) if last_run: run_state = last_run.get_state() run_date_info = self._get_localized_time( last_run.execution_date) last_successful_info = self._get_last_succesful_run_date(d) next_scheduled = d.following_schedule(datetime.now(pytz.utc)) next_scheduled_info = self._get_localized_time(next_scheduled) else: run_state = None run_date_info = {} last_successful_info = {} next_scheduled_info = {} dag_info = { 'name': d.dag_id, 'description': d.description, 'run_state': run_state, 'run_date': run_date_info, 'last_successful_date': last_successful_info, 'next_scheduled_date': next_scheduled_info, } data.append(dag_info) return data def get_last_successful_dagrun(self, dags): successful_runs = [ dag for dag in dags if dag['last_successful_date'].get('pst_time') ] if successful_runs: return max(successful_runs, key=lambda x: x['last_successful_date']['pst_time']) def get_next_dagrun(self, dags): scheduled_runs = [ dag for dag in dags if dag['next_scheduled_date'].get('pst_time') ] if scheduled_runs: return min(scheduled_runs, key=lambda x: x['next_scheduled_date']['pst_time']) def get_db_info(self): url_parts = { 'hostname': os.getenv('LA_METRO_HOST', 'http://app:8000'), 'api_key': os.getenv('LA_METRO_API_KEY', 'test key'), } endpoint = '{hostname}/object-counts/{api_key}'.format(**url_parts) response = requests.get(endpoint) try: response_json = response.json() except json.decoder.JSONDecodeError: print(response.text) else: if response_json['status_code'] == 200: return (response_json['event_count'], response_json['bill_count'], response_json['search_index_count']) return None, None, None def _get_localized_time(self, date): pst_time = date.astimezone(PACIFIC_TIMEZONE) cst_time = date.astimezone(CENTRAL_TIMEZONE) return { 'pst_time': pst_time, 'cst_time': cst_time, } def _get_last_succesful_run_date(self, dag): run = self.airflow_session.query(dagrun.DagRun)\ .filter(dagrun.DagRun.dag_id == dag.dag_id)\ .filter(dagrun.DagRun.state == 'success')\ .order_by(dagrun.DagRun.execution_date.desc())\ .first() if run: return self._get_localized_time(run.execution_date) else: return {}
def get_dag(dag_id: str) -> DAG: dag_bag = DagBag() dag = dag_bag.get_dag(dag_id=dag_id) if dag is None: raise KeyError(f"DAG with ID '{dag_id}' does not exist.") return dag
class SchedulerJob(BaseJob): """ This SchedulerJob runs for a specific time interval and schedules the jobs that are ready to run. It figures out the latest runs for each task and sees if the dependencies for the next schedules are met. If so, it creates appropriate TaskInstances and sends run commands to the executor. It does this for each task in each DAG and repeats. :param subdir: directory containing Python files with Airflow DAG definitions, or a specific path to a file :type subdir: str :param num_runs: The number of times to run the scheduling loop. If you have a large number of DAG files this could complete before each file has been parsed. -1 for unlimited times. :type num_runs: int :param num_times_parse_dags: The number of times to try to parse each DAG file. -1 for unlimited times. :type num_times_parse_dags: int :param processor_poll_interval: The number of seconds to wait between polls of running processors :type processor_poll_interval: int :param do_pickle: once a DAG object is obtained by executing the Python file, whether to serialize the DAG object to the DB :type do_pickle: bool :param log: override the default Logger :type log: logging.Logger """ __mapper_args__ = {'polymorphic_identity': 'SchedulerJob'} heartrate: int = conf.getint('scheduler', 'SCHEDULER_HEARTBEAT_SEC') def __init__( self, subdir: str = settings.DAGS_FOLDER, num_runs: int = conf.getint('scheduler', 'num_runs'), num_times_parse_dags: int = -1, processor_poll_interval: float = conf.getfloat('scheduler', 'processor_poll_interval'), do_pickle: bool = False, log: logging.Logger = None, *args, **kwargs, ): self.subdir = subdir self.num_runs = num_runs # In specific tests, we want to stop the parse loop after the _files_ have been parsed a certain # number of times. This is only to support testing, and isn't something a user is likely to want to # configure -- they'll want num_runs self.num_times_parse_dags = num_times_parse_dags self._processor_poll_interval = processor_poll_interval self.do_pickle = do_pickle super().__init__(*args, **kwargs) if log: self._log = log # Check what SQL backend we use sql_conn: str = conf.get('core', 'sql_alchemy_conn').lower() self.using_sqlite = sql_conn.startswith('sqlite') self.using_mysql = sql_conn.startswith('mysql') self.max_tis_per_query: int = conf.getint('scheduler', 'max_tis_per_query') self.processor_agent: Optional[DagFileProcessorAgent] = None self.dagbag = DagBag(dag_folder=self.subdir, read_dags_from_db=True, load_op_links=False) def register_signals(self) -> None: """Register signals that stop child processes""" signal.signal(signal.SIGINT, self._exit_gracefully) signal.signal(signal.SIGTERM, self._exit_gracefully) signal.signal(signal.SIGUSR2, self._debug_dump) def _exit_gracefully(self, signum, frame) -> None: """Helper method to clean up processor_agent to avoid leaving orphan processes.""" if not _is_parent_process(): # Only the parent process should perform the cleanup. return self.log.info("Exiting gracefully upon receiving signal %s", signum) if self.processor_agent: self.processor_agent.end() sys.exit(os.EX_OK) def _debug_dump(self, signum, frame): if not _is_parent_process(): # Only the parent process should perform the debug dump. return try: sig_name = signal.Signals(signum).name except Exception: sig_name = str(signum) self.log.info("%s\n%s received, printing debug\n%s", "-" * 80, sig_name, "-" * 80) self.executor.debug_dump() self.log.info("-" * 80) def is_alive(self, grace_multiplier: Optional[float] = None) -> bool: """ Is this SchedulerJob alive? We define alive as in a state of running and a heartbeat within the threshold defined in the ``scheduler_health_check_threshold`` config setting. ``grace_multiplier`` is accepted for compatibility with the parent class. :rtype: boolean """ if grace_multiplier is not None: # Accept the same behaviour as superclass return super().is_alive(grace_multiplier=grace_multiplier) scheduler_health_check_threshold: int = conf.getint('scheduler', 'scheduler_health_check_threshold') return ( self.state == State.RUNNING and (timezone.utcnow() - self.latest_heartbeat).total_seconds() < scheduler_health_check_threshold ) @provide_session def __get_concurrency_maps( self, states: List[TaskInstanceState], session: Session = None ) -> Tuple[DefaultDict[str, int], DefaultDict[Tuple[str, str], int]]: """ Get the concurrency maps. :param states: List of states to query for :type states: list[airflow.utils.state.State] :return: A map from (dag_id, task_id) to # of task instances and a map from (dag_id, task_id) to # of task instances in the given state list :rtype: tuple[dict[str, int], dict[tuple[str, str], int]] """ ti_concurrency_query: List[Tuple[str, str, int]] = ( session.query(TI.task_id, TI.dag_id, func.count('*')) .filter(TI.state.in_(states)) .group_by(TI.task_id, TI.dag_id) ).all() dag_map: DefaultDict[str, int] = defaultdict(int) task_map: DefaultDict[Tuple[str, str], int] = defaultdict(int) for result in ti_concurrency_query: task_id, dag_id, count = result dag_map[dag_id] += count task_map[(dag_id, task_id)] = count return dag_map, task_map @provide_session def _executable_task_instances_to_queued(self, max_tis: int, session: Session = None) -> List[TI]: """ Finds TIs that are ready for execution with respect to pool limits, dag max_active_tasks, executor state, and priority. :param max_tis: Maximum number of TIs to queue in this loop. :type max_tis: int :return: list[airflow.models.TaskInstance] """ executable_tis: List[TI] = [] # Get the pool settings. We get a lock on the pool rows, treating this as a "critical section" # Throws an exception if lock cannot be obtained, rather than blocking pools = models.Pool.slots_stats(lock_rows=True, session=session) # If the pools are full, there is no point doing anything! # If _somehow_ the pool is overfull, don't let the limit go negative - it breaks SQL pool_slots_free = max(0, sum(pool['open'] for pool in pools.values())) if pool_slots_free == 0: self.log.debug("All pools are full!") return executable_tis max_tis = min(max_tis, pool_slots_free) # Get all task instances associated with scheduled # DagRuns which are not backfilled, in the given states, # and the dag is not paused query = ( session.query(TI) .join(TI.dag_run) .options(eagerload(TI.dag_run)) .filter(DR.run_type != DagRunType.BACKFILL_JOB, DR.state != DagRunState.QUEUED) .join(TI.dag_model) .filter(not_(DM.is_paused)) .filter(TI.state == State.SCHEDULED) .options(selectinload('dag_model')) .order_by(-TI.priority_weight, DR.execution_date) ) starved_pools = [pool_name for pool_name, stats in pools.items() if stats['open'] <= 0] if starved_pools: query = query.filter(not_(TI.pool.in_(starved_pools))) query = query.limit(max_tis) task_instances_to_examine: List[TI] = with_row_locks( query, of=TI, session=session, **skip_locked(session=session), ).all() # TODO[HA]: This was wrong before anyway, as it only looked at a sub-set of dags, not everything. # Stats.gauge('scheduler.tasks.pending', len(task_instances_to_examine)) if len(task_instances_to_examine) == 0: self.log.debug("No tasks to consider for execution.") return executable_tis # Put one task instance on each line task_instance_str = "\n\t".join(repr(x) for x in task_instances_to_examine) self.log.info("%s tasks up for execution:\n\t%s", len(task_instances_to_examine), task_instance_str) pool_to_task_instances: DefaultDict[str, List[models.Pool]] = defaultdict(list) for task_instance in task_instances_to_examine: pool_to_task_instances[task_instance.pool].append(task_instance) # dag_id to # of running tasks and (dag_id, task_id) to # of running tasks. dag_max_active_tasks_map: DefaultDict[str, int] task_concurrency_map: DefaultDict[Tuple[str, str], int] dag_max_active_tasks_map, task_concurrency_map = self.__get_concurrency_maps( states=list(EXECUTION_STATES), session=session ) num_tasks_in_executor = 0 # Number of tasks that cannot be scheduled because of no open slot in pool num_starving_tasks_total = 0 # Go through each pool, and queue up a task for execution if there are # any open slots in the pool. for pool, task_instances in pool_to_task_instances.items(): pool_name = pool if pool not in pools: self.log.warning("Tasks using non-existent pool '%s' will not be scheduled", pool) continue open_slots = pools[pool]["open"] num_ready = len(task_instances) self.log.info( "Figuring out tasks to run in Pool(name=%s) with %s open slots " "and %s task instances ready to be queued", pool, open_slots, num_ready, ) priority_sorted_task_instances = sorted( task_instances, key=lambda ti: (-ti.priority_weight, ti.execution_date) ) num_starving_tasks = 0 for current_index, task_instance in enumerate(priority_sorted_task_instances): if open_slots <= 0: self.log.info("Not scheduling since there are %s open slots in pool %s", open_slots, pool) # Can't schedule any more since there are no more open slots. num_unhandled = len(priority_sorted_task_instances) - current_index num_starving_tasks += num_unhandled num_starving_tasks_total += num_unhandled break # Check to make sure that the task max_active_tasks of the DAG hasn't been # reached. dag_id = task_instance.dag_id current_max_active_tasks_per_dag = dag_max_active_tasks_map[dag_id] max_active_tasks_per_dag_limit = task_instance.dag_model.max_active_tasks self.log.info( "DAG %s has %s/%s running and queued tasks", dag_id, current_max_active_tasks_per_dag, max_active_tasks_per_dag_limit, ) if current_max_active_tasks_per_dag >= max_active_tasks_per_dag_limit: self.log.info( "Not executing %s since the number of tasks running or queued " "from DAG %s is >= to the DAG's max_active_tasks limit of %s", task_instance, dag_id, max_active_tasks_per_dag_limit, ) continue task_concurrency_limit: Optional[int] = None if task_instance.dag_model.has_task_concurrency_limits: # Many dags don't have a task_concurrency, so where we can avoid loading the full # serialized DAG the better. serialized_dag = self.dagbag.get_dag(dag_id, session=session) if serialized_dag.has_task(task_instance.task_id): task_concurrency_limit = serialized_dag.get_task( task_instance.task_id ).max_active_tis_per_dag if task_concurrency_limit is not None: current_task_concurrency = task_concurrency_map[ (task_instance.dag_id, task_instance.task_id) ] if current_task_concurrency >= task_concurrency_limit: self.log.info( "Not executing %s since the task concurrency for" " this task has been reached.", task_instance, ) continue if task_instance.pool_slots > open_slots: self.log.info( "Not executing %s since it requires %s slots " "but there are %s open slots in the pool %s.", task_instance, task_instance.pool_slots, open_slots, pool, ) num_starving_tasks += 1 num_starving_tasks_total += 1 # Though we can execute tasks with lower priority if there's enough room continue executable_tis.append(task_instance) open_slots -= task_instance.pool_slots dag_max_active_tasks_map[dag_id] += 1 task_concurrency_map[(task_instance.dag_id, task_instance.task_id)] += 1 Stats.gauge(f'pool.starving_tasks.{pool_name}', num_starving_tasks) Stats.gauge('scheduler.tasks.starving', num_starving_tasks_total) Stats.gauge('scheduler.tasks.running', num_tasks_in_executor) Stats.gauge('scheduler.tasks.executable', len(executable_tis)) task_instance_str = "\n\t".join(repr(x) for x in executable_tis) self.log.info("Setting the following tasks to queued state:\n\t%s", task_instance_str) if len(executable_tis) > 0: # set TIs to queued state filter_for_tis = TI.filter_for_tis(executable_tis) session.query(TI).filter(filter_for_tis).update( # TODO[ha]: should we use func.now()? How does that work with DB timezone # on mysql when it's not UTC? {TI.state: State.QUEUED, TI.queued_dttm: timezone.utcnow(), TI.queued_by_job_id: self.id}, synchronize_session=False, ) for ti in executable_tis: make_transient(ti) return executable_tis def _enqueue_task_instances_with_queued_state(self, task_instances: List[TI]) -> None: """ Takes task_instances, which should have been set to queued, and enqueues them with the executor. :param task_instances: TaskInstances to enqueue :type task_instances: list[TaskInstance] """ # actually enqueue them for ti in task_instances: command = ti.command_as_list( local=True, pickle_id=ti.dag_model.pickle_id, ) priority = ti.priority_weight queue = ti.queue self.log.info("Sending %s to executor with priority %s and queue %s", ti.key, priority, queue) self.executor.queue_command( ti, command, priority=priority, queue=queue, ) def _critical_section_execute_task_instances(self, session: Session) -> int: """ Attempts to execute TaskInstances that should be executed by the scheduler. There are three steps: 1. Pick TIs by priority with the constraint that they are in the expected states and that we do exceed max_active_runs or pool limits. 2. Change the state for the TIs above atomically. 3. Enqueue the TIs in the executor. HA note: This function is a "critical section" meaning that only a single executor process can execute this function at the same time. This is achieved by doing ``SELECT ... from pool FOR UPDATE``. For DBs that support NOWAIT, a "blocked" scheduler will skip this and continue on with other tasks (creating new DAG runs, progressing TIs from None to SCHEDULED etc.); DBs that don't support this (such as MariaDB or MySQL 5.x) the other schedulers will wait for the lock before continuing. :param session: :type session: sqlalchemy.orm.Session :return: Number of task instance with state changed. """ if self.max_tis_per_query == 0: max_tis = self.executor.slots_available else: max_tis = min(self.max_tis_per_query, self.executor.slots_available) queued_tis = self._executable_task_instances_to_queued(max_tis, session=session) self._enqueue_task_instances_with_queued_state(queued_tis) return len(queued_tis) @provide_session def _process_executor_events(self, session: Session = None) -> int: """Respond to executor events.""" if not self.processor_agent: raise ValueError("Processor agent is not started.") ti_primary_key_to_try_number_map: Dict[Tuple[str, str, datetime.datetime], int] = {} event_buffer = self.executor.get_event_buffer() tis_with_right_state: List[TaskInstanceKey] = [] # Report execution for ti_key, value in event_buffer.items(): state: str state, _ = value # We create map (dag_id, task_id, execution_date) -> in-memory try_number ti_primary_key_to_try_number_map[ti_key.primary] = ti_key.try_number self.log.info( "Executor reports execution of %s.%s run_id=%s exited with status %s for try_number %s", ti_key.dag_id, ti_key.task_id, ti_key.run_id, state, ti_key.try_number, ) if state in (State.FAILED, State.SUCCESS, State.QUEUED): tis_with_right_state.append(ti_key) # Return if no finished tasks if not tis_with_right_state: return len(event_buffer) # Check state of finished tasks filter_for_tis = TI.filter_for_tis(tis_with_right_state) tis: List[TI] = session.query(TI).filter(filter_for_tis).options(selectinload('dag_model')).all() for ti in tis: try_number = ti_primary_key_to_try_number_map[ti.key.primary] buffer_key = ti.key.with_try_number(try_number) state, info = event_buffer.pop(buffer_key) # TODO: should we fail RUNNING as well, as we do in Backfills? if state == State.QUEUED: ti.external_executor_id = info self.log.info("Setting external_id for %s to %s", ti, info) continue if ti.try_number == buffer_key.try_number and ti.state == State.QUEUED: Stats.incr('scheduler.tasks.killed_externally') msg = ( "Executor reports task instance %s finished (%s) although the " "task says its %s. (Info: %s) Was the task killed externally?" ) self.log.error(msg, ti, state, ti.state, info) # Get task from the Serialized DAG try: dag = self.dagbag.get_dag(ti.dag_id) task = dag.get_task(ti.task_id) except Exception: self.log.exception("Marking task instance %s as %s", ti, state) ti.set_state(state) continue ti.task = task if task.on_retry_callback or task.on_failure_callback: request = TaskCallbackRequest( full_filepath=ti.dag_model.fileloc, simple_task_instance=SimpleTaskInstance(ti), msg=msg % (ti, state, ti.state, info), ) self.processor_agent.send_callback_to_execute(request) else: ti.handle_failure(error=msg % (ti, state, ti.state, info), session=session) return len(event_buffer) def _execute(self) -> None: self.log.info("Starting the scheduler") # DAGs can be pickled for easier remote execution by some executors pickle_dags = self.do_pickle and self.executor_class not in UNPICKLEABLE_EXECUTORS self.log.info("Processing each file at most %s times", self.num_times_parse_dags) # When using sqlite, we do not use async_mode # so the scheduler job and DAG parser don't access the DB at the same time. async_mode = not self.using_sqlite processor_timeout_seconds: int = conf.getint('core', 'dag_file_processor_timeout') processor_timeout = timedelta(seconds=processor_timeout_seconds) self.processor_agent = DagFileProcessorAgent( dag_directory=self.subdir, max_runs=self.num_times_parse_dags, processor_timeout=processor_timeout, dag_ids=[], pickle_dags=pickle_dags, async_mode=async_mode, ) try: self.executor.job_id = self.id self.executor.start() self.register_signals() self.processor_agent.start() execute_start_time = timezone.utcnow() self._run_scheduler_loop() # Stop any processors self.processor_agent.terminate() # Verify that all files were processed, and if so, deactivate DAGs that # haven't been touched by the scheduler as they likely have been # deleted. if self.processor_agent.all_files_processed: self.log.info( "Deactivating DAGs that haven't been touched since %s", execute_start_time.isoformat() ) models.DAG.deactivate_stale_dags(execute_start_time) settings.Session.remove() # type: ignore except Exception: self.log.exception("Exception when executing SchedulerJob._run_scheduler_loop") raise finally: try: self.executor.end() except Exception: self.log.exception("Exception when executing Executor.end") try: self.processor_agent.end() except Exception: self.log.exception("Exception when executing DagFileProcessorAgent.end") self.log.info("Exited execute loop") def _run_scheduler_loop(self) -> None: """ The actual scheduler loop. The main steps in the loop are: #. Harvest DAG parsing results through DagFileProcessorAgent #. Find and queue executable tasks #. Change task instance state in DB #. Queue tasks in executor #. Heartbeat executor #. Execute queued tasks in executor asynchronously #. Sync on the states of running tasks Following is a graphic representation of these steps. .. image:: ../docs/apache-airflow/img/scheduler_loop.jpg :rtype: None """ if not self.processor_agent: raise ValueError("Processor agent is not started.") is_unit_test: bool = conf.getboolean('core', 'unit_test_mode') timers = EventScheduler() # Check on start up, then every configured interval self.adopt_or_reset_orphaned_tasks() timers.call_regular_interval( conf.getfloat('scheduler', 'orphaned_tasks_check_interval', fallback=300.0), self.adopt_or_reset_orphaned_tasks, ) timers.call_regular_interval( conf.getfloat('scheduler', 'trigger_timeout_check_interval', fallback=15.0), self.check_trigger_timeouts, ) timers.call_regular_interval( conf.getfloat('scheduler', 'pool_metrics_interval', fallback=5.0), self._emit_pool_metrics, ) for loop_count in itertools.count(start=1): with Stats.timer() as timer: if self.using_sqlite: self.processor_agent.run_single_parsing_loop() # For the sqlite case w/ 1 thread, wait until the processor # is finished to avoid concurrent access to the DB. self.log.debug("Waiting for processors to finish since we're using sqlite") self.processor_agent.wait_until_finished() with create_session() as session: num_queued_tis = self._do_scheduling(session) self.executor.heartbeat() session.expunge_all() num_finished_events = self._process_executor_events(session=session) self.processor_agent.heartbeat() # Heartbeat the scheduler periodically self.heartbeat(only_if_necessary=True) # Run any pending timed events next_event = timers.run(blocking=False) self.log.debug("Next timed event is in %f", next_event) self.log.debug("Ran scheduling loop in %.2f seconds", timer.duration) if not is_unit_test and not num_queued_tis and not num_finished_events: # If the scheduler is doing things, don't sleep. This means when there is work to do, the # scheduler will run "as quick as possible", but when it's stopped, it can sleep, dropping CPU # usage when "idle" time.sleep(min(self._processor_poll_interval, next_event)) if loop_count >= self.num_runs > 0: self.log.info( "Exiting scheduler loop as requested number of runs (%d - got to %d) has been reached", self.num_runs, loop_count, ) break if self.processor_agent.done: self.log.info( "Exiting scheduler loop as requested DAG parse count (%d) has been reached after %d" " scheduler loops", self.num_times_parse_dags, loop_count, ) break def _do_scheduling(self, session) -> int: """ This function is where the main scheduling decisions take places. It: - Creates any necessary DAG runs by examining the next_dagrun_create_after column of DagModel Since creating Dag Runs is a relatively time consuming process, we select only 10 dags by default (configurable via ``scheduler.max_dagruns_to_create_per_loop`` setting) - putting this higher will mean one scheduler could spend a chunk of time creating dag runs, and not ever get around to scheduling tasks. - Finds the "next n oldest" running DAG Runs to examine for scheduling (n=20 by default, configurable via ``scheduler.max_dagruns_per_loop_to_schedule`` config setting) and tries to progress state (TIs to SCHEDULED, or DagRuns to SUCCESS/FAILURE etc) By "next oldest", we mean hasn't been examined/scheduled in the most time. The reason we don't select all dagruns at once because the rows are selected with row locks, meaning that only one scheduler can "process them", even it is waiting behind other dags. Increasing this limit will allow more throughput for smaller DAGs but will likely slow down throughput for larger (>500 tasks.) DAGs - Then, via a Critical Section (locking the rows of the Pool model) we queue tasks, and then send them to the executor. See docs of _critical_section_execute_task_instances for more. :return: Number of TIs enqueued in this iteration :rtype: int """ # Put a check in place to make sure we don't commit unexpectedly with prohibit_commit(session) as guard: if settings.USE_JOB_SCHEDULE: self._create_dagruns_for_dags(guard, session) self._start_queued_dagruns(session) guard.commit() dag_runs = self._get_next_dagruns_to_examine(State.RUNNING, session) # Bulk fetch the currently active dag runs for the dags we are # examining, rather than making one query per DagRun callback_tuples = [] for dag_run in dag_runs: callback_to_run = self._schedule_dag_run(dag_run, session) callback_tuples.append((dag_run, callback_to_run)) guard.commit() # Send the callbacks after we commit to ensure the context is up to date when it gets run for dag_run, callback_to_run in callback_tuples: self._send_dag_callbacks_to_processor(dag_run, callback_to_run) # Without this, the session has an invalid view of the DB session.expunge_all() # END: schedule TIs try: if self.executor.slots_available <= 0: # We know we can't do anything here, so don't even try! self.log.debug("Executor full, skipping critical section") return 0 timer = Stats.timer('scheduler.critical_section_duration') timer.start() # Find anything TIs in state SCHEDULED, try to QUEUE it (send it to the executor) num_queued_tis = self._critical_section_execute_task_instances(session=session) # Make sure we only sent this metric if we obtained the lock, otherwise we'll skew the # metric, way down timer.stop(send=True) except OperationalError as e: timer.stop(send=False) if is_lock_not_available_error(error=e): self.log.debug("Critical section lock held by another Scheduler") Stats.incr('scheduler.critical_section_busy') session.rollback() return 0 raise guard.commit() return num_queued_tis @retry_db_transaction def _get_next_dagruns_to_examine(self, state: DagRunState, session: Session): """Get Next DagRuns to Examine with retries""" return DagRun.next_dagruns_to_examine(state, session) @retry_db_transaction def _create_dagruns_for_dags(self, guard, session): """Find Dag Models needing DagRuns and Create Dag Runs with retries in case of OperationalError""" query = DagModel.dags_needing_dagruns(session) self._create_dag_runs(query.all(), session) # commit the session - Release the write lock on DagModel table. guard.commit() # END: create dagruns def _create_dag_runs(self, dag_models: Collection[DagModel], session: Session) -> None: """ Unconditionally create a DAG run for the given DAG, and update the dag_model's fields to control if/when the next DAGRun should be created """ # Bulk Fetch DagRuns with dag_id and execution_date same # as DagModel.dag_id and DagModel.next_dagrun # This list is used to verify if the DagRun already exist so that we don't attempt to create # duplicate dag runs if session.bind.dialect.name == 'mssql': existing_dagruns_filter = or_( *( and_( DagRun.dag_id == dm.dag_id, DagRun.execution_date == dm.next_dagrun, ) for dm in dag_models ) ) else: existing_dagruns_filter = tuple_(DagRun.dag_id, DagRun.execution_date).in_( [(dm.dag_id, dm.next_dagrun) for dm in dag_models] ) existing_dagruns = ( session.query(DagRun.dag_id, DagRun.execution_date).filter(existing_dagruns_filter).all() ) max_queued_dagruns = conf.getint('core', 'max_queued_runs_per_dag') queued_runs_of_dags = defaultdict( int, session.query(DagRun.dag_id, func.count('*')) .filter( # We use `list` here because SQLA doesn't accept a set # We use set to avoid duplicate dag_ids DagRun.dag_id.in_(list({dm.dag_id for dm in dag_models})), DagRun.state == State.QUEUED, ) .group_by(DagRun.dag_id) .all(), ) for dag_model in dag_models: # Lets quickly check if we have exceeded the number of queued dagruns per dags total_queued = queued_runs_of_dags[dag_model.dag_id] if total_queued >= max_queued_dagruns: continue dag = self.dagbag.get_dag(dag_model.dag_id, session=session) if not dag: self.log.error("DAG '%s' not found in serialized_dag table", dag_model.dag_id) continue dag_hash = self.dagbag.dags_hash.get(dag.dag_id) data_interval = dag.get_next_data_interval(dag_model) # Explicitly check if the DagRun already exists. This is an edge case # where a Dag Run is created but `DagModel.next_dagrun` and `DagModel.next_dagrun_create_after` # are not updated. # We opted to check DagRun existence instead # of catching an Integrity error and rolling back the session i.e # we need to set dag.next_dagrun_info if the Dag Run already exists or if we # create a new one. This is so that in the next Scheduling loop we try to create new runs # instead of falling in a loop of Integrity Error. if (dag.dag_id, dag_model.next_dagrun) not in existing_dagruns: dag.create_dagrun( run_type=DagRunType.SCHEDULED, execution_date=dag_model.next_dagrun, state=State.QUEUED, data_interval=data_interval, external_trigger=False, session=session, dag_hash=dag_hash, creating_job_id=self.id, ) queued_runs_of_dags[dag_model.dag_id] += 1 dag_model.calculate_dagrun_date_fields(dag, data_interval) # TODO[HA]: Should we do a session.flush() so we don't have to keep lots of state/object in # memory for larger dags? or expunge_all() def _start_queued_dagruns( self, session: Session, ) -> int: """Find DagRuns in queued state and decide moving them to running state""" dag_runs = self._get_next_dagruns_to_examine(State.QUEUED, session) active_runs_of_dags = defaultdict( lambda: 0, session.query(DagRun.dag_id, func.count('*')) .filter( # We use `list` here because SQLA doesn't accept a set # We use set to avoid duplicate dag_ids DagRun.dag_id.in_(list({dr.dag_id for dr in dag_runs})), DagRun.state == State.RUNNING, ) .group_by(DagRun.dag_id) .all(), ) def _update_state(dag: DAG, dag_run: DagRun): dag_run.state = State.RUNNING dag_run.start_date = timezone.utcnow() if dag.timetable.periodic: # TODO: Logically, this should be DagRunInfo.run_after, but the # information is not stored on a DagRun, only before the actual # execution on DagModel.next_dagrun_create_after. We should add # a field on DagRun for this instead of relying on the run # always happening immediately after the data interval. expected_start_date = dag.get_run_data_interval(dag_run).end schedule_delay = dag_run.start_date - expected_start_date Stats.timing(f'dagrun.schedule_delay.{dag.dag_id}', schedule_delay) for dag_run in dag_runs: dag = dag_run.dag = self.dagbag.get_dag(dag_run.dag_id, session=session) if not dag: self.log.error("DAG '%s' not found in serialized_dag table", dag_run.dag_id) continue active_runs = active_runs_of_dags[dag_run.dag_id] if dag.max_active_runs and active_runs >= dag.max_active_runs: self.log.debug( "DAG %s already has %d active runs, not moving any more runs to RUNNING state %s", dag.dag_id, active_runs, dag_run.execution_date, ) else: active_runs_of_dags[dag_run.dag_id] += 1 _update_state(dag, dag_run) def _schedule_dag_run( self, dag_run: DagRun, session: Session, ) -> Optional[DagCallbackRequest]: """ Make scheduling decisions about an individual dag run :param dag_run: The DagRun to schedule :return: Callback that needs to be executed """ dag = dag_run.dag = self.dagbag.get_dag(dag_run.dag_id, session=session) if not dag: self.log.error("Couldn't find dag %s in DagBag/DB!", dag_run.dag_id) return 0 if ( dag_run.start_date and dag.dagrun_timeout and dag_run.start_date < timezone.utcnow() - dag.dagrun_timeout ): dag_run.set_state(State.FAILED) unfinished_task_instances = ( session.query(TI) .filter(TI.dag_id == dag_run.dag_id) .filter(TI.run_id == dag_run.run_id) .filter(TI.state.in_(State.unfinished)) ) for task_instance in unfinished_task_instances: task_instance.state = State.SKIPPED session.merge(task_instance) session.flush() self.log.info("Run %s of %s has timed-out", dag_run.run_id, dag_run.dag_id) callback_to_execute = DagCallbackRequest( full_filepath=dag.fileloc, dag_id=dag.dag_id, run_id=dag_run.run_id, is_failure_callback=True, msg='timed_out', ) # Send SLA & DAG Success/Failure Callbacks to be executed self._send_dag_callbacks_to_processor(dag_run, callback_to_execute) return 0 if dag_run.execution_date > timezone.utcnow() and not dag.allow_future_exec_dates: self.log.error("Execution date is in future: %s", dag_run.execution_date) return 0 self._verify_integrity_if_dag_changed(dag_run=dag_run, session=session) # TODO[HA]: Rename update_state -> schedule_dag_run, ?? something else? schedulable_tis, callback_to_run = dag_run.update_state(session=session, execute_callbacks=False) # This will do one query per dag run. We "could" build up a complex # query to update all the TIs across all the execution dates and dag # IDs in a single query, but it turns out that can be _very very slow_ # see #11147/commit ee90807ac for more details dag_run.schedule_tis(schedulable_tis, session) return callback_to_run @provide_session def _verify_integrity_if_dag_changed(self, dag_run: DagRun, session=None): """Only run DagRun.verify integrity if Serialized DAG has changed since it is slow""" latest_version = SerializedDagModel.get_latest_version_hash(dag_run.dag_id, session=session) if dag_run.dag_hash == latest_version: self.log.debug("DAG %s not changed structure, skipping dagrun.verify_integrity", dag_run.dag_id) return dag_run.dag_hash = latest_version # Refresh the DAG dag_run.dag = self.dagbag.get_dag(dag_id=dag_run.dag_id, session=session) # Verify integrity also takes care of session.flush dag_run.verify_integrity(session=session) def _send_dag_callbacks_to_processor( self, dag_run: DagRun, callback: Optional[DagCallbackRequest] = None ): if not self.processor_agent: raise ValueError("Processor agent is not started.") dag = dag_run.get_dag() self._send_sla_callbacks_to_processor(dag) if callback: self.processor_agent.send_callback_to_execute(callback) def _send_sla_callbacks_to_processor(self, dag: DAG): """Sends SLA Callbacks to DagFileProcessor if tasks have SLAs set and check_slas=True""" if not settings.CHECK_SLAS: return if not any(isinstance(ti.sla, timedelta) for ti in dag.tasks): self.log.debug("Skipping SLA check for %s because no tasks in DAG have SLAs", dag) return if not self.processor_agent: raise ValueError("Processor agent is not started.") self.processor_agent.send_sla_callback_request_to_execute( full_filepath=dag.fileloc, dag_id=dag.dag_id ) @provide_session def _emit_pool_metrics(self, session: Session = None) -> None: pools = models.Pool.slots_stats(session=session) for pool_name, slot_stats in pools.items(): Stats.gauge(f'pool.open_slots.{pool_name}', slot_stats["open"]) Stats.gauge(f'pool.queued_slots.{pool_name}', slot_stats[State.QUEUED]) # type: ignore Stats.gauge(f'pool.running_slots.{pool_name}', slot_stats[State.RUNNING]) # type: ignore @provide_session def heartbeat_callback(self, session: Session = None) -> None: Stats.incr('scheduler_heartbeat', 1, 1) @provide_session def adopt_or_reset_orphaned_tasks(self, session: Session = None): """ Reset any TaskInstance still in QUEUED or SCHEDULED states that were enqueued by a SchedulerJob that is no longer running. :return: the number of TIs reset :rtype: int """ self.log.info("Resetting orphaned tasks for active dag runs") timeout = conf.getint('scheduler', 'scheduler_health_check_threshold') for attempt in run_with_db_retries(logger=self.log): with attempt: self.log.debug( "Running SchedulerJob.adopt_or_reset_orphaned_tasks with retries. Try %d of %d", attempt.retry_state.attempt_number, MAX_DB_RETRIES, ) self.log.debug("Calling SchedulerJob.adopt_or_reset_orphaned_tasks method") try: num_failed = ( session.query(SchedulerJob) .filter( SchedulerJob.state == State.RUNNING, SchedulerJob.latest_heartbeat < (timezone.utcnow() - timedelta(seconds=timeout)), ) .update({"state": State.FAILED}) ) if num_failed: self.log.info("Marked %d SchedulerJob instances as failed", num_failed) Stats.incr(self.__class__.__name__.lower() + '_end', num_failed) resettable_states = [State.QUEUED, State.RUNNING] query = ( session.query(TI) .filter(TI.state.in_(resettable_states)) # outerjoin is because we didn't use to have queued_by_job # set, so we need to pick up anything pre upgrade. This (and the # "or queued_by_job_id IS NONE") can go as soon as scheduler HA is # released. .outerjoin(TI.queued_by_job) .filter(or_(TI.queued_by_job_id.is_(None), SchedulerJob.state != State.RUNNING)) .join(TI.dag_run) .filter( DagRun.run_type != DagRunType.BACKFILL_JOB, DagRun.state == State.RUNNING, ) .options(load_only(TI.dag_id, TI.task_id, TI.run_id)) ) # Lock these rows, so that another scheduler can't try and adopt these too tis_to_reset_or_adopt = with_row_locks( query, of=TI, session=session, **skip_locked(session=session) ).all() to_reset = self.executor.try_adopt_task_instances(tis_to_reset_or_adopt) reset_tis_message = [] for ti in to_reset: reset_tis_message.append(repr(ti)) ti.state = State.NONE ti.queued_by_job_id = None for ti in set(tis_to_reset_or_adopt) - set(to_reset): ti.queued_by_job_id = self.id Stats.incr('scheduler.orphaned_tasks.cleared', len(to_reset)) Stats.incr('scheduler.orphaned_tasks.adopted', len(tis_to_reset_or_adopt) - len(to_reset)) if to_reset: task_instance_str = '\n\t'.join(reset_tis_message) self.log.info( "Reset the following %s orphaned TaskInstances:\n\t%s", len(to_reset), task_instance_str, ) # Issue SQL/finish "Unit of Work", but let @provide_session # commit (or if passed a session, let caller decide when to commit session.flush() except OperationalError: session.rollback() raise return len(to_reset) @provide_session def check_trigger_timeouts(self, session: Session = None): """ Looks at all tasks that are in the "deferred" state and whose trigger or execution timeout has passed, so they can be marked as failed. """ num_timed_out_tasks = ( session.query(TaskInstance) .filter(TaskInstance.state == State.DEFERRED, TaskInstance.trigger_timeout < timezone.utcnow()) .update( # We have to schedule these to fail themselves so it doesn't # happen inside the scheduler. { "state": State.SCHEDULED, "next_method": "__fail__", "next_kwargs": {"error": "Trigger/execution timeout"}, "trigger_id": None, } ) ) if num_timed_out_tasks: self.log.info("Timed out %i deferred tasks without fired triggers", num_timed_out_tasks)