def verify_integrity(self, session=None): """ Verifies the DagRun by checking for removed tasks or tasks that are not in the database yet. It will set state to removed or add the task if required. """ from airflow.models.taskinstance import TaskInstance # Avoid circular import dag = self.get_dag() tis = self.get_task_instances(session=session) # check for removed or restored tasks task_ids = [] for ti in tis: task_ids.append(ti.task_id) task = None try: task = dag.get_task(ti.task_id) except AirflowException: if ti.state == State.REMOVED: pass # ti has already been removed, just ignore it elif self.state is not State.RUNNING and not dag.partial: self.log.warning("Failed to get task '{}' for dag '{}'. " "Marking it as removed.".format(ti, dag)) Stats.incr( "task_removed_from_dag.{}".format(dag.dag_id), 1, 1) ti.state = State.REMOVED is_task_in_dag = task is not None should_restore_task = is_task_in_dag and ti.state == State.REMOVED if should_restore_task: self.log.info("Restoring task '{}' which was previously " "removed from DAG '{}'".format(ti, dag)) Stats.incr("task_restored_to_dag.{}".format(dag.dag_id), 1, 1) ti.state = State.NONE # check for missing tasks for task in six.itervalues(dag.task_dict): if task.start_date > self.execution_date and not self.is_backfill: continue if task.task_id not in task_ids: Stats.incr( "task_instance_created-{}".format(task.__class__.__name__), 1, 1) ti = TaskInstance(task, self.execution_date) session.add(ti) session.commit()
def verify_integrity(self, session: Session = NEW_SESSION): """ Verifies the DagRun by checking for removed tasks or tasks that are not in the database yet. It will set state to removed or add the task if required. :param session: Sqlalchemy ORM Session :type session: Session """ from airflow.settings import task_instance_mutation_hook dag = self.get_dag() tis = self.get_task_instances(session=session) # check for removed or restored tasks task_ids = set() for ti in tis: task_instance_mutation_hook(ti) task_ids.add(ti.task_id) task = None try: task = dag.get_task(ti.task_id) except AirflowException: if ti.state == State.REMOVED: pass # ti has already been removed, just ignore it elif self.state != State.RUNNING and not dag.partial: self.log.warning( "Failed to get task '%s' for dag '%s'. Marking it as removed.", ti, dag) Stats.incr(f"task_removed_from_dag.{dag.dag_id}", 1, 1) ti.state = State.REMOVED should_restore_task = (task is not None) and ti.state == State.REMOVED if should_restore_task: self.log.info( "Restoring task '%s' which was previously removed from DAG '%s'", ti, dag) Stats.incr(f"task_restored_to_dag.{dag.dag_id}", 1, 1) ti.state = State.NONE session.merge(ti) # check for missing tasks for task in dag.task_dict.values(): if task.start_date > self.execution_date and not self.is_backfill: continue if task.task_id not in task_ids: Stats.incr(f"task_instance_created-{task.task_type}", 1, 1) ti = TI(task, run_id=self.run_id) task_instance_mutation_hook(ti) session.add(ti) try: session.flush() except IntegrityError as err: self.log.info(str(err)) self.log.info( 'Hit IntegrityError while creating the TIs for %s- %s', dag.dag_id, self.run_id) self.log.info('Doing session rollback.') # TODO[HA]: We probably need to savepoint this so we can keep the transaction alive. session.rollback()
def process_file( self, file_path: str, callback_requests: List[CallbackRequest], pickle_dags: bool = False, session: Session = None, ) -> Tuple[int, int]: """ Process a Python file containing Airflow DAGs. This includes: 1. Execute the file and look for DAG objects in the namespace. 2. Execute any Callbacks if passed to this method. 3. Serialize the DAGs and save it to DB (or update existing record in the DB). 4. Pickle the DAG and save it to the DB (if necessary). 5. Mark any DAGs which are no longer present as inactive 6. Record any errors importing the file into ORM :param file_path: the path to the Python file that should be executed :param callback_requests: failure callback to execute :param pickle_dags: whether serialize the DAGs found in the file and save them to the db :param session: Sqlalchemy ORM Session :return: number of dags found, count of import errors :rtype: Tuple[int, int] """ self.log.info("Processing file %s for tasks to queue", file_path) try: dagbag = DagBag(file_path, include_examples=False, include_smart_sensor=False) except Exception: self.log.exception("Failed at reloading the DAG file %s", file_path) Stats.incr('dag_file_refresh_error', 1, 1) return 0, 0 if len(dagbag.dags) > 0: self.log.info("DAG(s) %s retrieved from %s", dagbag.dags.keys(), file_path) else: self.log.warning("No viable dags retrieved from %s", file_path) self.update_import_errors(session, dagbag) return 0, len(dagbag.import_errors) self.execute_callbacks(dagbag, callback_requests) # Save individual DAGs in the ORM dagbag.sync_to_db() if pickle_dags: paused_dag_ids = DagModel.get_paused_dag_ids(dag_ids=dagbag.dag_ids) unpaused_dags: List[DAG] = [ dag for dag_id, dag in dagbag.dags.items() if dag_id not in paused_dag_ids ] for dag in unpaused_dags: dag.pickle(session) # Record import errors into the ORM try: self.update_import_errors(session, dagbag) except Exception: self.log.exception("Error logging import errors!") return len(dagbag.dags), len(dagbag.import_errors)
def manage_slas(self, dag: DAG, session: Session = None) -> None: """ Finding all tasks that have SLAs defined, and sending alert emails where needed. New SLA misses are also recorded in the database. We are assuming that the scheduler runs often, so we only check for tasks that should have succeeded in the past hour. """ self.log.info("Running SLA Checks for %s", dag.dag_id) if not any(isinstance(ti.sla, timedelta) for ti in dag.tasks): self.log.info("Skipping SLA check for %s because no tasks in DAG have SLAs", dag) return qry = ( session.query(TI.task_id, func.max(DR.execution_date).label('max_ti')) .join(TI.dag_run) .with_hint(TI, 'USE INDEX (PRIMARY)', dialect_name='mysql') .filter(TI.dag_id == dag.dag_id) .filter(or_(TI.state == State.SUCCESS, TI.state == State.SKIPPED)) .filter(TI.task_id.in_(dag.task_ids)) .group_by(TI.task_id) .subquery('sq') ) # get recorded SlaMiss recorded_slas_query = set( session.query(SlaMiss.dag_id, SlaMiss.task_id, SlaMiss.execution_date).filter( SlaMiss.dag_id == dag.dag_id, SlaMiss.task_id.in_(dag.task_ids) ) ) max_tis: Iterator[TI] = ( session.query(TI) .join(TI.dag_run) .filter( TI.dag_id == dag.dag_id, TI.task_id == qry.c.task_id, DR.execution_date == qry.c.max_ti, ) ) ts = timezone.utcnow() for ti in max_tis: task = dag.get_task(ti.task_id) if not task.sla: continue if not isinstance(task.sla, timedelta): raise TypeError( f"SLA is expected to be timedelta object, got " f"{type(task.sla)} in {task.dag_id}:{task.task_id}" ) sla_misses = [] next_info = dag.next_dagrun_info(dag.get_run_data_interval(ti.dag_run), restricted=False) if next_info is None: self.log.info("Skipping SLA check for %s because task does not have scheduled date", ti) else: while next_info.logical_date < ts: next_info = dag.next_dagrun_info(next_info.data_interval, restricted=False) if next_info is None: break if (ti.dag_id, ti.task_id, next_info.logical_date) in recorded_slas_query: break if next_info.logical_date + task.sla < ts: sla_miss = SlaMiss( task_id=ti.task_id, dag_id=ti.dag_id, execution_date=next_info.logical_date, timestamp=ts, ) sla_misses.append(sla_miss) if sla_misses: session.add_all(sla_misses) session.commit() slas: List[SlaMiss] = ( session.query(SlaMiss) .filter(SlaMiss.notification_sent == False, SlaMiss.dag_id == dag.dag_id) # noqa .all() ) if slas: sla_dates: List[datetime.datetime] = [sla.execution_date for sla in slas] fetched_tis: List[TI] = ( session.query(TI) .filter(TI.state != State.SUCCESS, TI.execution_date.in_(sla_dates), TI.dag_id == dag.dag_id) .all() ) blocking_tis: List[TI] = [] for ti in fetched_tis: if ti.task_id in dag.task_ids: ti.task = dag.get_task(ti.task_id) blocking_tis.append(ti) else: session.delete(ti) session.commit() task_list = "\n".join(sla.task_id + ' on ' + sla.execution_date.isoformat() for sla in slas) blocking_task_list = "\n".join( ti.task_id + ' on ' + ti.execution_date.isoformat() for ti in blocking_tis ) # Track whether email or any alert notification sent # We consider email or the alert callback as notifications email_sent = False notification_sent = False if dag.sla_miss_callback: # Execute the alert callback self.log.info('Calling SLA miss callback') try: dag.sla_miss_callback(dag, task_list, blocking_task_list, slas, blocking_tis) notification_sent = True except Exception: Stats.incr('sla_callback_notification_failure') self.log.exception("Could not call sla_miss_callback for DAG %s", dag.dag_id) email_content = f"""\ Here's a list of tasks that missed their SLAs: <pre><code>{task_list}\n<code></pre> Blocking tasks: <pre><code>{blocking_task_list}<code></pre> Airflow Webserver URL: {conf.get(section='webserver', key='base_url')} """ tasks_missed_sla = [] for sla in slas: try: task = dag.get_task(sla.task_id) except TaskNotFound: # task already deleted from DAG, skip it self.log.warning( "Task %s doesn't exist in DAG anymore, skipping SLA miss notification.", sla.task_id ) continue tasks_missed_sla.append(task) emails: Set[str] = set() for task in tasks_missed_sla: if task.email: if isinstance(task.email, str): emails |= set(get_email_address_list(task.email)) elif isinstance(task.email, (list, tuple)): emails |= set(task.email) if emails: try: send_email(emails, f"[airflow] SLA miss on DAG={dag.dag_id}", email_content) email_sent = True notification_sent = True except Exception: Stats.incr('sla_email_notification_failure') self.log.exception("Could not send SLA Miss email notification for DAG %s", dag.dag_id) # If we sent any notification, update the sla_miss table if notification_sent: for sla in slas: sla.email_sent = email_sent sla.notification_sent = True session.merge(sla) session.commit()
def heartbeat(self): """ Heartbeats update the job's entry in the database with a timestamp for the latest_heartbeat and allows for the job to be killed externally. This allows at the system level to monitor what is actually active. For instance, an old heartbeat for SchedulerJob would mean something is wrong. This also allows for any job to be killed externally, regardless of who is running it or on which machine it is running. Note that if your heartbeat is set to 60 seconds and you call this method after 10 seconds of processing since the last heartbeat, it will sleep 50 seconds to complete the 60 seconds and keep a steady heart rate. If you go over 60 seconds before calling it, it won't sleep at all. """ previous_heartbeat = self.latest_heartbeat try: with create_session() as session: # This will cause it to load from the db session.merge(self) previous_heartbeat = self.latest_heartbeat if self.state == State.SHUTDOWN: self.kill() is_unit_test = conf.getboolean('core', 'unit_test_mode') if not is_unit_test: # Figure out how long to sleep for sleep_for = 0 if self.latest_heartbeat: seconds_remaining = self.heartrate - \ (timezone.utcnow() - self.latest_heartbeat)\ .total_seconds() sleep_for = max(0, seconds_remaining) sleep(sleep_for) # Update last heartbeat time with create_session() as session: # Make the sesion aware of this object session.merge(self) self.latest_heartbeat = timezone.utcnow() session.commit() # At this point, the DB has updated. previous_heartbeat = self.latest_heartbeat self.heartbeat_callback(session=session) self.log.debug('[heartbeat]') except OperationalError: Stats.incr( convert_camel_to_snake(self.__class__.__name__) + '_heartbeat_failure', 1, 1) self.log.exception("%s heartbeat got an exception", self.__class__.__name__) # We didn't manage to heartbeat, so make sure that the timestamp isn't updated self.latest_heartbeat = previous_heartbeat
def _execute(self): self.task_runner = get_task_runner(self) def signal_handler(signum, frame): """Setting kill signal handler""" self.log.error("Received SIGTERM. Terminating subprocesses") self.task_runner.terminate() self.handle_task_exit(128 + signum) return signal.signal(signal.SIGTERM, signal_handler) if not self.task_instance.check_and_change_state_before_execution( mark_success=self.mark_success, ignore_all_deps=self.ignore_all_deps, ignore_depends_on_past=self.ignore_depends_on_past, ignore_task_deps=self.ignore_task_deps, ignore_ti_state=self.ignore_ti_state, job_id=self.id, pool=self.pool, external_executor_id=self.external_executor_id, ): self.log.info("Task is not able to be run") return try: self.task_runner.start() heartbeat_time_limit = conf.getint( 'scheduler', 'scheduler_zombie_task_threshold') # task callback invocation happens either here or in # self.heartbeat() instead of taskinstance._run_raw_task to # avoid race conditions # # When self.terminating is set to True by heartbeat_callback, this # loop should not be restarted. Otherwise self.handle_task_exit # will be invoked and we will end up with duplicated callbacks while not self.terminating: # Monitor the task to see if it's done. Wait in a syscall # (`os.wait`) for as long as possible so we notice the # subprocess finishing as quick as we can max_wait_time = max( 0, # Make sure this value is never negative, min( (heartbeat_time_limit - (timezone.utcnow() - self.latest_heartbeat).total_seconds() * 0.75), self.heartrate, ), ) return_code = self.task_runner.return_code( timeout=max_wait_time) if return_code is not None: self.handle_task_exit(return_code) return self.heartbeat() # If it's been too long since we've heartbeat, then it's possible that # the scheduler rescheduled this task, so kill launched processes. # This can only really happen if the worker can't read the DB for a long time time_since_last_heartbeat = ( timezone.utcnow() - self.latest_heartbeat).total_seconds() if time_since_last_heartbeat > heartbeat_time_limit: Stats.incr('local_task_job_prolonged_heartbeat_failure', 1, 1) self.log.error("Heartbeat time limit exceeded!") raise AirflowException( "Time since last heartbeat({:.2f}s) " "exceeded limit ({}s).".format( time_since_last_heartbeat, heartbeat_time_limit)) finally: self.on_kill()
def _execute(self): self.task_runner = get_task_runner(self) # pylint: disable=unused-argument def signal_handler(signum, frame): """Setting kill signal handler""" self.log.error("Received SIGTERM. Terminating subprocesses") self.on_kill() raise AirflowException("LocalTaskJob received SIGTERM signal") # pylint: enable=unused-argument signal.signal(signal.SIGTERM, signal_handler) if not self.task_instance.check_and_change_state_before_execution( mark_success=self.mark_success, ignore_all_deps=self.ignore_all_deps, ignore_depends_on_past=self.ignore_depends_on_past, ignore_task_deps=self.ignore_task_deps, ignore_ti_state=self.ignore_ti_state, job_id=self.id, pool=self.pool, ): self.log.info("Task is not able to be run") return try: self.task_runner.start() heartbeat_time_limit = conf.getint( 'scheduler', 'scheduler_zombie_task_threshold') while True: # Monitor the task to see if it's done. Wait in a syscall # (`os.wait`) for as long as possible so we notice the # subprocess finishing as quick as we can max_wait_time = max( 0, # Make sure this value is never negative, min( (heartbeat_time_limit - (timezone.utcnow() - self.latest_heartbeat).total_seconds() * 0.75), self.heartrate, ), ) return_code = self.task_runner.return_code( timeout=max_wait_time) if return_code is not None: self.log.info("Task exited with return code %s", return_code) return self.heartbeat() # If it's been too long since we've heartbeat, then it's possible that # the scheduler rescheduled this task, so kill launched processes. # This can only really happen if the worker can't read the DB for a long time time_since_last_heartbeat = ( timezone.utcnow() - self.latest_heartbeat).total_seconds() if time_since_last_heartbeat > heartbeat_time_limit: Stats.incr('local_task_job_prolonged_heartbeat_failure', 1, 1) self.log.error("Heartbeat time limit exceeded!") raise AirflowException( "Time since last heartbeat({:.2f}s) " "exceeded limit ({}s).".format( time_since_last_heartbeat, heartbeat_time_limit)) finally: self.on_kill()
def _execute(self): self.task_runner = get_task_runner(self) def signal_handler(signum, frame): """Setting kill signal handler""" self.log.error("Received SIGTERM. Terminating subprocesses") self.on_kill() raise AirflowException("LocalTaskJob received SIGTERM signal") signal.signal(signal.SIGTERM, signal_handler) if not self.task_instance._check_and_change_state_before_execution( mark_success=self.mark_success, ignore_all_deps=self.ignore_all_deps, ignore_depends_on_past=self.ignore_depends_on_past, ignore_task_deps=self.ignore_task_deps, ignore_ti_state=self.ignore_ti_state, job_id=self.id, pool=self.pool): self.log.info("Task is not able to be run") return try: self.task_runner.start() last_heartbeat_time = time.time() heartbeat_time_limit = conf.getint( 'scheduler', 'scheduler_zombie_task_threshold') while True: # Monitor the task to see if it's done return_code = self.task_runner.return_code() if return_code is not None: self.log.info("Task exited with return code %s", return_code) return # Periodically heartbeat so that the scheduler doesn't think this # is a zombie try: self.heartbeat() last_heartbeat_time = time.time() except OperationalError: Stats.incr('local_task_job_heartbeat_failure', 1, 1) self.log.exception( "Exception while trying to heartbeat! Sleeping for %s seconds", self.heartrate) time.sleep(self.heartrate) # If it's been too long since we've heartbeat, then it's possible that # the scheduler rescheduled this task, so kill launched processes. time_since_last_heartbeat = time.time() - last_heartbeat_time if time_since_last_heartbeat > heartbeat_time_limit: Stats.incr('local_task_job_prolonged_heartbeat_failure', 1, 1) self.log.error("Heartbeat time limited exceeded!") raise AirflowException( "Time since last heartbeat({:.2f}s) " "exceeded limit ({}s).".format( time_since_last_heartbeat, heartbeat_time_limit)) finally: self.on_kill()
def heartbeat(self): """ This should be periodically called by the manager loop. This method will kick off new processes to process DAG definition files and read the results from the finished processors. :return: a list of SimpleDags that were produced by processors that have finished since the last time this was called :rtype: list[airflow.utils.dag_processing.SimpleDag] """ simple_dags = self.collect_results() # Generate more file paths to process if we processed all the files # already. if not self._file_path_queue: self.emit_metrics() self._parsing_start_time = timezone.utcnow() # If the file path is already being processed, or if a file was # processed recently, wait until the next batch file_paths_in_progress = self._processors.keys() now = timezone.utcnow() file_paths_recently_processed = [] for file_path in self._file_paths: last_finish_time = self.get_last_finish_time(file_path) if (last_finish_time is not None and (now - last_finish_time).total_seconds() < self._file_process_interval): file_paths_recently_processed.append(file_path) files_paths_at_run_limit = [ file_path for file_path, stat in self._file_stats.items() if stat.run_count == self._max_runs ] files_paths_to_queue = list( set(self._file_paths) - set(file_paths_in_progress) - set(file_paths_recently_processed) - set(files_paths_at_run_limit)) for file_path, processor in self._processors.items(): self.log.debug( "File path %s is still being processed (started: %s)", processor.file_path, processor.start_time.isoformat()) self.log.debug("Queuing the following files for processing:\n\t%s", "\n\t".join(files_paths_to_queue)) for file_path in files_paths_to_queue: if file_path not in self._file_stats: self._file_stats[file_path] = DagFileStat( 0, 0, None, None, 0) self._file_path_queue.extend(files_paths_to_queue) # Start more processors if we have enough slots and files to process while self._parallelism - len( self._processors) > 0 and self._file_path_queue: file_path = self._file_path_queue.pop(0) processor = self._processor_factory(file_path, self._zombies) Stats.incr('dag_processing.processes') processor.start() self.log.debug( "Started a process (PID: %s) to generate tasks for %s", processor.pid, file_path) self._processors[file_path] = processor # Update heartbeat count. self._heartbeat_count += 1 return simple_dags
def verify_integrity(self, session: Session = None): """ Verifies the DagRun by checking for removed tasks or tasks that are not in the database yet. It will set state to removed or add the task if required. :param session: Sqlalchemy ORM Session :type session: Session """ dag = self.get_dag() tis = self.get_task_instances(session=session) # check for removed or restored tasks task_ids = set() for ti in tis: task_instance_mutation_hook(ti) task_ids.add(ti.task_id) task = None try: task = dag.get_task(ti.task_id) except AirflowException: if ti.state == State.REMOVED: pass # ti has already been removed, just ignore it elif self.state is not State.RUNNING and not dag.partial: self.log.warning( "Failed to get task '%s' for dag '%s'. " "Marking it as removed.", ti, dag) Stats.incr("task_removed_from_dag.{}".format(dag.dag_id), 1, 1) ti.state = State.REMOVED should_restore_task = (task is not None) and ti.state == State.REMOVED if should_restore_task: self.log.info( "Restoring task '%s' which was previously " "removed from DAG '%s'", ti, dag) Stats.incr("task_restored_to_dag.{}".format(dag.dag_id), 1, 1) ti.state = State.NONE session.merge(ti) # check for missing tasks for task in dag.task_dict.values(): if task.start_date > self.execution_date and not self.is_backfill: continue if task.task_id not in task_ids: Stats.incr( "task_instance_created-{}".format(task.__class__.__name__), 1, 1) ti = TI(task, self.execution_date) task_instance_mutation_hook(ti) session.add(ti) try: session.commit() except IntegrityError as err: self.log.info(str(err)) self.log.info('Hit IntegrityError while creating the TIs for ' f'{dag.dag_id} - {self.execution_date}.') self.log.info('Doing session rollback.') session.rollback()
def heartbeat_callback(self, session: Session = None) -> None: Stats.incr('scheduler_heartbeat', 1, 1)
def verify_integrity(self, session: Session = NEW_SESSION): """ Verifies the DagRun by checking for removed tasks or tasks that are not in the database yet. It will set state to removed or add the task if required. :param session: Sqlalchemy ORM Session """ from airflow.settings import task_instance_mutation_hook dag = self.get_dag() tis = self.get_task_instances(session=session) # check for removed or restored tasks task_ids = set() for ti in tis: task_instance_mutation_hook(ti) task_ids.add(ti.task_id) task = None try: task = dag.get_task(ti.task_id) except AirflowException: if ti.state == State.REMOVED: pass # ti has already been removed, just ignore it elif self.state != State.RUNNING and not dag.partial: self.log.warning( "Failed to get task '%s' for dag '%s'. Marking it as removed.", ti, dag) Stats.incr(f"task_removed_from_dag.{dag.dag_id}", 1, 1) ti.state = State.REMOVED should_restore_task = (task is not None) and ti.state == State.REMOVED if should_restore_task: self.log.info( "Restoring task '%s' which was previously removed from DAG '%s'", ti, dag) Stats.incr(f"task_restored_to_dag.{dag.dag_id}", 1, 1) ti.state = State.NONE session.merge(ti) def task_filter(task: "BaseOperator"): return task.task_id not in task_ids and ( self.is_backfill or task.start_date <= self.execution_date) created_counts: Dict[str, int] = defaultdict(int) # Set for the empty default in airflow.settings -- if it's not set this means it has been changed hook_is_noop = getattr(task_instance_mutation_hook, 'is_noop', False) if hook_is_noop: def create_ti_mapping(task: "BaseOperator"): created_counts[task.task_type] += 1 return TI.insert_mapping(self.run_id, task) else: def create_ti(task: "BaseOperator") -> TI: ti = TI(task, run_id=self.run_id) task_instance_mutation_hook(ti) created_counts[ti.operator] += 1 return ti # Create missing tasks tasks = list(filter(task_filter, dag.task_dict.values())) try: if hook_is_noop: session.bulk_insert_mappings(TI, map(create_ti_mapping, tasks)) else: session.bulk_save_objects(map(create_ti, tasks)) for task_type, count in created_counts.items(): Stats.incr(f"task_instance_created-{task_type}", count) session.flush() except IntegrityError: self.log.info( 'Hit IntegrityError while creating the TIs for %s- %s', dag.dag_id, self.run_id, exc_info=True, ) self.log.info('Doing session rollback.') # TODO[HA]: We probably need to savepoint this so we can keep the transaction alive. session.rollback()
def _do_scheduling(self, session) -> int: """ This function is where the main scheduling decisions take places. It: - Creates any necessary DAG runs by examining the next_dagrun_create_after column of DagModel Since creating Dag Runs is a relatively time consuming process, we select only 10 dags by default (configurable via ``scheduler.max_dagruns_to_create_per_loop`` setting) - putting this higher will mean one scheduler could spend a chunk of time creating dag runs, and not ever get around to scheduling tasks. - Finds the "next n oldest" running DAG Runs to examine for scheduling (n=20 by default, configurable via ``scheduler.max_dagruns_per_loop_to_schedule`` config setting) and tries to progress state (TIs to SCHEDULED, or DagRuns to SUCCESS/FAILURE etc) By "next oldest", we mean hasn't been examined/scheduled in the most time. The reason we don't select all dagruns at once because the rows are selected with row locks, meaning that only one scheduler can "process them", even it is waiting behind other dags. Increasing this limit will allow more throughput for smaller DAGs but will likely slow down throughput for larger (>500 tasks.) DAGs - Then, via a Critical Section (locking the rows of the Pool model) we queue tasks, and then send them to the executor. See docs of _critical_section_execute_task_instances for more. :return: Number of TIs enqueued in this iteration :rtype: int """ # Put a check in place to make sure we don't commit unexpectedly with prohibit_commit(session) as guard: if settings.USE_JOB_SCHEDULE: self._create_dagruns_for_dags(guard, session) self._start_queued_dagruns(session) guard.commit() dag_runs = self._get_next_dagruns_to_examine(State.RUNNING, session) # Bulk fetch the currently active dag runs for the dags we are # examining, rather than making one query per DagRun callback_tuples = [] for dag_run in dag_runs: callback_to_run = self._schedule_dag_run(dag_run, session) callback_tuples.append((dag_run, callback_to_run)) guard.commit() # Send the callbacks after we commit to ensure the context is up to date when it gets run for dag_run, callback_to_run in callback_tuples: self._send_dag_callbacks_to_processor(dag_run, callback_to_run) # Without this, the session has an invalid view of the DB session.expunge_all() # END: schedule TIs try: if self.executor.slots_available <= 0: # We know we can't do anything here, so don't even try! self.log.debug("Executor full, skipping critical section") return 0 timer = Stats.timer('scheduler.critical_section_duration') timer.start() # Find anything TIs in state SCHEDULED, try to QUEUE it (send it to the executor) num_queued_tis = self._critical_section_execute_task_instances(session=session) # Make sure we only sent this metric if we obtained the lock, otherwise we'll skew the # metric, way down timer.stop(send=True) except OperationalError as e: timer.stop(send=False) if is_lock_not_available_error(error=e): self.log.debug("Critical section lock held by another Scheduler") Stats.incr('scheduler.critical_section_busy') session.rollback() return 0 raise guard.commit() return num_queued_tis
def _process_executor_events(self, session: Session = None) -> int: """Respond to executor events.""" if not self.processor_agent: raise ValueError("Processor agent is not started.") ti_primary_key_to_try_number_map: Dict[Tuple[str, str, datetime.datetime], int] = {} event_buffer = self.executor.get_event_buffer() tis_with_right_state: List[TaskInstanceKey] = [] # Report execution for ti_key, value in event_buffer.items(): state: str state, _ = value # We create map (dag_id, task_id, execution_date) -> in-memory try_number ti_primary_key_to_try_number_map[ti_key.primary] = ti_key.try_number self.log.info( "Executor reports execution of %s.%s run_id=%s exited with status %s for try_number %s", ti_key.dag_id, ti_key.task_id, ti_key.run_id, state, ti_key.try_number, ) if state in (State.FAILED, State.SUCCESS, State.QUEUED): tis_with_right_state.append(ti_key) # Return if no finished tasks if not tis_with_right_state: return len(event_buffer) # Check state of finished tasks filter_for_tis = TI.filter_for_tis(tis_with_right_state) tis: List[TI] = session.query(TI).filter(filter_for_tis).options(selectinload('dag_model')).all() for ti in tis: try_number = ti_primary_key_to_try_number_map[ti.key.primary] buffer_key = ti.key.with_try_number(try_number) state, info = event_buffer.pop(buffer_key) # TODO: should we fail RUNNING as well, as we do in Backfills? if state == State.QUEUED: ti.external_executor_id = info self.log.info("Setting external_id for %s to %s", ti, info) continue if ti.try_number == buffer_key.try_number and ti.state == State.QUEUED: Stats.incr('scheduler.tasks.killed_externally') msg = ( "Executor reports task instance %s finished (%s) although the " "task says its %s. (Info: %s) Was the task killed externally?" ) self.log.error(msg, ti, state, ti.state, info) # Get task from the Serialized DAG try: dag = self.dagbag.get_dag(ti.dag_id) task = dag.get_task(ti.task_id) except Exception: self.log.exception("Marking task instance %s as %s", ti, state) ti.set_state(state) continue ti.task = task if task.on_retry_callback or task.on_failure_callback: request = TaskCallbackRequest( full_filepath=ti.dag_model.fileloc, simple_task_instance=SimpleTaskInstance(ti), msg=msg % (ti, state, ti.state, info), ) self.processor_agent.send_callback_to_execute(request) else: ti.handle_failure(error=msg % (ti, state, ti.state, info), session=session) return len(event_buffer)
def adopt_or_reset_orphaned_tasks(self, session: Session = None): """ Reset any TaskInstance still in QUEUED or SCHEDULED states that were enqueued by a SchedulerJob that is no longer running. :return: the number of TIs reset :rtype: int """ self.log.info("Resetting orphaned tasks for active dag runs") timeout = conf.getint('scheduler', 'scheduler_health_check_threshold') for attempt in run_with_db_retries(logger=self.log): with attempt: self.log.debug( "Running SchedulerJob.adopt_or_reset_orphaned_tasks with retries. Try %d of %d", attempt.retry_state.attempt_number, MAX_DB_RETRIES, ) self.log.debug("Calling SchedulerJob.adopt_or_reset_orphaned_tasks method") try: num_failed = ( session.query(SchedulerJob) .filter( SchedulerJob.state == State.RUNNING, SchedulerJob.latest_heartbeat < (timezone.utcnow() - timedelta(seconds=timeout)), ) .update({"state": State.FAILED}) ) if num_failed: self.log.info("Marked %d SchedulerJob instances as failed", num_failed) Stats.incr(self.__class__.__name__.lower() + '_end', num_failed) resettable_states = [State.QUEUED, State.RUNNING] query = ( session.query(TI) .filter(TI.state.in_(resettable_states)) # outerjoin is because we didn't use to have queued_by_job # set, so we need to pick up anything pre upgrade. This (and the # "or queued_by_job_id IS NONE") can go as soon as scheduler HA is # released. .outerjoin(TI.queued_by_job) .filter(or_(TI.queued_by_job_id.is_(None), SchedulerJob.state != State.RUNNING)) .join(TI.dag_run) .filter( DagRun.run_type != DagRunType.BACKFILL_JOB, DagRun.state == State.RUNNING, ) .options(load_only(TI.dag_id, TI.task_id, TI.run_id)) ) # Lock these rows, so that another scheduler can't try and adopt these too tis_to_reset_or_adopt = with_row_locks( query, of=TI, session=session, **skip_locked(session=session) ).all() to_reset = self.executor.try_adopt_task_instances(tis_to_reset_or_adopt) reset_tis_message = [] for ti in to_reset: reset_tis_message.append(repr(ti)) ti.state = State.NONE ti.queued_by_job_id = None for ti in set(tis_to_reset_or_adopt) - set(to_reset): ti.queued_by_job_id = self.id Stats.incr('scheduler.orphaned_tasks.cleared', len(to_reset)) Stats.incr('scheduler.orphaned_tasks.adopted', len(tis_to_reset_or_adopt) - len(to_reset)) if to_reset: task_instance_str = '\n\t'.join(reset_tis_message) self.log.info( "Reset the following %s orphaned TaskInstances:\n\t%s", len(to_reset), task_instance_str, ) # Issue SQL/finish "Unit of Work", but let @provide_session # commit (or if passed a session, let caller decide when to commit session.flush() except OperationalError: session.rollback() raise return len(to_reset)