def test_skip(self, mock_now): session = settings.Session() now = datetime.datetime.utcnow().replace( tzinfo=pendulum.timezone('UTC')) mock_now.return_value = now dag = DAG( 'dag', start_date=DEFAULT_DATE, ) with dag: tasks = [DummyOperator(task_id='task')] dag_run = dag.create_dagrun( run_id='manual__' + now.isoformat(), state=State.FAILED, ) SkipMixin().skip(dag_run=dag_run, execution_date=now, tasks=tasks, session=session) session.query(TI).filter( TI.dag_id == 'dag', TI.task_id == 'task', TI.state == State.SKIPPED, TI.start_date == now, TI.end_date == now, ).one()
def test_dagrun_update_state_end_date(self): session = settings.Session() dag = DAG('test_dagrun_update_state_end_date', start_date=DEFAULT_DATE, default_args={'owner': 'owner1'}) # A -> B with dag: op1 = DummyOperator(task_id='A') op2 = DummyOperator(task_id='B') op1.set_upstream(op2) dag.clear() now = timezone.utcnow() dr = dag.create_dagrun(run_id='test_dagrun_update_state_end_date', state=State.RUNNING, execution_date=now, start_date=now) # Initial end_date should be NULL # State.SUCCESS and State.FAILED are all ending state and should set end_date # State.RUNNING set end_date back to NULL session.merge(dr) session.commit() self.assertIsNone(dr.end_date) ti_op1 = dr.get_task_instance(task_id=op1.task_id) ti_op1.set_state(state=State.SUCCESS, session=session) ti_op2 = dr.get_task_instance(task_id=op2.task_id) ti_op2.set_state(state=State.SUCCESS, session=session) dr.update_state() dr_database = session.query(DagRun).filter( DagRun.run_id == 'test_dagrun_update_state_end_date').one() self.assertIsNotNone(dr_database.end_date) self.assertEqual(dr.end_date, dr_database.end_date) ti_op1.set_state(state=State.RUNNING, session=session) ti_op2.set_state(state=State.RUNNING, session=session) dr.update_state() dr_database = session.query(DagRun).filter( DagRun.run_id == 'test_dagrun_update_state_end_date').one() self.assertEqual(dr._state, State.RUNNING) self.assertIsNone(dr.end_date) self.assertIsNone(dr_database.end_date) ti_op1.set_state(state=State.FAILED, session=session) ti_op2.set_state(state=State.FAILED, session=session) dr.update_state() dr_database = session.query(DagRun).filter( DagRun.run_id == 'test_dagrun_update_state_end_date').one() self.assertIsNotNone(dr_database.end_date) self.assertEqual(dr.end_date, dr_database.end_date)
def test_dagrun_deadlock(self): session = settings.Session() dag = DAG('text_dagrun_deadlock', start_date=DEFAULT_DATE, default_args={'owner': 'owner1'}) with dag: op1 = DummyOperator(task_id='A') op2 = DummyOperator(task_id='B') op2.trigger_rule = TriggerRule.ONE_FAILED op2.set_upstream(op1) dag.clear() now = timezone.utcnow() dr = dag.create_dagrun(run_id='test_dagrun_deadlock', state=State.RUNNING, execution_date=now, start_date=now) ti_op1 = dr.get_task_instance(task_id=op1.task_id) ti_op1.set_state(state=State.SUCCESS, session=session) ti_op2 = dr.get_task_instance(task_id=op2.task_id) ti_op2.set_state(state=State.NONE, session=session) dr.update_state() self.assertEqual(dr.state, State.RUNNING) ti_op2.set_state(state=State.NONE, session=session) op2.trigger_rule = 'invalid' dr.update_state() self.assertEqual(dr.state, State.FAILED)
def test_utc_transformations(self): """ Test whether what we are storing is what we are retrieving for datetimes """ dag_id = 'test_utc_transformations' start_date = utcnow() iso_date = start_date.isoformat() execution_date = start_date + datetime.timedelta(hours=1, days=1) dag = DAG( dag_id=dag_id, start_date=start_date, ) dag.clear() run = dag.create_dagrun( run_id=iso_date, state=State.NONE, execution_date=execution_date, start_date=start_date, session=self.session, ) assert execution_date == run.execution_date assert start_date == run.start_date assert execution_date.utcoffset().total_seconds() == 0.0 assert start_date.utcoffset().total_seconds() == 0.0 assert iso_date == run.run_id assert run.start_date.isoformat() == run.run_id dag.clear()
def test_execute_create_dagrun_with_conf(self): """ When SubDagOperator executes, it creates a DagRun if there is no existing one and wait until the DagRun succeeds. """ conf = {"key": "value"} dag = DAG('parent', default_args=default_args) subdag = DAG('parent.test', default_args=default_args) subdag_task = SubDagOperator(task_id='test', subdag=subdag, dag=dag, poke_interval=1, conf=conf) subdag.create_dagrun = Mock() subdag.create_dagrun.return_value = self.dag_run_running subdag_task._get_dagrun = Mock() subdag_task._get_dagrun.side_effect = [ None, self.dag_run_success, self.dag_run_success ] subdag_task.pre_execute(context={'execution_date': DEFAULT_DATE}) subdag_task.execute(context={'execution_date': DEFAULT_DATE}) subdag_task.post_execute(context={'execution_date': DEFAULT_DATE}) subdag.create_dagrun.assert_called_once_with( run_type=DagRunType.SCHEDULED, execution_date=DEFAULT_DATE, conf=conf, state=State.RUNNING, external_trigger=True, ) self.assertEqual(3, len(subdag_task._get_dagrun.mock_calls))
def test_fractional_seconds(self): """ Tests if fractional seconds are stored in the database """ dag_id = "test_fractional_seconds" dag = DAG(dag_id=dag_id) dag.schedule_interval = '@once' dag.add_task( BaseOperator(task_id="faketastic", owner='Also fake', start_date=datetime_tz(2015, 1, 2, 0, 0))) start_date = timezone.utcnow() run = dag.create_dagrun(run_id='test_' + start_date.isoformat(), execution_date=start_date, start_date=start_date, state=State.RUNNING, external_trigger=False) run.refresh_from_db() self.assertEqual(start_date, run.execution_date, "dag run execution_date loses precision") self.assertEqual(start_date, run.start_date, "dag run start_date loses precision ") self._clean_up(dag_id)
def _get_dag_run(self, run_date: datetime, dag: DAG, session: Session = None): """ Returns a dag run for the given run date, which will be matched to an existing dag run if available or create a new dag run otherwise. If the max_active_runs limit is reached, this function will return None. :param run_date: the execution date for the dag run :param dag: DAG :param session: the database session object :return: a DagRun in state RUNNING or None """ # consider max_active_runs but ignore when running subdags respect_dag_max_active_limit = bool(dag.schedule_interval and not dag.is_subdag) current_active_dag_count = dag.get_num_active_runs( external_trigger=False) # check if we are scheduling on top of a already existing dag_run # we could find a "scheduled" run instead of a "backfill" runs = DagRun.find(dag_id=dag.dag_id, execution_date=run_date, session=session) run: Optional[DagRun] if runs: run = runs[0] if run.state == State.RUNNING: respect_dag_max_active_limit = False else: run = None # enforce max_active_runs limit for dag, special cases already # handled by respect_dag_max_active_limit if (respect_dag_max_active_limit and current_active_dag_count >= dag.max_active_runs): return None run = run or dag.create_dagrun( execution_date=run_date, start_date=timezone.utcnow(), state=State.RUNNING, external_trigger=False, session=session, conf=self.conf, run_type=DagRunType.BACKFILL_JOB, creating_job_id=self.id, ) # set required transient field run.dag = dag # explicitly mark as backfill and running run.state = State.RUNNING run.run_id = run.generate_run_id(DagRunType.BACKFILL_JOB, run_date) run.run_type = DagRunType.BACKFILL_JOB run.verify_integrity(session=session) return run
def test_emit_scheduling_delay(self, stats_mock): """ Tests that dag scheduling delay stat is set properly once running scheduled dag. dag_run.update_state() invokes the _emit_true_scheduling_delay_stats_for_finished_state method. """ dag = DAG(dag_id='test_emit_dag_stats', start_date=days_ago(1)) dag_task = DummyOperator(task_id='dummy', dag=dag, owner='airflow') session = settings.Session() orm_dag = DagModel( dag_id=dag.dag_id, has_task_concurrency_limits=False, next_dagrun=dag.start_date, next_dagrun_create_after=dag.following_schedule(dag.start_date), is_active=True, ) session.add(orm_dag) session.flush() dag_run = dag.create_dagrun( run_type=DagRunType.SCHEDULED, state=State.SUCCESS, execution_date=dag.start_date, start_date=dag.start_date, session=session, ) ti = dag_run.get_task_instance(dag_task.task_id) ti.set_state(State.SUCCESS, session) session.commit() session.close() dag_run.update_state() true_delay = (ti.start_date - dag.following_schedule(dag_run.execution_date)).total_seconds() stats_mock.assert_called() sched_delay_stat_call = call(f'dagrun.{dag.dag_id}.first_task_scheduling_delay', true_delay) self.assertIn(sched_delay_stat_call, stats_mock.mock_calls)
def test_localtaskjob_essential_attr(self): """ Check whether essential attributes of LocalTaskJob can be assigned with proper values without intervention """ dag = DAG('test_localtaskjob_essential_attr', start_date=DEFAULT_DATE, default_args={'owner': 'owner1'}) with dag: op1 = DummyOperator(task_id='op1') dag.clear() dr = dag.create_dagrun(run_id="test", state=State.SUCCESS, execution_date=DEFAULT_DATE, start_date=DEFAULT_DATE) ti = dr.get_task_instance(task_id=op1.task_id) job1 = LocalTaskJob(task_instance=ti, ignore_ti_state=True, executor=SequentialExecutor()) essential_attr = ["dag_id", "job_type", "start_date", "hostname"] check_result_1 = [hasattr(job1, attr) for attr in essential_attr] self.assertTrue(all(check_result_1)) check_result_2 = [ getattr(job1, attr) is not None for attr in essential_attr ] self.assertTrue(all(check_result_2))
def test_skip(self, mock_now): session = settings.Session() now = datetime.datetime.utcnow().replace(tzinfo=pendulum.timezone('UTC')) mock_now.return_value = now dag = DAG( 'dag', start_date=DEFAULT_DATE, ) with dag: tasks = [DummyOperator(task_id='task')] dag_run = dag.create_dagrun( run_id='manual__' + now.isoformat(), state=State.FAILED, ) SkipMixin().skip( dag_run=dag_run, execution_date=now, tasks=tasks, session=session) session.query(TI).filter( TI.dag_id == 'dag', TI.task_id == 'task', TI.state == State.SKIPPED, TI.start_date == now, TI.end_date == now, ).one()
def test_rerun_failed_subdag(self): """ When there is an existing DagRun with failed state, reset the DagRun and the corresponding TaskInstances """ dag = DAG('parent', default_args=default_args) subdag = DAG('parent.test', default_args=default_args) subdag_task = SubDagOperator(task_id='test', subdag=subdag, dag=dag, poke_interval=1) dummy_task = DummyOperator(task_id='dummy', dag=subdag) with create_session() as session: dummy_task_instance = TaskInstance( task=dummy_task, execution_date=DEFAULT_DATE, state=State.FAILED, ) session.add(dummy_task_instance) session.commit() sub_dagrun = subdag.create_dagrun( run_id="scheduled__{}".format(DEFAULT_DATE.isoformat()), execution_date=DEFAULT_DATE, state=State.FAILED, external_trigger=True, ) subdag_task._reset_dag_run_and_task_instances(sub_dagrun, execution_date=DEFAULT_DATE) dummy_task_instance.refresh_from_db() self.assertEqual(dummy_task_instance.state, State.NONE) sub_dagrun.refresh_from_db() self.assertEqual(sub_dagrun.state, State.RUNNING)
def test_execute_create_dagrun_wait_until_success(self): """ When SubDagOperator executes, it creates a DagRun if there is no existing one and wait until the DagRun succeeds. """ dag = DAG('parent', default_args=default_args) subdag = DAG('parent.test', default_args=default_args) subdag_task = SubDagOperator(task_id='test', subdag=subdag, dag=dag, poke_interval=1) subdag.create_dagrun = Mock() subdag.create_dagrun.return_value = self.dag_run_running subdag_task._get_dagrun = Mock() subdag_task._get_dagrun.side_effect = [None, self.dag_run_success, self.dag_run_success] subdag_task.pre_execute(context={'execution_date': DEFAULT_DATE}) subdag_task.execute(context={'execution_date': DEFAULT_DATE}) subdag_task.post_execute(context={'execution_date': DEFAULT_DATE}) subdag.create_dagrun.assert_called_once_with( run_id="scheduled__{}".format(DEFAULT_DATE.isoformat()), execution_date=DEFAULT_DATE, state=State.RUNNING, external_trigger=True, ) self.assertEqual(3, len(subdag_task._get_dagrun.mock_calls))
def test_file_task_handler_when_ti_value_is_invalid(self): def task_callable(ti, **kwargs): ti.log.info("test") dag = DAG('dag_for_testing_file_task_handler', start_date=DEFAULT_DATE) dag.create_dagrun(run_type=DagRunType.MANUAL, state=State.RUNNING, execution_date=DEFAULT_DATE) task = PythonOperator( task_id='task_for_testing_file_log_handler', dag=dag, python_callable=task_callable, ) ti = TaskInstance(task=task, execution_date=DEFAULT_DATE) logger = ti.log ti.log.disabled = False file_handler = next( (handler for handler in logger.handlers if handler.name == FILE_TASK_HANDLER), None ) assert file_handler is not None set_context(logger, ti) assert file_handler.handler is not None # We expect set_context generates a file locally. log_filename = file_handler.handler.baseFilename assert os.path.isfile(log_filename) assert log_filename.endswith("1.log"), log_filename ti.run(ignore_ti_state=True) file_handler.flush() file_handler.close() assert hasattr(file_handler, 'read') # Return value of read must be a tuple of list and list. # passing invalid `try_number` to read function logs, metadatas = file_handler.read(ti, 0) assert isinstance(logs, list) assert isinstance(metadatas, list) assert len(logs) == 1 assert len(logs) == len(metadatas) assert isinstance(metadatas[0], dict) assert logs[0][0][0] == "default_host" assert logs[0][0][1] == "Error fetching the logs. Try number 0 is invalid." # Remove the generated tmp log file. os.remove(log_filename)
def test_with_dag_run(self): value = False dag = DAG('shortcircuit_operator_test_with_dag_run', default_args={ 'owner': 'airflow', 'start_date': DEFAULT_DATE }, schedule_interval=INTERVAL) short_op = ShortCircuitOperator(task_id='make_choice', dag=dag, python_callable=lambda: value) branch_1 = DummyOperator(task_id='branch_1', dag=dag) branch_1.set_upstream(short_op) branch_2 = DummyOperator(task_id='branch_2', dag=dag) branch_2.set_upstream(branch_1) upstream = DummyOperator(task_id='upstream', dag=dag) upstream.set_downstream(short_op) dag.clear() logging.error("Tasks {}".format(dag.tasks)) dr = dag.create_dagrun( run_id="manual__", start_date=timezone.utcnow(), execution_date=DEFAULT_DATE, state=State.RUNNING ) upstream.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) short_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) tis = dr.get_task_instances() self.assertEqual(len(tis), 4) for ti in tis: if ti.task_id == 'make_choice': self.assertEquals(ti.state, State.SUCCESS) elif ti.task_id == 'upstream': self.assertEquals(ti.state, State.SUCCESS) elif ti.task_id == 'branch_1' or ti.task_id == 'branch_2': self.assertEquals(ti.state, State.SKIPPED) else: raise value = True dag.clear() dr.verify_integrity() upstream.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) short_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) tis = dr.get_task_instances() self.assertEqual(len(tis), 4) for ti in tis: if ti.task_id == 'make_choice': self.assertEquals(ti.state, State.SUCCESS) elif ti.task_id == 'upstream': self.assertEquals(ti.state, State.SUCCESS) elif ti.task_id == 'branch_1' or ti.task_id == 'branch_2': self.assertEquals(ti.state, State.NONE) else: raise
def test_clear_skipped_downstream_task(self): """ After a downstream task is skipped by ShortCircuitOperator, clearing the skipped task should not cause it to be executed. """ dag = DAG( 'shortcircuit_clear_skipped_downstream_task', default_args={ 'owner': 'airflow', 'start_date': DEFAULT_DATE }, schedule_interval=INTERVAL, ) short_op = ShortCircuitOperator(task_id='make_choice', dag=dag, python_callable=lambda: False) downstream = DummyOperator(task_id='downstream', dag=dag) short_op >> downstream dag.clear() dr = dag.create_dagrun( run_type=DagRunType.MANUAL, start_date=timezone.utcnow(), execution_date=DEFAULT_DATE, state=State.RUNNING, ) short_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) downstream.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) tis = dr.get_task_instances() for ti in tis: if ti.task_id == 'make_choice': assert ti.state == State.SUCCESS elif ti.task_id == 'downstream': assert ti.state == State.SKIPPED else: raise ValueError(f'Invalid task id {ti.task_id} found!') # Clear downstream with create_session() as session: clear_task_instances([t for t in tis if t.task_id == "downstream"], session=session, dag=dag) # Run downstream again downstream.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) # Check if the states are correct. for ti in dr.get_task_instances(): if ti.task_id == 'make_choice': assert ti.state == State.SUCCESS elif ti.task_id == 'downstream': assert ti.state == State.SKIPPED else: raise ValueError(f'Invalid task id {ti.task_id} found!')
def test_get_states_count_upstream_ti(self): """ this test tests the helper function '_get_states_count_upstream_ti' as a unit and inside update_state """ from airflow.ti_deps.dep_context import DepContext get_states_count_upstream_ti = TriggerRuleDep._get_states_count_upstream_ti session = settings.Session() now = timezone.utcnow() dag = DAG( 'test_dagrun_with_pre_tis', start_date=DEFAULT_DATE, default_args={'owner': 'owner1'}) with dag: op1 = DummyOperator(task_id='A') op2 = DummyOperator(task_id='B') op3 = DummyOperator(task_id='C') op4 = DummyOperator(task_id='D') op5 = DummyOperator(task_id='E', trigger_rule=TriggerRule.ONE_FAILED) op1.set_downstream([op2, op3]) # op1 >> op2, op3 op4.set_upstream([op3, op2]) # op3, op2 >> op4 op5.set_upstream([op2, op3, op4]) # (op2, op3, op4) >> op5 clear_db_runs() dag.clear() dr = dag.create_dagrun(run_id='test_dagrun_with_pre_tis', state=State.RUNNING, execution_date=now, start_date=now) ti_op1 = TaskInstance(task=dag.get_task(op1.task_id), execution_date=dr.execution_date) ti_op2 = TaskInstance(task=dag.get_task(op2.task_id), execution_date=dr.execution_date) ti_op3 = TaskInstance(task=dag.get_task(op3.task_id), execution_date=dr.execution_date) ti_op4 = TaskInstance(task=dag.get_task(op4.task_id), execution_date=dr.execution_date) ti_op5 = TaskInstance(task=dag.get_task(op5.task_id), execution_date=dr.execution_date) ti_op1.set_state(state=State.SUCCESS, session=session) ti_op2.set_state(state=State.FAILED, session=session) ti_op3.set_state(state=State.SUCCESS, session=session) ti_op4.set_state(state=State.SUCCESS, session=session) ti_op5.set_state(state=State.SUCCESS, session=session) session.commit() # check handling with cases that tasks are triggered from backfill with no finished tasks finished_tasks = DepContext().ensure_finished_tasks(ti_op2.task.dag, ti_op2.execution_date, session) self.assertEqual(get_states_count_upstream_ti(finished_tasks=finished_tasks, ti=ti_op2), (1, 0, 0, 0, 1)) finished_tasks = dr.get_task_instances(state=State.finished() + [State.UPSTREAM_FAILED], session=session) self.assertEqual(get_states_count_upstream_ti(finished_tasks=finished_tasks, ti=ti_op4), (1, 0, 1, 0, 2)) self.assertEqual(get_states_count_upstream_ti(finished_tasks=finished_tasks, ti=ti_op5), (2, 0, 1, 0, 3)) dr.update_state() self.assertEqual(State.SUCCESS, dr.state)
def test_with_dag_run(self): value = False dag = DAG('shortcircuit_operator_test_with_dag_run', default_args={ 'owner': 'airflow', 'start_date': DEFAULT_DATE }, schedule_interval=INTERVAL) short_op = ShortCircuitOperator(task_id='make_choice', dag=dag, python_callable=lambda: value) branch_1 = DummyOperator(task_id='branch_1', dag=dag) branch_1.set_upstream(short_op) branch_2 = DummyOperator(task_id='branch_2', dag=dag) branch_2.set_upstream(branch_1) upstream = DummyOperator(task_id='upstream', dag=dag) upstream.set_downstream(short_op) dag.clear() logging.error("Tasks {}".format(dag.tasks)) dr = dag.create_dagrun( run_id="manual__", start_date=timezone.utcnow(), execution_date=DEFAULT_DATE, state=State.RUNNING ) upstream.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) short_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) tis = dr.get_task_instances() self.assertEqual(len(tis), 4) for ti in tis: if ti.task_id == 'make_choice': self.assertEqual(ti.state, State.SUCCESS) elif ti.task_id == 'upstream': self.assertEqual(ti.state, State.SUCCESS) elif ti.task_id == 'branch_1' or ti.task_id == 'branch_2': self.assertEqual(ti.state, State.SKIPPED) else: raise Exception value = True dag.clear() dr.verify_integrity() upstream.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) short_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) tis = dr.get_task_instances() self.assertEqual(len(tis), 4) for ti in tis: if ti.task_id == 'make_choice': self.assertEqual(ti.state, State.SUCCESS) elif ti.task_id == 'upstream': self.assertEqual(ti.state, State.SUCCESS) elif ti.task_id == 'branch_1' or ti.task_id == 'branch_2': self.assertEqual(ti.state, State.NONE) else: raise Exception
def test_dag_clear(self): dag = DAG('test_dag_clear', start_date=DEFAULT_DATE, end_date=DEFAULT_DATE + datetime.timedelta(days=10)) task0 = DummyOperator(task_id='test_dag_clear_task_0', owner='test', dag=dag) ti0 = TI(task=task0, execution_date=DEFAULT_DATE) dag.create_dagrun( execution_date=ti0.execution_date, state=State.RUNNING, run_type=DagRunType.SCHEDULED, ) # Next try to run will be try 1 assert ti0.try_number == 1 ti0.run() assert ti0.try_number == 2 dag.clear() ti0.refresh_from_db() assert ti0.try_number == 2 assert ti0.state == State.NONE assert ti0.max_tries == 1 task1 = DummyOperator(task_id='test_dag_clear_task_1', owner='test', dag=dag, retries=2) ti1 = TI(task=task1, execution_date=DEFAULT_DATE) assert ti1.max_tries == 2 ti1.try_number = 1 # Next try will be 2 ti1.run() assert ti1.try_number == 3 assert ti1.max_tries == 2 dag.clear() ti0.refresh_from_db() ti1.refresh_from_db() # after clear dag, ti2 should show attempt 3 of 5 assert ti1.max_tries == 4 assert ti1.try_number == 3 # after clear dag, ti1 should show attempt 2 of 2 assert ti0.try_number == 2 assert ti0.max_tries == 1
def test_process_bind_param_naive(self): """ Check if naive datetimes are prevented from saving to the db """ dag_id = 'test_process_bind_param_naive' # naive start_date = datetime.datetime.now() dag = DAG(dag_id=dag_id, start_date=start_date) dag.clear() with self.assertRaises((ValueError, StatementError)): dag.create_dagrun(run_id=start_date.isoformat, state=State.NONE, execution_date=start_date, start_date=start_date, session=self.session) dag.clear()
def setUp(self): from airflow.www_rbac.views import dagbag from airflow.utils.state import State dag = DAG(self.DAG_ID, start_date=self.DEFAULT_DATE) dagbag.bag_dag(dag, parent_dag=dag, root_dag=dag) self.runs = [] for rd in self.RUNS_DATA: run = dag.create_dagrun(run_id=rd[0], execution_date=rd[1], state=State.SUCCESS, external_trigger=True) self.runs.append(run)
def test_clear_task_instances_without_task(self): dag = DAG( 'test_clear_task_instances_without_task', start_date=DEFAULT_DATE, end_date=DEFAULT_DATE + datetime.timedelta(days=10), ) task0 = DummyOperator(task_id='task0', owner='test', dag=dag) task1 = DummyOperator(task_id='task1', owner='test', dag=dag, retries=2) ti0 = TI(task=task0, execution_date=DEFAULT_DATE) ti1 = TI(task=task1, execution_date=DEFAULT_DATE) dag.create_dagrun( execution_date=ti0.execution_date, state=State.RUNNING, run_type=DagRunType.SCHEDULED, ) ti0.run() ti1.run() # Remove the task from dag. dag.task_dict = {} assert not dag.has_task(task0.task_id) assert not dag.has_task(task1.task_id) with create_session() as session: qry = session.query(TI).filter(TI.dag_id == dag.dag_id).all() clear_task_instances(qry, session) # When dag is None, max_tries will be maximum of original max_tries or try_number. ti0.refresh_from_db() ti1.refresh_from_db() # Next try to run will be try 2 assert ti0.try_number == 2 assert ti0.max_tries == 1 assert ti1.try_number == 2 assert ti1.max_tries == 2
def test_runtype_enum_escape(): """ Make sure DagRunType.SCHEDULE is converted to string 'scheduled' when referenced in DB query """ with create_session() as session: dag = DAG(dag_id='test_enum_dags', start_date=DEFAULT_DATE) dag.create_dagrun( run_type=DagRunType.SCHEDULED, state=State.RUNNING, execution_date=DEFAULT_DATE, start_date=DEFAULT_DATE, session=session, ) query = session.query( DagRun.dag_id, DagRun.state, DagRun.run_type, ).filter( DagRun.dag_id == dag.dag_id, # make sure enum value can be used in filter queries DagRun.run_type == DagRunType.SCHEDULED, ) assert str( query.statement.compile(compile_kwargs={"literal_binds": True}) ) == ( 'SELECT dag_run.dag_id, dag_run.state, dag_run.run_type \n' 'FROM dag_run \n' "WHERE dag_run.dag_id = 'test_enum_dags' AND dag_run.run_type = 'scheduled'" ) rows = query.all() assert len(rows) == 1 assert rows[0].dag_id == dag.dag_id assert rows[0].state == State.RUNNING # make sure value in db is stored as `scheduled`, not `DagRunType.SCHEDULED` assert rows[0].run_type == 'scheduled' session.rollback()
def test_dagrun_set_state_end_date(self): session = settings.Session() dag = DAG( 'test_dagrun_set_state_end_date', start_date=DEFAULT_DATE, default_args={'owner': 'owner1'}) dag.clear() now = timezone.utcnow() dr = dag.create_dagrun(run_id='test_dagrun_set_state_end_date', state=State.RUNNING, execution_date=now, start_date=now) # Initial end_date should be NULL # State.SUCCESS and State.FAILED are all ending state and should set end_date # State.RUNNING set end_date back to NULL session.add(dr) session.commit() self.assertIsNone(dr.end_date) dr.set_state(State.SUCCESS) session.merge(dr) session.commit() dr_database = session.query(DagRun).filter( DagRun.run_id == 'test_dagrun_set_state_end_date' ).one() self.assertIsNotNone(dr_database.end_date) self.assertEqual(dr.end_date, dr_database.end_date) dr.set_state(State.RUNNING) session.merge(dr) session.commit() dr_database = session.query(DagRun).filter( DagRun.run_id == 'test_dagrun_set_state_end_date' ).one() self.assertIsNone(dr_database.end_date) dr.set_state(State.FAILED) session.merge(dr) session.commit() dr_database = session.query(DagRun).filter( DagRun.run_id == 'test_dagrun_set_state_end_date' ).one() self.assertIsNotNone(dr_database.end_date) self.assertEqual(dr.end_date, dr_database.end_date)
def test_emit_scheduling_delay(self, schedule_interval, expected): """ Tests that dag scheduling delay stat is set properly once running scheduled dag. dag_run.update_state() invokes the _emit_true_scheduling_delay_stats_for_finished_state method. """ dag = DAG(dag_id='test_emit_dag_stats', start_date=days_ago(1), schedule_interval=schedule_interval) dag_task = DummyOperator(task_id='dummy', dag=dag, owner='airflow') session = settings.Session() try: orm_dag = DagModel( dag_id=dag.dag_id, has_task_concurrency_limits=False, next_dagrun=dag.start_date, next_dagrun_create_after=dag.following_schedule( dag.start_date), is_active=True, ) session.add(orm_dag) session.flush() dag_run = dag.create_dagrun( run_type=DagRunType.SCHEDULED, state=State.SUCCESS, execution_date=dag.start_date, start_date=dag.start_date, session=session, ) ti = dag_run.get_task_instance(dag_task.task_id, session) ti.set_state(State.SUCCESS, session) session.flush() with mock.patch.object(Stats, 'timing') as stats_mock: dag_run.update_state(session) metric_name = f'dagrun.{dag.dag_id}.first_task_scheduling_delay' if expected: true_delay = ti.start_date - dag.following_schedule( dag_run.execution_date) sched_delay_stat_call = call(metric_name, true_delay) assert sched_delay_stat_call in stats_mock.mock_calls else: # Assert that we never passed the metric sched_delay_stat_call = call(metric_name, mock.ANY) assert sched_delay_stat_call not in stats_mock.mock_calls finally: # Don't write anything to the DB session.rollback() session.close()
def test_operator_clear(self): dag = DAG( 'test_operator_clear', start_date=DEFAULT_DATE, end_date=DEFAULT_DATE + datetime.timedelta(days=10), ) op1 = DummyOperator(task_id='bash_op', owner='test', dag=dag) op2 = DummyOperator(task_id='dummy_op', owner='test', dag=dag, retries=1) op2.set_upstream(op1) ti1 = TI(task=op1, execution_date=DEFAULT_DATE) ti2 = TI(task=op2, execution_date=DEFAULT_DATE) dag.create_dagrun( execution_date=ti1.execution_date, state=State.RUNNING, run_type=DagRunType.SCHEDULED, ) ti2.run() # Dependency not met assert ti2.try_number == 1 assert ti2.max_tries == 1 op2.clear(upstream=True) ti1.run() ti2.run(ignore_ti_state=True) assert ti1.try_number == 2 # max_tries is 0 because there is no task instance in db for ti1 # so clear won't change the max_tries. assert ti1.max_tries == 0 assert ti2.try_number == 2 # try_number (0) + retries(1) assert ti2.max_tries == 1
def test_dagrun_success_conditions(self): session = settings.Session() dag = DAG( 'test_dagrun_success_conditions', start_date=DEFAULT_DATE, default_args={'owner': 'owner1'}) # A -> B # A -> C -> D # ordered: B, D, C, A or D, B, C, A or D, C, B, A with dag: op1 = DummyOperator(task_id='A') op2 = DummyOperator(task_id='B') op3 = DummyOperator(task_id='C') op4 = DummyOperator(task_id='D') op1.set_upstream([op2, op3]) op3.set_upstream(op4) dag.clear() now = datetime.datetime.now() dr = dag.create_dagrun(run_id='test_dagrun_success_conditions', state=State.RUNNING, execution_date=now, start_date=now) # op1 = root ti_op1 = dr.get_task_instance(task_id=op1.task_id) ti_op1.set_state(state=State.SUCCESS, session=session) ti_op2 = dr.get_task_instance(task_id=op2.task_id) ti_op3 = dr.get_task_instance(task_id=op3.task_id) ti_op4 = dr.get_task_instance(task_id=op4.task_id) # root is successful, but unfinished tasks state = dr.update_state() self.assertEqual(State.RUNNING, state) # one has failed, but root is successful ti_op2.set_state(state=State.FAILED, session=session) ti_op3.set_state(state=State.SUCCESS, session=session) ti_op4.set_state(state=State.SUCCESS, session=session) state = dr.update_state() self.assertEqual(State.SUCCESS, state) # upstream dependency failed, root has not run ti_op1.set_state(State.NONE, session) state = dr.update_state() self.assertEqual(State.FAILED, state)
def setUp(self): from airflow.www_rbac.views import dagbag from airflow.utils.state import State dag = DAG(self.DAG_ID, start_date=self.DEFAULT_DATE) dagbag.bag_dag(dag, parent_dag=dag, root_dag=dag) self.runs = [] for rd in self.RUNS_DATA: run = dag.create_dagrun( run_id=rd[0], execution_date=rd[1], state=State.SUCCESS, external_trigger=True ) self.runs.append(run)
def test_parent_follow_branch(): """ A simple DAG with a BranchPythonOperator that follows op2. NotPreviouslySkippedDep is met. """ start_date = pendulum.datetime(2020, 1, 1) dag = DAG("test_parent_follow_branch_dag", schedule_interval=None, start_date=start_date) dag.create_dagrun(run_type=DagRunType.MANUAL, state=State.RUNNING, execution_date=start_date) op1 = BranchPythonOperator(task_id="op1", python_callable=lambda: "op2", dag=dag) op2 = DummyOperator(task_id="op2", dag=dag) op1 >> op2 TaskInstance(op1, start_date).run() ti2 = TaskInstance(op2, start_date) with create_session() as session: dep = NotPreviouslySkippedDep() assert len(list(dep.get_dep_statuses(ti2, session, DepContext()))) == 0 assert dep.is_met(ti2, session) assert ti2.state != State.SKIPPED
def test_dagstats_crud(self): DagStat.create(dag_id='test_dagstats_crud') session = settings.Session() qry = session.query(DagStat).filter( DagStat.dag_id == 'test_dagstats_crud') self.assertEqual(len(qry.all()), len(State.dag_states)) DagStat.set_dirty(dag_id='test_dagstats_crud') res = qry.all() for stat in res: self.assertTrue(stat.dirty) # create missing DagStat.set_dirty(dag_id='test_dagstats_crud_2') qry2 = session.query(DagStat).filter( DagStat.dag_id == 'test_dagstats_crud_2') self.assertEqual(len(qry2.all()), len(State.dag_states)) dag = DAG('test_dagstats_crud', start_date=DEFAULT_DATE, default_args={'owner': 'owner1'}) with dag: op1 = DummyOperator(task_id='A') now = datetime.datetime.now() dr = dag.create_dagrun( run_id='manual__' + now.isoformat(), execution_date=now, start_date=now, state=State.FAILED, external_trigger=False, ) DagStat.update(dag_ids=['test_dagstats_crud']) res = qry.all() for stat in res: if stat.state == State.FAILED: self.assertEqual(stat.count, 1) else: self.assertEqual(stat.count, 0) DagStat.update() res = qry2.all() for stat in res: self.assertFalse(stat.dirty)
def setUp(self): app = application.create_app(testing=True) app.config['WTF_CSRF_METHODS'] = [] self.app = app.test_client() self.session = Session() from airflow.www.views import dagbag from airflow.utils.state import State dag = DAG(self.DAG_ID, start_date=self.DEFAULT_DATE) dagbag.bag_dag(dag, parent_dag=dag, root_dag=dag) self.runs = [] for rd in self.RUNS_DATA: run = dag.create_dagrun(run_id=rd[0], execution_date=rd[1], state=State.SUCCESS, external_trigger=True) self.runs.append(run)
def test_dagstats_crud(self): DagStat.create(dag_id='test_dagstats_crud') session = settings.Session() qry = session.query(DagStat).filter(DagStat.dag_id == 'test_dagstats_crud') self.assertEqual(len(qry.all()), len(State.dag_states)) DagStat.set_dirty(dag_id='test_dagstats_crud') res = qry.all() for stat in res: self.assertTrue(stat.dirty) # create missing DagStat.set_dirty(dag_id='test_dagstats_crud_2') qry2 = session.query(DagStat).filter(DagStat.dag_id == 'test_dagstats_crud_2') self.assertEqual(len(qry2.all()), len(State.dag_states)) dag = DAG( 'test_dagstats_crud', start_date=DEFAULT_DATE, default_args={'owner': 'owner1'}) with dag: op1 = DummyOperator(task_id='A') now = datetime.datetime.now() dr = dag.create_dagrun( run_id='manual__' + now.isoformat(), execution_date=now, start_date=now, state=State.FAILED, external_trigger=False, ) DagStat.update(dag_ids=['test_dagstats_crud']) res = qry.all() for stat in res: if stat.state == State.FAILED: self.assertEqual(stat.count, 1) else: self.assertEqual(stat.count, 0) DagStat.update() res = qry2.all() for stat in res: self.assertFalse(stat.dirty)
def test_sub_set_subdag(self): dag = DAG('test_sub_set_subdag', start_date=DEFAULT_DATE, default_args={'owner': 'owner1'}) with dag: op1 = DummyOperator(task_id='leave1') op2 = DummyOperator(task_id='leave2') op3 = DummyOperator(task_id='upstream_level_1') op4 = DummyOperator(task_id='upstream_level_2') op5 = DummyOperator(task_id='upstream_level_3') # order randomly op2.set_downstream(op3) op1.set_downstream(op3) op4.set_downstream(op5) op3.set_downstream(op4) dag.clear() dr = dag.create_dagrun(run_id="test", state=State.RUNNING, execution_date=DEFAULT_DATE, start_date=DEFAULT_DATE) executor = MockExecutor() sub_dag = dag.sub_dag(task_regex="leave*", include_downstream=False, include_upstream=False) job = BackfillJob(dag=sub_dag, start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, executor=executor) job.run() self.assertRaises(sqlalchemy.orm.exc.NoResultFound, dr.refresh_from_db) # the run_id should have changed, so a refresh won't work drs = DagRun.find(dag_id=dag.dag_id, execution_date=DEFAULT_DATE) dr = drs[0] self.assertEqual( BackfillJob.ID_FORMAT_PREFIX.format(DEFAULT_DATE.isoformat()), dr.run_id) for ti in dr.get_task_instances(): if ti.task_id == 'leave1' or ti.task_id == 'leave2': self.assertEqual(State.SUCCESS, ti.state) else: self.assertEqual(State.NONE, ti.state)
def test_dagrun_no_deadlock_with_shutdown(self): session = settings.Session() dag = DAG('test_dagrun_no_deadlock_with_shutdown', start_date=DEFAULT_DATE) with dag: op1 = DummyOperator(task_id='upstream_task') op2 = DummyOperator(task_id='downstream_task') op2.set_upstream(op1) dr = dag.create_dagrun(run_id='test_dagrun_no_deadlock_with_shutdown', state=State.RUNNING, execution_date=DEFAULT_DATE, start_date=DEFAULT_DATE) upstream_ti = dr.get_task_instance(task_id='upstream_task') upstream_ti.set_state(State.SHUTDOWN, session=session) dr.update_state() self.assertEqual(dr.state, State.RUNNING)
def test_execute_skip_if_dagrun_success(self): """ When there is an existing DagRun in SUCCESS state, skip the execution. """ dag = DAG('parent', default_args=default_args) subdag = DAG('parent.test', default_args=default_args) subdag.create_dagrun = Mock() subdag_task = SubDagOperator(task_id='test', subdag=subdag, dag=dag, poke_interval=1) subdag_task._get_dagrun = Mock() subdag_task._get_dagrun.return_value = self.dag_run_success subdag_task.pre_execute(context={'execution_date': DEFAULT_DATE}) subdag_task.execute(context={'execution_date': DEFAULT_DATE}) subdag_task.post_execute(context={'execution_date': DEFAULT_DATE}) subdag.create_dagrun.assert_not_called() self.assertEqual(3, len(subdag_task._get_dagrun.mock_calls))
def setUp(self): configuration.load_test_config() app = application.create_app(testing=True) app.config['WTF_CSRF_METHODS'] = [] self.app = app.test_client() self.session = Session() from airflow.www.views import dagbag from airflow.utils.state import State dag = DAG(self.DAG_ID, start_date=self.DEFAULT_DATE) dagbag.bag_dag(dag, parent_dag=dag, root_dag=dag) self.runs = [] for rd in self.RUNS_DATA: run = dag.create_dagrun( run_id=rd[0], execution_date=rd[1], state=State.SUCCESS, external_trigger=True ) self.runs.append(run)
class PythonOperatorTest(unittest.TestCase): @classmethod def setUpClass(cls): super(PythonOperatorTest, cls).setUpClass() session = Session() session.query(DagRun).delete() session.query(TI).delete() session.commit() session.close() def setUp(self): super(PythonOperatorTest, self).setUp() configuration.load_test_config() self.dag = DAG( 'test_dag', default_args={ 'owner': 'airflow', 'start_date': DEFAULT_DATE}, schedule_interval=INTERVAL) self.addCleanup(self.dag.clear) self.clear_run() self.addCleanup(self.clear_run) def tearDown(self): super(PythonOperatorTest, self).tearDown() session = Session() session.query(DagRun).delete() session.query(TI).delete() print(len(session.query(DagRun).all())) session.commit() session.close() for var in TI_CONTEXT_ENV_VARS: if var in os.environ: del os.environ[var] def do_run(self): self.run = True def clear_run(self): self.run = False def is_run(self): return self.run def test_python_operator_run(self): """Tests that the python callable is invoked on task run.""" task = PythonOperator( python_callable=self.do_run, task_id='python_operator', dag=self.dag) self.assertFalse(self.is_run()) task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) self.assertTrue(self.is_run()) def test_python_operator_python_callable_is_callable(self): """Tests that PythonOperator will only instantiate if the python_callable argument is callable.""" not_callable = {} with self.assertRaises(AirflowException): PythonOperator( python_callable=not_callable, task_id='python_operator', dag=self.dag) not_callable = None with self.assertRaises(AirflowException): PythonOperator( python_callable=not_callable, task_id='python_operator', dag=self.dag) def test_python_operator_shallow_copy_attr(self): not_callable = lambda x: x original_task = PythonOperator( python_callable=not_callable, task_id='python_operator', op_kwargs={'certain_attrs': ''}, dag=self.dag ) new_task = copy.deepcopy(original_task) # shallow copy op_kwargs self.assertEquals(id(original_task.op_kwargs['certain_attrs']), id(new_task.op_kwargs['certain_attrs'])) # shallow copy python_callable self.assertEquals(id(original_task.python_callable), id(new_task.python_callable)) def _env_var_check_callback(self): self.assertEqual('test_dag', os.environ['AIRFLOW_CTX_DAG_ID']) self.assertEqual('hive_in_python_op', os.environ['AIRFLOW_CTX_TASK_ID']) self.assertEqual(DEFAULT_DATE.isoformat(), os.environ['AIRFLOW_CTX_EXECUTION_DATE']) self.assertEqual('manual__' + DEFAULT_DATE.isoformat(), os.environ['AIRFLOW_CTX_DAG_RUN_ID']) def test_echo_env_variables(self): """ Test that env variables are exported correctly to the python callback in the task. """ self.dag.create_dagrun( run_id='manual__' + DEFAULT_DATE.isoformat(), execution_date=DEFAULT_DATE, start_date=DEFAULT_DATE, state=State.RUNNING, external_trigger=False, ) t = PythonOperator(task_id='hive_in_python_op', dag=self.dag, python_callable=self._env_var_check_callback ) t.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
class BranchOperatorTest(unittest.TestCase): @classmethod def setUpClass(cls): super(BranchOperatorTest, cls).setUpClass() with create_session() as session: session.query(DagRun).delete() session.query(TI).delete() def setUp(self): self.dag = DAG('branch_operator_test', default_args={ 'owner': 'airflow', 'start_date': DEFAULT_DATE}, schedule_interval=INTERVAL) self.branch_1 = DummyOperator(task_id='branch_1', dag=self.dag) self.branch_2 = DummyOperator(task_id='branch_2', dag=self.dag) def tearDown(self): super().tearDown() with create_session() as session: session.query(DagRun).delete() session.query(TI).delete() def test_without_dag_run(self): """This checks the defensive against non existent tasks in a dag run""" self.branch_op = BranchPythonOperator(task_id='make_choice', dag=self.dag, python_callable=lambda: 'branch_1') self.branch_1.set_upstream(self.branch_op) self.branch_2.set_upstream(self.branch_op) self.dag.clear() self.branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) with create_session() as session: tis = session.query(TI).filter( TI.dag_id == self.dag.dag_id, TI.execution_date == DEFAULT_DATE ) for ti in tis: if ti.task_id == 'make_choice': self.assertEqual(ti.state, State.SUCCESS) elif ti.task_id == 'branch_1': # should exist with state None self.assertEqual(ti.state, State.NONE) elif ti.task_id == 'branch_2': self.assertEqual(ti.state, State.SKIPPED) else: raise Exception def test_branch_list_without_dag_run(self): """This checks if the BranchPythonOperator supports branching off to a list of tasks.""" self.branch_op = BranchPythonOperator(task_id='make_choice', dag=self.dag, python_callable=lambda: ['branch_1', 'branch_2']) self.branch_1.set_upstream(self.branch_op) self.branch_2.set_upstream(self.branch_op) self.branch_3 = DummyOperator(task_id='branch_3', dag=self.dag) self.branch_3.set_upstream(self.branch_op) self.dag.clear() self.branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) with create_session() as session: tis = session.query(TI).filter( TI.dag_id == self.dag.dag_id, TI.execution_date == DEFAULT_DATE ) expected = { "make_choice": State.SUCCESS, "branch_1": State.NONE, "branch_2": State.NONE, "branch_3": State.SKIPPED, } for ti in tis: if ti.task_id in expected: self.assertEqual(ti.state, expected[ti.task_id]) else: raise Exception def test_with_dag_run(self): self.branch_op = BranchPythonOperator(task_id='make_choice', dag=self.dag, python_callable=lambda: 'branch_1') self.branch_1.set_upstream(self.branch_op) self.branch_2.set_upstream(self.branch_op) self.dag.clear() dr = self.dag.create_dagrun( run_id="manual__", start_date=timezone.utcnow(), execution_date=DEFAULT_DATE, state=State.RUNNING ) self.branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) tis = dr.get_task_instances() for ti in tis: if ti.task_id == 'make_choice': self.assertEqual(ti.state, State.SUCCESS) elif ti.task_id == 'branch_1': self.assertEqual(ti.state, State.NONE) elif ti.task_id == 'branch_2': self.assertEqual(ti.state, State.SKIPPED) else: raise Exception def test_with_skip_in_branch_downstream_dependencies(self): self.branch_op = BranchPythonOperator(task_id='make_choice', dag=self.dag, python_callable=lambda: 'branch_1') self.branch_op >> self.branch_1 >> self.branch_2 self.branch_op >> self.branch_2 self.dag.clear() dr = self.dag.create_dagrun( run_id="manual__", start_date=timezone.utcnow(), execution_date=DEFAULT_DATE, state=State.RUNNING ) self.branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) tis = dr.get_task_instances() for ti in tis: if ti.task_id == 'make_choice': self.assertEqual(ti.state, State.SUCCESS) elif ti.task_id == 'branch_1': self.assertEqual(ti.state, State.NONE) elif ti.task_id == 'branch_2': self.assertEqual(ti.state, State.NONE) else: raise Exception def test_with_skip_in_branch_downstream_dependencies2(self): self.branch_op = BranchPythonOperator(task_id='make_choice', dag=self.dag, python_callable=lambda: 'branch_2') self.branch_op >> self.branch_1 >> self.branch_2 self.branch_op >> self.branch_2 self.dag.clear() dr = self.dag.create_dagrun( run_id="manual__", start_date=timezone.utcnow(), execution_date=DEFAULT_DATE, state=State.RUNNING ) self.branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) tis = dr.get_task_instances() for ti in tis: if ti.task_id == 'make_choice': self.assertEqual(ti.state, State.SUCCESS) elif ti.task_id == 'branch_1': self.assertEqual(ti.state, State.SKIPPED) elif ti.task_id == 'branch_2': self.assertEqual(ti.state, State.NONE) else: raise Exception
class BranchOperatorTest(unittest.TestCase): @classmethod def setUpClass(cls): super(BranchOperatorTest, cls).setUpClass() session = Session() session.query(DagRun).delete() session.query(TI).delete() session.commit() session.close() def setUp(self): self.dag = DAG('branch_operator_test', default_args={ 'owner': 'airflow', 'start_date': DEFAULT_DATE}, schedule_interval=INTERVAL) self.branch_op = BranchPythonOperator(task_id='make_choice', dag=self.dag, python_callable=lambda: 'branch_1') self.branch_1 = DummyOperator(task_id='branch_1', dag=self.dag) self.branch_1.set_upstream(self.branch_op) self.branch_2 = DummyOperator(task_id='branch_2', dag=self.dag) self.branch_2.set_upstream(self.branch_op) self.dag.clear() def tearDown(self): super(BranchOperatorTest, self).tearDown() session = Session() session.query(DagRun).delete() session.query(TI).delete() print(len(session.query(DagRun).all())) session.commit() session.close() def test_without_dag_run(self): """This checks the defensive against non existent tasks in a dag run""" self.branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) session = Session() tis = session.query(TI).filter( TI.dag_id == self.dag.dag_id, TI.execution_date == DEFAULT_DATE ) session.close() for ti in tis: if ti.task_id == 'make_choice': self.assertEquals(ti.state, State.SUCCESS) elif ti.task_id == 'branch_1': # should exist with state None self.assertEquals(ti.state, State.NONE) elif ti.task_id == 'branch_2': self.assertEquals(ti.state, State.SKIPPED) else: raise def test_with_dag_run(self): dr = self.dag.create_dagrun( run_id="manual__", start_date=timezone.utcnow(), execution_date=DEFAULT_DATE, state=State.RUNNING ) self.branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) tis = dr.get_task_instances() for ti in tis: if ti.task_id == 'make_choice': self.assertEquals(ti.state, State.SUCCESS) elif ti.task_id == 'branch_1': self.assertEquals(ti.state, State.NONE) elif ti.task_id == 'branch_2': self.assertEquals(ti.state, State.SKIPPED) else: raise
class PythonOperatorTest(unittest.TestCase): @classmethod def setUpClass(cls): super(PythonOperatorTest, cls).setUpClass() with create_session() as session: session.query(DagRun).delete() session.query(TI).delete() def setUp(self): super().setUp() configuration.load_test_config() self.dag = DAG( 'test_dag', default_args={ 'owner': 'airflow', 'start_date': DEFAULT_DATE}, schedule_interval=INTERVAL) self.addCleanup(self.dag.clear) self.clear_run() self.addCleanup(self.clear_run) def tearDown(self): super().tearDown() with create_session() as session: session.query(DagRun).delete() session.query(TI).delete() for var in TI_CONTEXT_ENV_VARS: if var in os.environ: del os.environ[var] def do_run(self): self.run = True def clear_run(self): self.run = False def is_run(self): return self.run def test_python_operator_run(self): """Tests that the python callable is invoked on task run.""" task = PythonOperator( python_callable=self.do_run, task_id='python_operator', dag=self.dag) self.assertFalse(self.is_run()) task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) self.assertTrue(self.is_run()) def test_python_operator_python_callable_is_callable(self): """Tests that PythonOperator will only instantiate if the python_callable argument is callable.""" not_callable = {} with self.assertRaises(AirflowException): PythonOperator( python_callable=not_callable, task_id='python_operator', dag=self.dag) not_callable = None with self.assertRaises(AirflowException): PythonOperator( python_callable=not_callable, task_id='python_operator', dag=self.dag) def _assertCallsEqual(self, first, second): self.assertIsInstance(first, Call) self.assertIsInstance(second, Call) self.assertTupleEqual(first.args, second.args) self.assertDictEqual(first.kwargs, second.kwargs) def test_python_callable_arguments_are_templatized(self): """Test PythonOperator op_args are templatized""" recorded_calls = [] task = PythonOperator( task_id='python_operator', # a Mock instance cannot be used as a callable function or test fails with a # TypeError: Object of type Mock is not JSON serializable python_callable=(build_recording_function(recorded_calls)), op_args=[ 4, date(2019, 1, 1), "dag {{dag.dag_id}} ran on {{ds}}." ], dag=self.dag) self.dag.create_dagrun( run_id='manual__' + DEFAULT_DATE.isoformat(), execution_date=DEFAULT_DATE, start_date=DEFAULT_DATE, state=State.RUNNING ) task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) self.assertEqual(1, len(recorded_calls)) self._assertCallsEqual( recorded_calls[0], Call(4, date(2019, 1, 1), "dag {} ran on {}.".format(self.dag.dag_id, DEFAULT_DATE.date().isoformat())) ) def test_python_callable_keyword_arguments_are_templatized(self): """Test PythonOperator op_kwargs are templatized""" recorded_calls = [] task = PythonOperator( task_id='python_operator', # a Mock instance cannot be used as a callable function or test fails with a # TypeError: Object of type Mock is not JSON serializable python_callable=(build_recording_function(recorded_calls)), op_kwargs={ 'an_int': 4, 'a_date': date(2019, 1, 1), 'a_templated_string': "dag {{dag.dag_id}} ran on {{ds}}." }, dag=self.dag) self.dag.create_dagrun( run_id='manual__' + DEFAULT_DATE.isoformat(), execution_date=DEFAULT_DATE, start_date=DEFAULT_DATE, state=State.RUNNING ) task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) self.assertEqual(1, len(recorded_calls)) self._assertCallsEqual( recorded_calls[0], Call(an_int=4, a_date=date(2019, 1, 1), a_templated_string="dag {} ran on {}.".format( self.dag.dag_id, DEFAULT_DATE.date().isoformat())) ) def test_python_operator_shallow_copy_attr(self): not_callable = lambda x: x original_task = PythonOperator( python_callable=not_callable, task_id='python_operator', op_kwargs={'certain_attrs': ''}, dag=self.dag ) new_task = copy.deepcopy(original_task) # shallow copy op_kwargs self.assertEqual(id(original_task.op_kwargs['certain_attrs']), id(new_task.op_kwargs['certain_attrs'])) # shallow copy python_callable self.assertEqual(id(original_task.python_callable), id(new_task.python_callable)) def _env_var_check_callback(self): self.assertEqual('test_dag', os.environ['AIRFLOW_CTX_DAG_ID']) self.assertEqual('hive_in_python_op', os.environ['AIRFLOW_CTX_TASK_ID']) self.assertEqual(DEFAULT_DATE.isoformat(), os.environ['AIRFLOW_CTX_EXECUTION_DATE']) self.assertEqual('manual__' + DEFAULT_DATE.isoformat(), os.environ['AIRFLOW_CTX_DAG_RUN_ID']) def test_echo_env_variables(self): """ Test that env variables are exported correctly to the python callback in the task. """ self.dag.create_dagrun( run_id='manual__' + DEFAULT_DATE.isoformat(), execution_date=DEFAULT_DATE, start_date=DEFAULT_DATE, state=State.RUNNING, external_trigger=False, ) t = PythonOperator(task_id='hive_in_python_op', dag=self.dag, python_callable=self._env_var_check_callback ) t.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)