def test_mark_failure_on_failure_callback(self): """ Test that ensures that mark_failure in the UI fails the task, and executes on_failure_callback """ data = {'called': False} def check_failure(context): self.assertEqual(context['dag_run'].dag_id, 'test_mark_failure') data['called'] = True def task_function(ti): print("python_callable run in pid %s", os.getpid()) with create_session() as session: self.assertEqual(State.RUNNING, ti.state) ti.log.info("Marking TI as failed 'externally'") ti.state = State.FAILED session.merge(ti) session.commit() time.sleep(60) # This should not happen -- the state change should be noticed and the task should get killed data['reached_end_of_sleep'] = True with DAG(dag_id='test_mark_failure', start_date=DEFAULT_DATE) as dag: task = PythonOperator(task_id='test_state_succeeded1', python_callable=task_function, on_failure_callback=check_failure) session = settings.Session() dag.clear() dag.create_dagrun(run_id="test", state=State.RUNNING, execution_date=DEFAULT_DATE, start_date=DEFAULT_DATE, session=session) ti = TaskInstance(task=task, execution_date=DEFAULT_DATE) ti.refresh_from_db() job1 = LocalTaskJob(task_instance=ti, ignore_ti_state=True, executor=SequentialExecutor()) with timeout(30): # This should be _much_ shorter to run. # If you change this limit, make the timeout in the callbable above bigger job1.run() ti.refresh_from_db() self.assertEqual(ti.state, State.FAILED) self.assertTrue(data['called']) self.assertNotIn( 'reached_end_of_sleep', data, 'Task should not have been allowed to run to completion')
def _run_task_by_local_task_job(args, ti): """Run LocalTaskJob, which monitors the raw task execution process""" run_job = LocalTaskJob(task_instance=ti, mark_success=args.mark_success, pickle_id=args.pickle, ignore_all_deps=args.ignore_all_dependencies, ignore_depends_on_past=args.ignore_depends_on_past, ignore_task_deps=args.ignore_dependencies, ignore_ti_state=args.force, pool=args.pool) run_job.run()
def test_mark_success_on_success_callback(self): """ Test that ensures that where a task is marked suceess in the UI on_success_callback gets executed """ data = {'called': False} def success_callback(context): self.assertEqual(context['dag_run'].dag_id, 'test_mark_success') data['called'] = True dag = DAG(dag_id='test_mark_success', start_date=DEFAULT_DATE, default_args={'owner': 'owner1'}) task = DummyOperator(task_id='test_state_succeeded1', dag=dag, on_success_callback=success_callback) session = settings.Session() dag.clear() dag.create_dagrun( run_id="test", state=State.RUNNING, execution_date=DEFAULT_DATE, start_date=DEFAULT_DATE, session=session, ) ti = TaskInstance(task=task, execution_date=DEFAULT_DATE) ti.refresh_from_db() job1 = LocalTaskJob(task_instance=ti, ignore_ti_state=True, executor=SequentialExecutor()) from airflow.task.task_runner.standard_task_runner import StandardTaskRunner job1.task_runner = StandardTaskRunner(job1) process = multiprocessing.Process(target=job1.run) process.start() ti.refresh_from_db() for _ in range(0, 50): if ti.state == State.RUNNING: break time.sleep(0.1) ti.refresh_from_db() self.assertEqual(State.RUNNING, ti.state) ti.state = State.SUCCESS session.merge(ti) session.commit() job1.heartbeat_callback(session=None) self.assertTrue(data['called']) process.join(timeout=10) self.assertFalse(process.is_alive())
def test_localtaskjob_maintain_heart_rate(self): dagbag = DagBag( dag_folder=TEST_DAG_FOLDER, include_examples=False, ) dag = dagbag.dags.get('test_localtaskjob_double_trigger') task = dag.get_task('test_localtaskjob_double_trigger_task') session = settings.Session() dag.clear() dag.create_dagrun( run_id="test", state=State.SUCCESS, execution_date=DEFAULT_DATE, start_date=DEFAULT_DATE, session=session, ) ti_run = TaskInstance(task=task, execution_date=DEFAULT_DATE) ti_run.refresh_from_db() job1 = LocalTaskJob(task_instance=ti_run, executor=SequentialExecutor()) # this should make sure we only heartbeat once and exit at the second # loop in _execute() return_codes = [None, 0] def multi_return_code(): return return_codes.pop(0) time_start = time.time() from airflow.task.task_runner.standard_task_runner import StandardTaskRunner with patch.object(StandardTaskRunner, 'start', return_value=None) as mock_start: with patch.object(StandardTaskRunner, 'return_code') as mock_ret_code: mock_ret_code.side_effect = multi_return_code job1.run() self.assertEqual(mock_start.call_count, 1) self.assertEqual(mock_ret_code.call_count, 2) time_end = time.time() self.assertEqual(self.mock_base_job_sleep.call_count, 1) self.assertEqual(job1.state, State.SUCCESS) # Consider we have patched sleep call, it should not be sleeping to # keep up with the heart rate in other unpatched places # # We already make sure patched sleep call is only called once self.assertLess(time_end - time_start, job1.heartrate) session.close()
def test_find_zombies(self): manager = DagFileProcessorManager( dag_directory='directory', max_runs=1, processor_factory=MagicMock().return_value, processor_timeout=timedelta.max, signal_conn=MagicMock(), dag_ids=[], pickle_dags=False, async_mode=True, ) dagbag = DagBag(TEST_DAG_FOLDER, read_dags_from_db=False) with create_session() as session: session.query(LJ).delete() dag = dagbag.get_dag('example_branch_operator') dag.sync_to_db() task = dag.get_task(task_id='run_this_first') ti = TI(task, DEFAULT_DATE, State.RUNNING) local_job = LJ(ti) local_job.state = State.SHUTDOWN session.add(local_job) session.commit() ti.job_id = local_job.id session.add(ti) session.commit() manager._last_zombie_query_time = timezone.utcnow() - timedelta( seconds=manager._zombie_threshold_secs + 1) manager._find_zombies() # pylint: disable=no-value-for-parameter requests = manager._callback_to_execute[dag.full_filepath] self.assertEqual(1, len(requests)) self.assertEqual(requests[0].full_filepath, dag.full_filepath) self.assertEqual(requests[0].msg, "Detected as zombie") self.assertEqual(requests[0].is_failure_callback, True) self.assertIsInstance(requests[0].simple_task_instance, SimpleTaskInstance) self.assertEqual(ti.dag_id, requests[0].simple_task_instance.dag_id) self.assertEqual(ti.task_id, requests[0].simple_task_instance.task_id) self.assertEqual(ti.execution_date, requests[0].simple_task_instance.execution_date) session.query(TI).delete() session.query(LJ).delete()
def test_number_of_queries_single_loop(self, mock_get_task_runner, return_codes): unique_prefix = str(uuid.uuid4()) dag = DAG(dag_id=f'{unique_prefix}_test_number_of_queries', start_date=DEFAULT_DATE) task = DummyOperator(task_id='test_state_succeeded1', dag=dag) dag.clear() dag.create_dagrun(run_id=unique_prefix, state=State.NONE) ti = TaskInstance(task=task, execution_date=DEFAULT_DATE) mock_get_task_runner.return_value.return_code.side_effects = return_codes job = LocalTaskJob(task_instance=ti, executor=MockExecutor()) with assert_queries_count(13): job.run()
def test_terminate_task(self): """If a task instance's db state get deleted, it should fail""" from airflow.executors.sequential_executor import SequentialExecutor TI = TaskInstance dag = self.dagbag.dags.get('test_utils') task = dag.task_dict.get('sleeps_forever') ti = TI(task=task, execution_date=DEFAULT_DATE) job = LocalTaskJob(task_instance=ti, ignore_ti_state=True, executor=SequentialExecutor()) # Running task instance asynchronously proc = multiprocessing.Process(target=job.run) proc.start() sleep(5) settings.engine.dispose() session = settings.Session() ti.refresh_from_db(session=session) # making sure it's actually running self.assertEqual(State.RUNNING, ti.state) ti = session.query(TI).filter_by(dag_id=task.dag_id, task_id=task.task_id, execution_date=DEFAULT_DATE).one() # deleting the instance should result in a failure session.delete(ti) session.commit() # waiting for the async task to finish proc.join() # making sure that the task ended up as failed ti.refresh_from_db(session=session) self.assertEqual(State.FAILED, ti.state) session.close()
def test_localtaskjob_essential_attr(self): """ Check whether essential attributes of LocalTaskJob can be assigned with proper values without intervention """ dag = DAG('test_localtaskjob_essential_attr', start_date=DEFAULT_DATE, default_args={'owner': 'owner1'}) with dag: op1 = DummyOperator(task_id='op1') dag.clear() dr = dag.create_dagrun(run_id="test", state=State.SUCCESS, execution_date=DEFAULT_DATE, start_date=DEFAULT_DATE) ti = dr.get_task_instance(task_id=op1.task_id) job1 = LocalTaskJob(task_instance=ti, ignore_ti_state=True, executor=SequentialExecutor()) essential_attr = ["dag_id", "job_type", "start_date", "hostname"] check_result_1 = [hasattr(job1, attr) for attr in essential_attr] self.assertTrue(all(check_result_1)) check_result_2 = [ getattr(job1, attr) is not None for attr in essential_attr ] self.assertTrue(all(check_result_2))
def test_on_kill(self): """ Test that ensures that clearing in the UI SIGTERMS the task """ path = "/tmp/airflow_on_kill" try: os.unlink(path) except OSError: pass dagbag = models.DagBag( dag_folder=TEST_DAG_FOLDER, include_examples=False, ) dag = dagbag.dags.get('test_on_kill') task = dag.get_task('task1') session = settings.Session() dag.clear() dag.create_dagrun( run_id="test", state=State.RUNNING, execution_date=DEFAULT_DATE, start_date=DEFAULT_DATE, session=session, ) ti = TI(task=task, execution_date=DEFAULT_DATE) job1 = LocalTaskJob(task_instance=ti, ignore_ti_state=True) session.commit() runner = StandardTaskRunner(job1) runner.start() # give the task some time to startup time.sleep(3) pgid = os.getpgid(runner.process.pid) self.assertGreater(pgid, 0) self.assertNotEqual( pgid, os.getpgid(0), "Task should be in a different process group to us") processes = list(self._procs_in_pgroup(pgid)) runner.terminate() # Wait some time for the result for _ in range(20): if os.path.exists(path): break time.sleep(2) with open(path) as f: self.assertEqual("ON_KILL_TEST", f.readline()) for process in processes: self.assertFalse(psutil.pid_exists(process.pid), f"{process} is still alive")
def _run_task_by_local_task_job(args, ti): """Run LocalTaskJob, which monitors the raw task execution process""" run_job = LocalTaskJob(task_instance=ti, mark_success=args.mark_success, pickle_id=args.pickle, ignore_all_deps=args.ignore_all_dependencies, ignore_depends_on_past=args.ignore_depends_on_past, ignore_task_deps=args.ignore_dependencies, ignore_ti_state=args.force, pool=args.pool, server_uri=args.server_uri) try: run_job.run() finally: if args.shut_down_logging: logging.shutdown()
def test_localtaskjob_double_trigger(self): dagbag = DagBag( dag_folder=TEST_DAG_FOLDER, include_examples=False, ) dag = dagbag.dags.get('test_localtaskjob_double_trigger') task = dag.get_task('test_localtaskjob_double_trigger_task') session = settings.Session() dag.clear() dr = dag.create_dagrun( run_id="test", state=State.SUCCESS, execution_date=DEFAULT_DATE, start_date=DEFAULT_DATE, session=session, ) ti = dr.get_task_instance(task_id=task.task_id, session=session) ti.state = State.RUNNING ti.hostname = get_hostname() ti.pid = 1 session.merge(ti) session.commit() ti_run = TaskInstance(task=task, execution_date=DEFAULT_DATE) ti_run.refresh_from_db() job1 = LocalTaskJob(task_instance=ti_run, executor=SequentialExecutor()) from airflow.task.task_runner.standard_task_runner import StandardTaskRunner with patch.object(StandardTaskRunner, 'start', return_value=None) as mock_method: job1.run() mock_method.assert_not_called() ti = dr.get_task_instance(task_id=task.task_id, session=session) self.assertEqual(ti.pid, 1) self.assertEqual(ti.state, State.RUNNING) session.close()
def test_localtaskjob_heartbeat(self): session = settings.Session() dag = DAG('test_localtaskjob_heartbeat', start_date=DEFAULT_DATE, default_args={'owner': 'owner1'}) with dag: op1 = DummyOperator(task_id='op1') dag.clear() dr = dag.create_dagrun( run_id="test", state=State.SUCCESS, execution_date=DEFAULT_DATE, start_date=DEFAULT_DATE, session=session, ) ti = dr.get_task_instance(task_id=op1.task_id, session=session) ti.state = State.RUNNING ti.hostname = "blablabla" session.commit() job1 = LocalTaskJob(task_instance=ti, ignore_ti_state=True, executor=SequentialExecutor()) ti.task = op1 ti.refresh_from_task(op1) job1.task_runner = StandardTaskRunner(job1) job1.task_runner.process = mock.Mock() with pytest.raises(AirflowException): job1.heartbeat_callback() # pylint: disable=no-value-for-parameter job1.task_runner.process.pid = 1 ti.state = State.RUNNING ti.hostname = get_hostname() ti.pid = 1 session.merge(ti) session.commit() assert ti.pid != os.getpid() job1.heartbeat_callback(session=None) job1.task_runner.process.pid = 2 with pytest.raises(AirflowException): job1.heartbeat_callback() # pylint: disable=no-value-for-parameter
def test_localtaskjob_heartbeat(self, mock_pid): session = settings.Session() dag = DAG('test_localtaskjob_heartbeat', start_date=DEFAULT_DATE, default_args={'owner': 'owner1'}) with dag: op1 = DummyOperator(task_id='op1') dag.clear() dr = dag.create_dagrun( run_id="test", state=State.SUCCESS, execution_date=DEFAULT_DATE, start_date=DEFAULT_DATE, session=session, ) ti = dr.get_task_instance(task_id=op1.task_id, session=session) ti.state = State.RUNNING ti.hostname = "blablabla" session.commit() job1 = LocalTaskJob(task_instance=ti, ignore_ti_state=True, executor=SequentialExecutor()) self.assertRaises(AirflowException, job1.heartbeat_callback) mock_pid.return_value = 1 ti.state = State.RUNNING ti.hostname = get_hostname() ti.pid = 1 session.merge(ti) session.commit() job1.heartbeat_callback(session=None) mock_pid.return_value = 2 self.assertRaises(AirflowException, job1.heartbeat_callback)
def test_find_zombies(self): manager = DagFileProcessorManager( dag_directory='directory', file_paths=['abc.txt'], max_runs=1, processor_factory=MagicMock().return_value, processor_timeout=timedelta.max, signal_conn=MagicMock(), async_mode=True) dagbag = DagBag(TEST_DAG_FOLDER) with create_session() as session: session.query(LJ).delete() dag = dagbag.get_dag('example_branch_operator') task = dag.get_task(task_id='run_this_first') ti = TI(task, DEFAULT_DATE, State.RUNNING) local_job = LJ(ti) local_job.state = State.SHUTDOWN local_job.id = 1 ti.job_id = local_job.id session.add(local_job) session.add(ti) session.commit() manager._last_zombie_query_time = timezone.utcnow() - timedelta( seconds=manager._zombie_threshold_secs + 1) manager._find_zombies() # pylint: disable=no-value-for-parameter zombies = manager._zombies self.assertEqual(1, len(zombies)) self.assertIsInstance(zombies[0], SimpleTaskInstance) self.assertEqual(ti.dag_id, zombies[0].dag_id) self.assertEqual(ti.task_id, zombies[0].task_id) self.assertEqual(ti.execution_date, zombies[0].execution_date) session.query(TI).delete() session.query(LJ).delete()
def test_mark_success_no_kill(self): """ Test that ensures that mark_success in the UI doesn't cause the task to fail, and that the task exits """ dagbag = DagBag( dag_folder=TEST_DAG_FOLDER, include_examples=False, ) dag = dagbag.dags.get('test_mark_success') task = dag.get_task('task1') session = settings.Session() dag.clear() dag.create_dagrun( run_id="test", state=State.RUNNING, execution_date=DEFAULT_DATE, start_date=DEFAULT_DATE, session=session, ) ti = TaskInstance(task=task, execution_date=DEFAULT_DATE) ti.refresh_from_db() job1 = LocalTaskJob(task_instance=ti, ignore_ti_state=True) process = multiprocessing.Process(target=job1.run) process.start() ti.refresh_from_db() for _ in range(0, 50): if ti.state == State.RUNNING: break time.sleep(0.1) ti.refresh_from_db() self.assertEqual(State.RUNNING, ti.state) ti.state = State.SUCCESS session.merge(ti) session.commit() process.join(timeout=10) self.assertFalse(process.is_alive()) ti.refresh_from_db() self.assertEqual(State.SUCCESS, ti.state)
def test_heartbeat_failed_fast(self): """ Test that task heartbeat will sleep when it fails fast """ self.mock_base_job_sleep.side_effect = time.sleep with create_session() as session: dagbag = DagBag( dag_folder=TEST_DAG_FOLDER, include_examples=False, ) dag_id = 'test_heartbeat_failed_fast' task_id = 'test_heartbeat_failed_fast_op' dag = dagbag.get_dag(dag_id) task = dag.get_task(task_id) dag.create_dagrun( run_id="test_heartbeat_failed_fast_run", state=State.RUNNING, execution_date=DEFAULT_DATE, start_date=DEFAULT_DATE, session=session, ) ti = TaskInstance(task=task, execution_date=DEFAULT_DATE) ti.refresh_from_db() ti.state = State.RUNNING ti.hostname = get_hostname() ti.pid = 1 session.commit() job = LocalTaskJob(task_instance=ti, executor=MockExecutor(do_update=False)) job.heartrate = 2 heartbeat_records = [] job.heartbeat_callback = lambda session: heartbeat_records.append( job.latest_heartbeat) job._execute() self.assertGreater(len(heartbeat_records), 2) for i in range(1, len(heartbeat_records)): time1 = heartbeat_records[i - 1] time2 = heartbeat_records[i] # Assert that difference small enough delta = (time2 - time1).total_seconds() self.assertAlmostEqual(delta, job.heartrate, delta=0.05)
def test_handle_failure_callback_with_zombies_are_correctly_passed_to_dag_file_processor(self): """ Check that the same set of failure callback with zombies are passed to the dag file processors until the next zombie detection logic is invoked. """ test_dag_path = os.path.join(TEST_DAG_FOLDER, 'test_example_bash_operator.py') with conf_vars({('scheduler', 'parsing_processes'): '1', ('core', 'load_examples'): 'False'}): dagbag = DagBag(test_dag_path, read_dags_from_db=False) with create_session() as session: session.query(LJ).delete() dag = dagbag.get_dag('test_example_bash_operator') dag.sync_to_db() task = dag.get_task(task_id='run_this_last') ti = TI(task, DEFAULT_DATE, State.RUNNING) local_job = LJ(ti) local_job.state = State.SHUTDOWN session.add(local_job) session.commit() # TODO: If there was an actual Relationship between TI and Job # we wouldn't need this extra commit session.add(ti) ti.job_id = local_job.id session.commit() expected_failure_callback_requests = [ TaskCallbackRequest( full_filepath=dag.full_filepath, simple_task_instance=SimpleTaskInstance(ti), msg="Message", ) ] test_dag_path = os.path.join(TEST_DAG_FOLDER, 'test_example_bash_operator.py') child_pipe, parent_pipe = multiprocessing.Pipe() async_mode = 'sqlite' not in conf.get('core', 'sql_alchemy_conn') fake_processors = [] def fake_processor_factory(*args, **kwargs): nonlocal fake_processors processor = FakeDagFileProcessorRunner._fake_dag_processor_factory(*args, **kwargs) fake_processors.append(processor) return processor manager = DagFileProcessorManager( dag_directory=test_dag_path, max_runs=1, processor_factory=fake_processor_factory, processor_timeout=timedelta.max, signal_conn=child_pipe, dag_ids=[], pickle_dags=False, async_mode=async_mode, ) self.run_processor_manager_one_loop(manager, parent_pipe) if async_mode: # Once for initial parse, and then again for the add_callback_to_queue assert len(fake_processors) == 2 assert fake_processors[0]._file_path == test_dag_path assert fake_processors[0]._callback_requests == [] else: assert len(fake_processors) == 1 assert fake_processors[-1]._file_path == test_dag_path callback_requests = fake_processors[-1]._callback_requests assert {zombie.simple_task_instance.key for zombie in expected_failure_callback_requests} == { result.simple_task_instance.key for result in callback_requests } child_pipe.close() parent_pipe.close()
def test_local_task_job(self): TI = TaskInstance ti = TI( task=self.runme_0, execution_date=DEFAULT_DATE) job = LocalTaskJob(task_instance=ti, ignore_ti_state=True) job.run()
def test_failure_callback_only_called_once(self, mock_return_code, _check_call): """ Test that ensures that when a task exits with failure by itself, failure callback is only called once """ # use shared memory value so we can properly track value change even if # it's been updated across processes. failure_callback_called = Value('i', 0) callback_count_lock = Lock() def failure_callback(context): with callback_count_lock: failure_callback_called.value += 1 assert context['dag_run'].dag_id == 'test_failure_callback_race' assert isinstance(context['exception'], AirflowFailException) def task_function(ti): raise AirflowFailException() dag = DAG(dag_id='test_failure_callback_race', start_date=DEFAULT_DATE) task = PythonOperator( task_id='test_exit_on_failure', python_callable=task_function, on_failure_callback=failure_callback, dag=dag, ) dag.clear() with create_session() as session: dag.create_dagrun( run_id="test", state=State.RUNNING, execution_date=DEFAULT_DATE, start_date=DEFAULT_DATE, session=session, ) ti = TaskInstance(task=task, execution_date=DEFAULT_DATE) ti.refresh_from_db() job1 = LocalTaskJob(task_instance=ti, ignore_ti_state=True, executor=SequentialExecutor()) # Simulate race condition where job1 heartbeat ran right after task # state got set to failed by ti.handle_failure but before task process # fully exits. See _execute loop in airflow/jobs/local_task_job.py. # In this case, we have: # * task_runner.return_code() is None # * ti.state == State.Failed # # We also need to set return_code to a valid int after job1.terminating # is set to True so _execute loop won't loop forever. def dummy_return_code(*args, **kwargs): return None if not job1.terminating else -9 mock_return_code.side_effect = dummy_return_code with timeout(10): # This should be _much_ shorter to run. # If you change this limit, make the timeout in the callbable above bigger job1.run() ti.refresh_from_db() assert ti.state == State.FAILED # task exits with failure state assert failure_callback_called.value == 1
def test_handle_failure_callback_with_zobmies_are_correctly_passed_to_dag_file_processor( self): """ Check that the same set of failure callback with zombies are passed to the dag file processors until the next zombie detection logic is invoked. """ test_dag_path = os.path.join(TEST_DAG_FOLDER, 'test_example_bash_operator.py') with conf_vars({ ('scheduler', 'max_threads'): '1', ('core', 'load_examples'): 'False' }): dagbag = DagBag(test_dag_path) with create_session() as session: session.query(LJ).delete() dag = dagbag.get_dag('test_example_bash_operator') dag.sync_to_db() task = dag.get_task(task_id='run_this_last') ti = TI(task, DEFAULT_DATE, State.RUNNING) local_job = LJ(ti) local_job.state = State.SHUTDOWN local_job.id = 1 ti.job_id = local_job.id session.add(local_job) session.add(ti) session.commit() fake_failure_callback_requests = [ FailureCallbackRequest( full_filepath=dag.full_filepath, simple_task_instance=SimpleTaskInstance(ti), msg="Message") ] class FakeDagFileProcessorRunner(DagFileProcessorProcess): # This fake processor will return the zombies it received in constructor # as its processing result w/o actually parsing anything. def __init__(self, file_path, pickle_dags, dag_id_white_list, failure_callback_requests): super().__init__(file_path, pickle_dags, dag_id_white_list, failure_callback_requests) self._result = failure_callback_requests, 0 def start(self): pass @property def start_time(self): return DEFAULT_DATE @property def pid(self): return 1234 @property def done(self): return True @property def result(self): return self._result def processor_factory(file_path, failure_callback_requests): return FakeDagFileProcessorRunner(file_path, False, [], failure_callback_requests) async_mode = 'sqlite' not in conf.get('core', 'sql_alchemy_conn') processor_agent = DagFileProcessorAgent(test_dag_path, 1, processor_factory, timedelta.max, async_mode) processor_agent.start() parsing_result = [] if not async_mode: processor_agent.run_single_parsing_loop() while not processor_agent.done: if not async_mode: processor_agent.wait_until_finished() parsing_result.extend(processor_agent.harvest_simple_dags()) self.assertEqual(len(fake_failure_callback_requests), len(parsing_result)) self.assertEqual( set(zombie.simple_task_instance.key for zombie in fake_failure_callback_requests), set(result.simple_task_instance.key for result in parsing_result))
def test_handle_failure_callback_with_zombies_are_correctly_passed_to_dag_file_processor( self): """ Check that the same set of failure callback with zombies are passed to the dag file processors until the next zombie detection logic is invoked. """ test_dag_path = os.path.join(TEST_DAG_FOLDER, 'test_example_bash_operator.py') with conf_vars({ ('scheduler', 'max_threads'): '1', ('core', 'load_examples'): 'False' }): dagbag = DagBag(test_dag_path) with create_session() as session: session.query(LJ).delete() dag = dagbag.get_dag('test_example_bash_operator') dag.sync_to_db() task = dag.get_task(task_id='run_this_last') ti = TI(task, DEFAULT_DATE, State.RUNNING) local_job = LJ(ti) local_job.state = State.SHUTDOWN session.add(local_job) session.commit() # TODO: If there was an actual Relationshop between TI and Job # we wouldn't need this extra commit session.add(ti) ti.job_id = local_job.id session.commit() fake_failure_callback_requests = [ FailureCallbackRequest( full_filepath=dag.full_filepath, simple_task_instance=SimpleTaskInstance(ti), msg="Message") ] test_dag_path = os.path.join(TEST_DAG_FOLDER, 'test_example_bash_operator.py') child_pipe, parent_pipe = multiprocessing.Pipe() async_mode = 'sqlite' not in conf.get('core', 'sql_alchemy_conn') manager = DagFileProcessorManager( dag_directory=test_dag_path, max_runs=1, processor_factory=FakeDagFileProcessorRunner. _fake_dag_processor_factory, processor_timeout=timedelta.max, signal_conn=child_pipe, dag_ids=[], pickle_dags=False, async_mode=async_mode) parsing_result = self.run_processor_manager_one_loop( manager, parent_pipe) self.assertEqual(len(fake_failure_callback_requests), len(parsing_result)) self.assertEqual( set(zombie.simple_task_instance.key for zombie in fake_failure_callback_requests), set(result.simple_task_instance.key for result in parsing_result)) child_pipe.close() parent_pipe.close()
def test_mark_success_on_success_callback(self): """ Test that ensures that where a task is marked suceess in the UI on_success_callback gets executed """ # use shared memory value so we can properly track value change even if # it's been updated across processes. success_callback_called = Value('i', 0) task_terminated_externally = Value('i', 1) shared_mem_lock = Lock() def success_callback(context): with shared_mem_lock: success_callback_called.value += 1 assert context['dag_run'].dag_id == 'test_mark_success' dag = DAG(dag_id='test_mark_success', start_date=DEFAULT_DATE, default_args={'owner': 'owner1'}) def task_function(ti): # pylint: disable=unused-argument time.sleep(60) # This should not happen -- the state change should be noticed and the task should get killed with shared_mem_lock: task_terminated_externally.value = 0 task = PythonOperator( task_id='test_state_succeeded1', python_callable=task_function, on_success_callback=success_callback, dag=dag, ) session = settings.Session() dag.clear() dag.create_dagrun( run_id="test", state=State.RUNNING, execution_date=DEFAULT_DATE, start_date=DEFAULT_DATE, session=session, ) ti = TaskInstance(task=task, execution_date=DEFAULT_DATE) ti.refresh_from_db() job1 = LocalTaskJob(task_instance=ti, ignore_ti_state=True, executor=SequentialExecutor()) job1.task_runner = StandardTaskRunner(job1) settings.engine.dispose() process = multiprocessing.Process(target=job1.run) process.start() for _ in range(0, 25): ti.refresh_from_db() if ti.state == State.RUNNING: break time.sleep(0.2) assert ti.state == State.RUNNING ti.state = State.SUCCESS session.merge(ti) session.commit() process.join(timeout=10) assert success_callback_called.value == 1 assert task_terminated_externally.value == 1 assert not process.is_alive()