def test_clear_skipped_downstream_task(self): """ After a downstream task is skipped by ShortCircuitOperator, clearing the skipped task should not cause it to be executed. """ dag = DAG( 'shortcircuit_clear_skipped_downstream_task', default_args={ 'owner': 'airflow', 'start_date': DEFAULT_DATE }, schedule_interval=INTERVAL, ) short_op = ShortCircuitOperator(task_id='make_choice', dag=dag, python_callable=lambda: False) downstream = DummyOperator(task_id='downstream', dag=dag) short_op >> downstream dag.clear() dr = dag.create_dagrun( run_type=DagRunType.MANUAL, start_date=timezone.utcnow(), execution_date=DEFAULT_DATE, state=State.RUNNING, ) short_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) downstream.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) tis = dr.get_task_instances() for ti in tis: if ti.task_id == 'make_choice': assert ti.state == State.SUCCESS elif ti.task_id == 'downstream': assert ti.state == State.SKIPPED else: raise ValueError(f'Invalid task id {ti.task_id} found!') # Clear downstream with create_session() as session: clear_task_instances([t for t in tis if t.task_id == "downstream"], session=session, dag=dag) # Run downstream again downstream.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) # Check if the states are correct. for ti in dr.get_task_instances(): if ti.task_id == 'make_choice': assert ti.state == State.SUCCESS elif ti.task_id == 'downstream': assert ti.state == State.SKIPPED else: raise ValueError(f'Invalid task id {ti.task_id} found!')
def test_with_dag_run(self): value = False dag = DAG('shortcircuit_operator_test_with_dag_run', default_args={ 'owner': 'airflow', 'start_date': DEFAULT_DATE }, schedule_interval=INTERVAL) short_op = ShortCircuitOperator(task_id='make_choice', dag=dag, python_callable=lambda: value) branch_1 = DummyOperator(task_id='branch_1', dag=dag) branch_1.set_upstream(short_op) branch_2 = DummyOperator(task_id='branch_2', dag=dag) branch_2.set_upstream(branch_1) upstream = DummyOperator(task_id='upstream', dag=dag) upstream.set_downstream(short_op) dag.clear() logging.error("Tasks %s", dag.tasks) dr = dag.create_dagrun( run_type=DagRunType.MANUAL, start_date=timezone.utcnow(), execution_date=DEFAULT_DATE, state=State.RUNNING ) upstream.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) short_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) tis = dr.get_task_instances() self.assertEqual(len(tis), 4) for ti in tis: if ti.task_id == 'make_choice': self.assertEqual(ti.state, State.SUCCESS) elif ti.task_id == 'upstream': self.assertEqual(ti.state, State.SUCCESS) elif ti.task_id == 'branch_1' or ti.task_id == 'branch_2': self.assertEqual(ti.state, State.SKIPPED) else: raise ValueError(f'Invalid task id {ti.task_id} found!') value = True dag.clear() dr.verify_integrity() upstream.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) short_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) tis = dr.get_task_instances() self.assertEqual(len(tis), 4) for ti in tis: if ti.task_id == 'make_choice': self.assertEqual(ti.state, State.SUCCESS) elif ti.task_id == 'upstream': self.assertEqual(ti.state, State.SUCCESS) elif ti.task_id == 'branch_1' or ti.task_id == 'branch_2': self.assertEqual(ti.state, State.NONE) else: raise ValueError(f'Invalid task id {ti.task_id} found!')
def test_without_dag_run(self): """This checks the defensive against non existent tasks in a dag run""" value = False dag = DAG( 'shortcircuit_operator_test_without_dag_run', default_args={ 'owner': 'airflow', 'start_date': DEFAULT_DATE }, schedule_interval=INTERVAL, ) short_op = ShortCircuitOperator(task_id='make_choice', dag=dag, python_callable=lambda: value) branch_1 = DummyOperator(task_id='branch_1', dag=dag) branch_1.set_upstream(short_op) branch_2 = DummyOperator(task_id='branch_2', dag=dag) branch_2.set_upstream(branch_1) upstream = DummyOperator(task_id='upstream', dag=dag) upstream.set_downstream(short_op) dag.clear() short_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) with create_session() as session: tis = session.query(TI).filter(TI.dag_id == dag.dag_id, TI.execution_date == DEFAULT_DATE) for ti in tis: if ti.task_id == 'make_choice': assert ti.state == State.SUCCESS elif ti.task_id == 'upstream': # should not exist raise ValueError(f'Invalid task id {ti.task_id} found!') elif ti.task_id == 'branch_1' or ti.task_id == 'branch_2': assert ti.state == State.SKIPPED else: raise ValueError(f'Invalid task id {ti.task_id} found!') value = True dag.clear() short_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) for ti in tis: if ti.task_id == 'make_choice': assert ti.state == State.SUCCESS elif ti.task_id == 'upstream': # should not exist raise ValueError(f'Invalid task id {ti.task_id} found!') elif ti.task_id == 'branch_1' or ti.task_id == 'branch_2': assert ti.state == State.NONE else: raise ValueError(f'Invalid task id {ti.task_id} found!')