def test_backfill_pooled_tasks(self): """ Test that queued tasks are executed by BackfillJob Test for https://github.com/airbnb/airflow/pull/1225 """ session = settings.Session() pool = Pool(pool='test_backfill_pooled_task_pool', slots=1) session.add(pool) session.commit() dag = self.dagbag.get_dag('test_backfill_pooled_task_dag') dag.clear() job = BackfillJob( dag=dag, start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) # run with timeout because this creates an infinite loop if not # caught with timeout(seconds=30): job.run() ti = TI( task=dag.get_task('test_backfill_pooled_task'), execution_date=DEFAULT_DATE) ti.refresh_from_db() self.assertEqual(ti.state, State.SUCCESS)
def send_task_to_executor(task_tuple): key, simple_ti, command, queue, task = task_tuple try: with timeout(seconds=2): result = task.apply_async(args=[command], queue=queue) except Exception as e: exception_traceback = "Celery Task ID: {}\n{}".format(key, traceback.format_exc()) result = ExceptionWithTraceback(e, exception_traceback) return key, command, result
def send_task_to_executor(task_tuple): key, simple_ti, command, queue, task = task_tuple try: with timeout(seconds=2): result = task.apply_async(args=[command], queue=queue) except Exception as e: exception_traceback = "Celery Task ID: {}\n{}".format( key, traceback.format_exc()) result = ExceptionWithTraceback(e, exception_traceback) return key, command, result
def test_mark_failure_on_failure_callback(self): """ Test that ensures that mark_failure in the UI fails the task, and executes on_failure_callback """ data = {'called': False} def check_failure(context): self.assertEqual(context['dag_run'].dag_id, 'test_mark_failure') data['called'] = True def task_function(ti): print("python_callable run in pid %s", os.getpid()) with create_session() as session: self.assertEqual(State.RUNNING, ti.state) ti.log.info("Marking TI as failed 'externally'") ti.state = State.FAILED session.merge(ti) session.commit() time.sleep(60) # This should not happen -- the state change should be noticed and the task should get killed data['reached_end_of_sleep'] = True with DAG(dag_id='test_mark_failure', start_date=DEFAULT_DATE) as dag: task = PythonOperator(task_id='test_state_succeeded1', python_callable=task_function, on_failure_callback=check_failure) session = settings.Session() dag.clear() dag.create_dagrun(run_id="test", state=State.RUNNING, execution_date=DEFAULT_DATE, start_date=DEFAULT_DATE, session=session) ti = TaskInstance(task=task, execution_date=DEFAULT_DATE) ti.refresh_from_db() job1 = LocalTaskJob(task_instance=ti, ignore_ti_state=True, executor=SequentialExecutor()) with timeout(30): # This should be _much_ shorter to run. # If you change this limit, make the timeout in the callbable above bigger job1.run() ti.refresh_from_db() self.assertEqual(ti.state, State.FAILED) self.assertTrue(data['called']) self.assertNotIn( 'reached_end_of_sleep', data, 'Task should not have been allowed to run to completion')
def send_task_to_executor(task_tuple: TaskInstanceInCelery) \ -> Tuple[TaskInstanceKey, CommandType, Union[AsyncResult, ExceptionWithTraceback]]: """Sends task to executor.""" key, _, command, queue, task_to_run = task_tuple try: with timeout(seconds=OPERATION_TIMEOUT): result = task_to_run.apply_async(args=[command], queue=queue) except Exception as e: # pylint: disable=broad-except exception_traceback = "Celery Task ID: {}\n{}".format(key, traceback.format_exc()) result = ExceptionWithTraceback(e, exception_traceback) return key, command, result
def _load_modules_from_file(self, filepath, safe_mode): if not might_contain_dag(filepath, safe_mode): # Don't want to spam user with skip messages if not self.has_logged: self.has_logged = True self.log.info("File %s assumed to contain no DAGs. Skipping.", filepath) return [] self.log.debug("Importing %s", filepath) org_mod_name, _ = os.path.splitext(os.path.split(filepath)[-1]) path_hash = hashlib.sha1(filepath.encode('utf-8')).hexdigest() mod_name = f'unusual_prefix_{path_hash}_{org_mod_name}' if mod_name in sys.modules: del sys.modules[mod_name] def parse(mod_name, filepath): try: loader = importlib.machinery.SourceFileLoader( mod_name, filepath) spec = importlib.util.spec_from_loader(mod_name, loader) new_module = importlib.util.module_from_spec(spec) sys.modules[spec.name] = new_module loader.exec_module(new_module) return [new_module] except Exception as e: self.log.exception("Failed to import: %s", filepath) if self.dagbag_import_error_tracebacks: self.import_errors[filepath] = traceback.format_exc( limit=-self.dagbag_import_error_traceback_depth) else: self.import_errors[filepath] = str(e) return [] dagbag_import_timeout = settings.get_dagbag_import_timeout(filepath) if not isinstance(dagbag_import_timeout, (int, float)): raise TypeError( f'Value ({dagbag_import_timeout}) from get_dagbag_import_timeout must be int or float' ) if dagbag_import_timeout <= 0: # no parsing timeout return parse(mod_name, filepath) timeout_msg = ( f"DagBag import timeout for {filepath} after {dagbag_import_timeout}s.\n" "Please take a look at these docs to improve your DAG import time:\n" f"* {get_docs_url('best-practices.html#top-level-python-code')}\n" f"* {get_docs_url('best-practices.html#reducing-dag-complexity')}") with timeout(dagbag_import_timeout, error_message=timeout_msg): return parse(mod_name, filepath)
def test_subdag_clear_parentdag_downstream_clear(self): dag = self.dagbag.get_dag('example_subdag_operator') subdag_op_task = dag.get_task('section-1') subdag = subdag_op_task.subdag subdag.schedule_interval = '@daily' executor = TestExecutor() job = BackfillJob(dag=subdag, start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, executor=executor, donot_pickle=True) with timeout(seconds=30): job.run() ti0 = TI( task=subdag.get_task('section-1-task-1'), execution_date=DEFAULT_DATE) ti0.refresh_from_db() self.assertEqual(ti0.state, State.SUCCESS) sdag = subdag.sub_dag( task_regex='section-1-task-1', include_downstream=True, include_upstream=False) sdag.clear( start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, include_parentdag=True) ti0.refresh_from_db() self.assertEqual(State.NONE, ti0.state) ti1 = TI( task=dag.get_task('some-other-task'), execution_date=DEFAULT_DATE) self.assertEqual(State.NONE, ti1.state) # Checks that all the Downstream tasks for Parent DAG # have been cleared for task in subdag_op_task.downstream_list: ti = TI( task=dag.get_task(task.task_id), execution_date=DEFAULT_DATE ) self.assertEqual(State.NONE, ti.state) subdag.clear() dag.clear()
def test_subdag_clear_parentdag_downstream_clear(self): dag = self.dagbag.get_dag('clear_subdag_test_dag') subdag_op_task = dag.get_task('daily_job') subdag = subdag_op_task.subdag executor = MockExecutor() job = BackfillJob(dag=dag, start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, executor=executor, donot_pickle=True) with timeout(seconds=30): job.run() ti_subdag = TI(task=dag.get_task('daily_job'), execution_date=DEFAULT_DATE) ti_subdag.refresh_from_db() self.assertEqual(ti_subdag.state, State.SUCCESS) ti_irrelevant = TI(task=dag.get_task('daily_job_irrelevant'), execution_date=DEFAULT_DATE) ti_irrelevant.refresh_from_db() self.assertEqual(ti_irrelevant.state, State.SUCCESS) ti_downstream = TI(task=dag.get_task('daily_job_downstream'), execution_date=DEFAULT_DATE) ti_downstream.refresh_from_db() self.assertEqual(ti_downstream.state, State.SUCCESS) sdag = subdag.sub_dag(task_regex='daily_job_subdag_task', include_downstream=True, include_upstream=False) sdag.clear(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, include_parentdag=True) ti_subdag.refresh_from_db() self.assertEqual(State.NONE, ti_subdag.state) ti_irrelevant.refresh_from_db() self.assertEqual(State.SUCCESS, ti_irrelevant.state) ti_downstream.refresh_from_db() self.assertEqual(State.NONE, ti_downstream.state) subdag.clear() dag.clear()
def _execute_task(self, context, task_copy): """Executes Task (optionally with a Timeout) and pushes Xcom results""" # If a timeout is specified for the task, make it fail # if it goes beyond if task_copy.execution_timeout: try: with timeout(task_copy.execution_timeout.total_seconds()): result = task_copy.execute(context=context) except AirflowTaskTimeout: task_copy.on_kill() raise else: result = task_copy.execute(context=context) # If the task returns a result, push an XCom containing it # if task_copy.do_xcom_push and result is not None: # self.xcom_push(key=XCOM_RETURN_KEY, value=result) return result
def test_backfill_execute_subdag_with_removed_task(self): """ Ensure that subdag operators execute properly in the case where an associated task of the subdag has been removed from the dag definition, but has instances in the database from previous runs. """ dag = self.dagbag.get_dag('example_subdag_operator') subdag = dag.get_task('section-1').subdag executor = MockExecutor() job = BackfillJob(dag=subdag, start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, executor=executor, donot_pickle=True) removed_task_ti = TI( task=DummyOperator(task_id='removed_task'), execution_date=DEFAULT_DATE, state=State.REMOVED) removed_task_ti.dag_id = subdag.dag_id session = settings.Session() session.merge(removed_task_ti) session.commit() with timeout(seconds=30): job.run() for task in subdag.tasks: instance = session.query(TI).filter( TI.dag_id == subdag.dag_id, TI.task_id == task.task_id, TI.execution_date == DEFAULT_DATE).first() self.assertIsNotNone(instance) self.assertEqual(instance.state, State.SUCCESS) removed_task_ti.refresh_from_db() self.assertEqual(removed_task_ti.state, State.REMOVED) subdag.clear() dag.clear()
def process_file(self, filepath, only_if_updated=True, safe_mode=False): found_templates = [] if filepath is None or not os.path.isfile(filepath): return found_templates try: # This failed before in what may have been a git sync # race condition file_last_changed_on_disk = datetime.fromtimestamp( os.path.getmtime(filepath)) if only_if_updated and file_last_changed_on_disk == self.file_last_changed.get( filepath, None): return found_templates except Exception as e: self.log.exception(e) return found_templates if safe_mode: # TODO: heuristic to process only if file contains template pass self.log.debug(f'Importing {filepath}') modname, _ = os.path.splitext(os.path.split(filepath)[-1]) mods = [] with timeout(self.TEMPLATE_IMPORT_TIMEOUT): try: m = imp.load_source(modname, filepath) mods.append(m) except Exception as e: self.log.exception(f'Failed to import: {filepath}') self.import_errors[filepath] = str(e) self.file_last_changed[filepath] = file_last_changed_on_disk for mod in mods: for val in list(m.__dict__.values()): if isinstance(val, type) and val != BaseDagTemplate and issubclass( val, BaseDagTemplate): tmpl = val self._add_template(tmpl) found_templates.append(tmpl) self.file_last_changed[filepath] = file_last_changed_on_disk return found_templates
def fetch_celery_task_state(async_result: AsyncResult) -> Tuple[str, Union[str, ExceptionWithTraceback], Any]: """ Fetch and return the state of the given celery task. The scope of this function is global so that it can be called by subprocesses in the pool. :param async_result: a tuple of the Celery task key and the async Celery object used to fetch the task's state :return: a tuple of the Celery task key and the Celery state and the celery info of the task :rtype: tuple[str, str, str] """ try: with timeout(seconds=OPERATION_TIMEOUT): # Accessing state property of celery task will make actual network request # to get the current state of the task info = async_result.info if hasattr(async_result, 'info') else None return async_result.task_id, async_result.state, info except Exception as e: exception_traceback = f"Celery Task ID: {async_result}\n{traceback.format_exc()}" return async_result.task_id, ExceptionWithTraceback(e, exception_traceback), None
def fetch_celery_task_state(celery_task): """ Fetch and return the state of the given celery task. The scope of this function is global so that it can be called by subprocesses in the pool. :param celery_task: a tuple of the Celery task key and the async Celery object used to fetch the task's state :type celery_task: (str, celery.result.AsyncResult) :return: a tuple of the Celery task key and the Celery state of the task :rtype: (str, str) """ try: with timeout(seconds=2): # Accessing state property of celery task will make actual network request # to get the current state of the task. res = (celery_task[0], celery_task[1].state) except Exception as e: exception_traceback = "Celery Task ID: {}\n{}".format( celery_task[0], traceback.format_exc()) res = ExceptionWithTraceback(e, exception_traceback) return res
def fetch_celery_task_state(celery_task): """ Fetch and return the state of the given celery task. The scope of this function is global so that it can be called by subprocesses in the pool. :param celery_task: a tuple of the Celery task key and the async Celery object used to fetch the task's state :type celery_task: (str, celery.result.AsyncResult) :return: a tuple of the Celery task key and the Celery state of the task :rtype: (str, str) """ try: with timeout(seconds=2): # Accessing state property of celery task will make actual network request # to get the current state of the task. res = (celery_task[0], celery_task[1].state) except Exception as e: exception_traceback = "Celery Task ID: {}\n{}".format(celery_task[0], traceback.format_exc()) res = ExceptionWithTraceback(e, exception_traceback) return res
def execution_parallelism(self, parallelism=0): executor = LocalExecutor(parallelism=parallelism) executor.start() success_key = 'success {}' success_command = ['true', 'some_parameter'] fail_command = ['false', 'some_parameter'] for i in range(self.TEST_SUCCESS_COMMANDS): key, command = success_key.format(i), success_command executor.execute_async(key=key, command=command) executor.running[key] = True # errors are propagated for some reason try: executor.execute_async(key='fail', command=fail_command) except Exception: pass executor.running['fail'] = True if parallelism == 0: with timeout(seconds=10): executor.end() else: executor.end() if isinstance(executor.impl, LocalExecutor._LimitedParallelism): self.assertTrue(executor.queue.empty()) self.assertEqual(len(executor.running), 0) self.assertTrue(executor.result_queue.empty()) for i in range(self.TEST_SUCCESS_COMMANDS): key = success_key.format(i) self.assertEqual(executor.event_buffer[key], State.SUCCESS) self.assertEqual(executor.event_buffer['fail'], State.FAILED) expected = self.TEST_SUCCESS_COMMANDS + 1 if parallelism == 0 else parallelism self.assertEqual(executor.workers_used, expected)
def fetch_celery_task_state(celery_task: Tuple[TaskInstanceKeyType, AsyncResult]) \ -> Union[TaskInstanceStateType, ExceptionWithTraceback]: """ Fetch and return the state of the given celery task. The scope of this function is global so that it can be called by subprocesses in the pool. :param celery_task: a tuple of the Celery task key and the async Celery object used to fetch the task's state :type celery_task: tuple(str, celery.result.AsyncResult) :return: a tuple of the Celery task key and the Celery state of the task :rtype: tuple[str, str] """ try: with timeout(seconds=OPERATION_TIMEOUT): # Accessing state property of celery task will make actual network request # to get the current state of the task. return celery_task[0], celery_task[1].state except Exception as e: # pylint: disable=broad-except exception_traceback = "Celery Task ID: {}\n{}".format( celery_task[0], traceback.format_exc()) return ExceptionWithTraceback(e, exception_traceback)
def execution_parallelism(self, parallelism=0): executor = LocalExecutor(parallelism=parallelism) executor.start() success_key = 'success {}' success_command = 'echo {}' fail_command = 'exit 1' for i in range(self.TEST_SUCCESS_COMMANDS): key, command = success_key.format(i), success_command.format(i) executor.execute_async(key=key, command=command) executor.running[key] = True # errors are propagated for some reason try: executor.execute_async(key='fail', command=fail_command) except: pass executor.running['fail'] = True if parallelism == 0: with timeout(seconds=5): executor.end() else: executor.end() for i in range(self.TEST_SUCCESS_COMMANDS): key = success_key.format(i) self.assertTrue(executor.event_buffer[key], State.SUCCESS) self.assertTrue(executor.event_buffer['fail'], State.FAILED) for i in range(self.TEST_SUCCESS_COMMANDS): self.assertNotIn(success_key.format(i), executor.running) self.assertNotIn('fail', executor.running) expected = self.TEST_SUCCESS_COMMANDS + 1 if parallelism == 0 else parallelism self.assertEqual(executor.workers_used, expected)
def process_file(self, filepath, only_if_updated=True): """ Given a path to a python module or zip file, this method imports the module and look for dag objects within it. """ found_dags = [] # if the source file no longer exists in the DB or in the filesystem, # return an empty list # todo: raise exception? if filepath is None or not os.path.isfile(filepath): return found_dags try: # This failed before in what may have been a git sync # race condition file_last_changed = datetime.fromtimestamp( os.path.getmtime(filepath)) if only_if_updated \ and filepath in self.dagbag.file_last_changed \ and file_last_changed == self.dagbag.file_last_changed[filepath]: return found_dags except Exception as e: self.log.exception(e) return found_dags mods = [] if not zipfile.is_zipfile(filepath): if self.safe_mode and os.path.isfile(filepath): with open(filepath, 'rb') as f: content = f.read() if not all([s in content for s in (b'DAG', b'airflow')]): self.dagbag.file_last_changed[ filepath] = file_last_changed return found_dags self.log.debug("Importing %s", filepath) org_mod_name, _ = os.path.splitext(os.path.split(filepath)[-1]) mod_name = ('unusual_prefix_' + hashlib.sha1(filepath.encode('utf-8')).hexdigest() + '_' + org_mod_name) if mod_name in sys.modules: del sys.modules[mod_name] with timeout(configuration.getint('core', "DAGBAG_IMPORT_TIMEOUT")): try: m = imp.load_source(mod_name, filepath) mods.append(m) except Exception as e: self.log.exception("Failed to import: %s", filepath) self.dagbag.import_errors[filepath] = str(e) self.dagbag.file_last_changed[filepath] = file_last_changed else: zip_file = zipfile.ZipFile(filepath) for mod in zip_file.infolist(): head, _ = os.path.split(mod.filename) mod_name, ext = os.path.splitext(mod.filename) if not head and (ext == '.py' or ext == '.pyc'): if mod_name == '__init__': self.log.warning("Found __init__.%s at root of %s", ext, filepath) if self.safe_mode: with zip_file.open(mod.filename) as zf: self.log.debug("Reading %s from %s", mod.filename, filepath) content = zf.read() if not all( [s in content for s in (b'DAG', b'airflow')]): self.dagbag.file_last_changed[filepath] = ( file_last_changed) # todo: create ignore list return found_dags if mod_name in sys.modules: del sys.modules[mod_name] try: sys.path.insert(0, filepath) m = importlib.import_module(mod_name) mods.append(m) except Exception as e: self.log.exception("Failed to import: %s", filepath) self.dagbag.import_errors[filepath] = str(e) self.dagbag.file_last_changed[ filepath] = file_last_changed for m in mods: for dag in list(m.__dict__.values()): if isinstance(dag, airflow.models.DAG): if not dag.full_filepath: dag.full_filepath = filepath if dag.fileloc != filepath: dag.fileloc = filepath try: dag.is_subdag = False self.dagbag.bag_dag(dag, parent_dag=dag, root_dag=dag) found_dags.append(dag) found_dags += dag.subdags except AirflowDagCycleException as cycle_exception: self.log.exception("Failed to bag_dag: %s", dag.full_filepath) self.dagbag.import_errors[dag.full_filepath] = \ str(cycle_exception) self.dagbag.file_last_changed[dag.full_filepath] = \ file_last_changed self.dagbag.file_last_changed[filepath] = file_last_changed return found_dags
def run_with_timeout(): with timeout(seconds=30): job.run()
def process_file(self, filepath, only_if_updated=True, safe_mode=True): """ Given a path to a python module or zip file, this method imports the module and look for dag objects within it. """ from airflow.models.dag import DAG # Avoid circular import found_dags = [] # if the source file no longer exists in the DB or in the filesystem, # return an empty list # todo: raise exception? if filepath is None or not os.path.isfile(filepath): return found_dags try: # This failed before in what may have been a git sync # race condition file_last_changed_on_disk = datetime.fromtimestamp( os.path.getmtime(filepath)) if only_if_updated \ and filepath in self.file_last_changed \ and file_last_changed_on_disk == self.file_last_changed[filepath]: return found_dags except Exception as e: self.log.exception(e) return found_dags mods = [] is_zipfile = zipfile.is_zipfile(filepath) if not is_zipfile: if safe_mode: with open(filepath, 'rb') as f: content = f.read() if not all([s in content for s in (b'DAG', b'airflow')]): self.file_last_changed[ filepath] = file_last_changed_on_disk # Don't want to spam user with skip messages if not self.has_logged: self.has_logged = True self.log.info( "File %s assumed to contain no DAGs. Skipping.", filepath) return found_dags self.log.debug("Importing %s", filepath) org_mod_name, _ = os.path.splitext(os.path.split(filepath)[-1]) mod_name = ('unusual_prefix_' + hashlib.sha1(filepath.encode('utf-8')).hexdigest() + '_' + org_mod_name) if mod_name in sys.modules: del sys.modules[mod_name] with timeout(self.DAGBAG_IMPORT_TIMEOUT): try: m = imp.load_source(mod_name, filepath) mods.append(m) except Exception as e: self.log.exception("Failed to import: %s", filepath) self.import_errors[filepath] = str(e) self.file_last_changed[ filepath] = file_last_changed_on_disk else: zip_file = zipfile.ZipFile(filepath) for mod in zip_file.infolist(): head, _ = os.path.split(mod.filename) mod_name, ext = os.path.splitext(mod.filename) if not head and (ext == '.py' or ext == '.pyc'): if mod_name == '__init__': self.log.warning("Found __init__.%s at root of %s", ext, filepath) if safe_mode: with zip_file.open(mod.filename) as zf: self.log.debug("Reading %s from %s", mod.filename, filepath) content = zf.read() if not all( [s in content for s in (b'DAG', b'airflow')]): self.file_last_changed[filepath] = ( file_last_changed_on_disk) # todo: create ignore list # Don't want to spam user with skip messages if not self.has_logged: self.has_logged = True self.log.info( "File %s assumed to contain no DAGs. Skipping.", filepath) if mod_name in sys.modules: del sys.modules[mod_name] try: sys.path.insert(0, filepath) m = importlib.import_module(mod_name) mods.append(m) except Exception as e: self.log.exception("Failed to import: %s", filepath) self.import_errors[filepath] = str(e) self.file_last_changed[ filepath] = file_last_changed_on_disk for m in mods: for dag in list(m.__dict__.values()): if isinstance(dag, DAG): if not dag.full_filepath: dag.full_filepath = filepath if dag.fileloc != filepath and not is_zipfile: dag.fileloc = filepath try: dag.is_subdag = False self.bag_dag(dag, parent_dag=dag, root_dag=dag) if isinstance(dag._schedule_interval, six.string_types): croniter(dag._schedule_interval) found_dags.append(dag) found_dags += dag.subdags except (CroniterBadCronError, CroniterBadDateError, CroniterNotAlphaError) as cron_e: self.log.exception("Failed to bag_dag: %s", dag.full_filepath) self.import_errors[dag.full_filepath] = \ "Invalid Cron expression: " + str(cron_e) self.file_last_changed[dag.full_filepath] = \ file_last_changed_on_disk except AirflowDagCycleException as cycle_exception: self.log.exception("Failed to bag_dag: %s", dag.full_filepath) self.import_errors[dag.full_filepath] = str( cycle_exception) self.file_last_changed[dag.full_filepath] = \ file_last_changed_on_disk self.file_last_changed[filepath] = file_last_changed_on_disk return found_dags
def test_failure_callback_only_called_once(self, mock_return_code, _check_call): """ Test that ensures that when a task exits with failure by itself, failure callback is only called once """ # use shared memory value so we can properly track value change even if # it's been updated across processes. failure_callback_called = Value('i', 0) callback_count_lock = Lock() def failure_callback(context): with callback_count_lock: failure_callback_called.value += 1 assert context['dag_run'].dag_id == 'test_failure_callback_race' assert isinstance(context['exception'], AirflowFailException) def task_function(ti): raise AirflowFailException() dag = DAG(dag_id='test_failure_callback_race', start_date=DEFAULT_DATE) task = PythonOperator( task_id='test_exit_on_failure', python_callable=task_function, on_failure_callback=failure_callback, dag=dag, ) dag.clear() with create_session() as session: dag.create_dagrun( run_id="test", state=State.RUNNING, execution_date=DEFAULT_DATE, start_date=DEFAULT_DATE, session=session, ) ti = TaskInstance(task=task, execution_date=DEFAULT_DATE) ti.refresh_from_db() job1 = LocalTaskJob(task_instance=ti, ignore_ti_state=True, executor=SequentialExecutor()) # Simulate race condition where job1 heartbeat ran right after task # state got set to failed by ti.handle_failure but before task process # fully exits. See _execute loop in airflow/jobs/local_task_job.py. # In this case, we have: # * task_runner.return_code() is None # * ti.state == State.Failed # # We also need to set return_code to a valid int after job1.terminating # is set to True so _execute loop won't loop forever. def dummy_return_code(*args, **kwargs): return None if not job1.terminating else -9 mock_return_code.side_effect = dummy_return_code with timeout(10): # This should be _much_ shorter to run. # If you change this limit, make the timeout in the callbable above bigger job1.run() ti.refresh_from_db() assert ti.state == State.FAILED # task exits with failure state assert failure_callback_called.value == 1