def test_dagrun_deadlock(self): session = settings.Session() dag = DAG('text_dagrun_deadlock', start_date=DEFAULT_DATE, default_args={'owner': 'owner1'}) with dag: op1 = DummyOperator(task_id='A') op2 = DummyOperator(task_id='B') op2.trigger_rule = TriggerRule.ONE_FAILED op2.set_upstream(op1) dag.clear() now = timezone.utcnow() dr = dag.create_dagrun( run_id='test_dagrun_deadlock', state=State.RUNNING, execution_date=now, start_date=now ) ti_op1 = dr.get_task_instance(task_id=op1.task_id) ti_op1.set_state(state=State.SUCCESS, session=session) ti_op2 = dr.get_task_instance(task_id=op2.task_id) ti_op2.set_state(state=State.NONE, session=session) dr.update_state() assert dr.state == State.RUNNING ti_op2.set_state(state=State.NONE, session=session) op2.trigger_rule = 'invalid' dr.update_state() assert dr.state == State.FAILED
def test_dagrun_update_state_end_date(self): session = settings.Session() dag = DAG( 'test_dagrun_update_state_end_date', start_date=DEFAULT_DATE, default_args={'owner': 'owner1'} ) # A -> B with dag: op1 = DummyOperator(task_id='A') op2 = DummyOperator(task_id='B') op1.set_upstream(op2) dag.clear() now = timezone.utcnow() dr = dag.create_dagrun( run_id='test_dagrun_update_state_end_date', state=State.RUNNING, execution_date=now, start_date=now, ) # Initial end_date should be NULL # State.SUCCESS and State.FAILED are all ending state and should set end_date # State.RUNNING set end_date back to NULL session.merge(dr) session.commit() assert dr.end_date is None ti_op1 = dr.get_task_instance(task_id=op1.task_id) ti_op1.set_state(state=State.SUCCESS, session=session) ti_op2 = dr.get_task_instance(task_id=op2.task_id) ti_op2.set_state(state=State.SUCCESS, session=session) dr.update_state() dr_database = session.query(DagRun).filter(DagRun.run_id == 'test_dagrun_update_state_end_date').one() assert dr_database.end_date is not None assert dr.end_date == dr_database.end_date ti_op1.set_state(state=State.RUNNING, session=session) ti_op2.set_state(state=State.RUNNING, session=session) dr.update_state() dr_database = session.query(DagRun).filter(DagRun.run_id == 'test_dagrun_update_state_end_date').one() assert dr._state == State.RUNNING assert dr.end_date is None assert dr_database.end_date is None ti_op1.set_state(state=State.FAILED, session=session) ti_op2.set_state(state=State.FAILED, session=session) dr.update_state() dr_database = session.query(DagRun).filter(DagRun.run_id == 'test_dagrun_update_state_end_date').one() assert dr_database.end_date is not None assert dr.end_date == dr_database.end_date
def test_without_dag_run(self): """This checks the defensive against non existent tasks in a dag run""" value = False dag = DAG( 'shortcircuit_operator_test_without_dag_run', default_args={ 'owner': 'airflow', 'start_date': DEFAULT_DATE }, schedule_interval=INTERVAL, ) short_op = ShortCircuitOperator(task_id='make_choice', dag=dag, python_callable=lambda: value) branch_1 = DummyOperator(task_id='branch_1', dag=dag) branch_1.set_upstream(short_op) branch_2 = DummyOperator(task_id='branch_2', dag=dag) branch_2.set_upstream(branch_1) upstream = DummyOperator(task_id='upstream', dag=dag) upstream.set_downstream(short_op) dag.clear() short_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) with create_session() as session: tis = session.query(TI).filter(TI.dag_id == dag.dag_id, TI.execution_date == DEFAULT_DATE) for ti in tis: if ti.task_id == 'make_choice': assert ti.state == State.SUCCESS elif ti.task_id == 'upstream': # should not exist raise ValueError(f'Invalid task id {ti.task_id} found!') elif ti.task_id == 'branch_1' or ti.task_id == 'branch_2': assert ti.state == State.SKIPPED else: raise ValueError(f'Invalid task id {ti.task_id} found!') value = True dag.clear() short_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) for ti in tis: if ti.task_id == 'make_choice': assert ti.state == State.SUCCESS elif ti.task_id == 'upstream': # should not exist raise ValueError(f'Invalid task id {ti.task_id} found!') elif ti.task_id == 'branch_1' or ti.task_id == 'branch_2': assert ti.state == State.NONE else: raise ValueError(f'Invalid task id {ti.task_id} found!')
def _make_smart_operator(self, index, **kwargs): poke_interval = 'poke_interval' smart_sensor_timeout = 'smart_sensor_timeout' if poke_interval not in kwargs: kwargs[poke_interval] = 0 if smart_sensor_timeout not in kwargs: kwargs[smart_sensor_timeout] = 0 smart_task = DummySmartSensor(task_id=SMART_OP + "_" + str(index), dag=self.dag, **kwargs) dummy_op = DummyOperator(task_id=DUMMY_OP, dag=self.dag) dummy_op.set_upstream(smart_task) return smart_task
def test_get_states_count_upstream_ti(self): """ this test tests the helper function '_get_states_count_upstream_ti' as a unit and inside update_state """ from airflow.ti_deps.dep_context import DepContext get_states_count_upstream_ti = TriggerRuleDep._get_states_count_upstream_ti session = settings.Session() now = timezone.utcnow() dag = DAG('test_dagrun_with_pre_tis', start_date=DEFAULT_DATE, default_args={'owner': 'owner1'}) with dag: op1 = DummyOperator(task_id='A') op2 = DummyOperator(task_id='B') op3 = DummyOperator(task_id='C') op4 = DummyOperator(task_id='D') op5 = DummyOperator(task_id='E', trigger_rule=TriggerRule.ONE_FAILED) op1.set_downstream([op2, op3]) # op1 >> op2, op3 op4.set_upstream([op3, op2]) # op3, op2 >> op4 op5.set_upstream([op2, op3, op4]) # (op2, op3, op4) >> op5 clear_db_runs() dag.clear() dr = dag.create_dagrun( run_id='test_dagrun_with_pre_tis', state=State.RUNNING, execution_date=now, start_date=now ) ti_op1 = TaskInstance(task=dag.get_task(op1.task_id), execution_date=dr.execution_date) ti_op2 = TaskInstance(task=dag.get_task(op2.task_id), execution_date=dr.execution_date) ti_op3 = TaskInstance(task=dag.get_task(op3.task_id), execution_date=dr.execution_date) ti_op4 = TaskInstance(task=dag.get_task(op4.task_id), execution_date=dr.execution_date) ti_op5 = TaskInstance(task=dag.get_task(op5.task_id), execution_date=dr.execution_date) ti_op1.set_state(state=State.SUCCESS, session=session) ti_op2.set_state(state=State.FAILED, session=session) ti_op3.set_state(state=State.SUCCESS, session=session) ti_op4.set_state(state=State.SUCCESS, session=session) ti_op5.set_state(state=State.SUCCESS, session=session) session.commit() # check handling with cases that tasks are triggered from backfill with no finished tasks finished_tasks = DepContext().ensure_finished_tasks(ti_op2.task.dag, ti_op2.execution_date, session) assert get_states_count_upstream_ti(finished_tasks=finished_tasks, ti=ti_op2) == (1, 0, 0, 0, 1) finished_tasks = dr.get_task_instances(state=State.finished, session=session) assert get_states_count_upstream_ti(finished_tasks=finished_tasks, ti=ti_op4) == (1, 0, 1, 0, 2) assert get_states_count_upstream_ti(finished_tasks=finished_tasks, ti=ti_op5) == (2, 0, 1, 0, 3) dr.update_state() assert State.SUCCESS == dr.state
def _make_sensor(self, return_value, task_id=SENSOR_OP, **kwargs): poke_interval = 'poke_interval' timeout = 'timeout' if poke_interval not in kwargs: kwargs[poke_interval] = 0 if timeout not in kwargs: kwargs[timeout] = 0 sensor = DummySensor(task_id=task_id, return_value=return_value, dag=self.dag, **kwargs) dummy_op = DummyOperator(task_id=DUMMY_OP, dag=self.dag) dummy_op.set_upstream(sensor) return sensor
def test_dagrun_success_conditions(self): session = settings.Session() dag = DAG('test_dagrun_success_conditions', start_date=DEFAULT_DATE, default_args={'owner': 'owner1'}) # A -> B # A -> C -> D # ordered: B, D, C, A or D, B, C, A or D, C, B, A with dag: op1 = DummyOperator(task_id='A') op2 = DummyOperator(task_id='B') op3 = DummyOperator(task_id='C') op4 = DummyOperator(task_id='D') op1.set_upstream([op2, op3]) op3.set_upstream(op4) dag.clear() now = timezone.utcnow() dr = dag.create_dagrun(run_id='test_dagrun_success_conditions', state=State.RUNNING, execution_date=now, start_date=now) # op1 = root ti_op1 = dr.get_task_instance(task_id=op1.task_id) ti_op1.set_state(state=State.SUCCESS, session=session) ti_op2 = dr.get_task_instance(task_id=op2.task_id) ti_op3 = dr.get_task_instance(task_id=op3.task_id) ti_op4 = dr.get_task_instance(task_id=op4.task_id) # root is successful, but unfinished tasks dr.update_state() assert State.RUNNING == dr.state # one has failed, but root is successful ti_op2.set_state(state=State.FAILED, session=session) ti_op3.set_state(state=State.SUCCESS, session=session) ti_op4.set_state(state=State.SUCCESS, session=session) dr.update_state() assert State.SUCCESS == dr.state
def test_dagrun_no_deadlock_with_shutdown(self): session = settings.Session() dag = DAG('test_dagrun_no_deadlock_with_shutdown', start_date=DEFAULT_DATE) with dag: op1 = DummyOperator(task_id='upstream_task') op2 = DummyOperator(task_id='downstream_task') op2.set_upstream(op1) dr = dag.create_dagrun( run_id='test_dagrun_no_deadlock_with_shutdown', state=State.RUNNING, execution_date=DEFAULT_DATE, start_date=DEFAULT_DATE, ) upstream_ti = dr.get_task_instance(task_id='upstream_task') upstream_ti.set_state(State.SHUTDOWN, session=session) dr.update_state() assert dr.state == State.RUNNING
def test_operator_clear(self): dag = DAG( 'test_operator_clear', start_date=DEFAULT_DATE, end_date=DEFAULT_DATE + datetime.timedelta(days=10), ) op1 = DummyOperator(task_id='bash_op', owner='test', dag=dag) op2 = DummyOperator(task_id='dummy_op', owner='test', dag=dag, retries=1) op2.set_upstream(op1) ti1 = TI(task=op1, execution_date=DEFAULT_DATE) ti2 = TI(task=op2, execution_date=DEFAULT_DATE) dag.create_dagrun( execution_date=ti1.execution_date, state=State.RUNNING, run_type=DagRunType.SCHEDULED, ) ti2.run() # Dependency not met assert ti2.try_number == 1 assert ti2.max_tries == 1 op2.clear(upstream=True) ti1.run() ti2.run(ignore_ti_state=True) assert ti1.try_number == 2 # max_tries is 0 because there is no task instance in db for ti1 # so clear won't change the max_tries. assert ti1.max_tries == 0 assert ti2.try_number == 2 # try_number (0) + retries(1) assert ti2.max_tries == 1
# specific language governing permissions and limitations # under the License. """This dag only runs some simple tasks to test Airflow's task execution.""" from datetime import datetime, timedelta from airflow.models.dag import DAG from airflow.operators.dummy import DummyOperator from airflow.utils.dates import days_ago now = datetime.now() now_to_the_hour = (now - timedelta(0, 0, 0, 0, 0, 3)).replace(minute=0, second=0, microsecond=0) START_DATE = now_to_the_hour DAG_NAME = 'test_dag_v2' default_args = { 'owner': 'airflow', 'depends_on_past': True, 'start_date': days_ago(2) } dag = DAG(DAG_NAME, schedule_interval='*/10 * * * *', default_args=default_args) run_this_1 = DummyOperator(task_id='run_this_1', dag=dag) run_this_2 = DummyOperator(task_id='run_this_2', dag=dag) run_this_2.set_upstream(run_this_1) run_this_3 = DummyOperator(task_id='run_this_3', dag=dag) run_this_3.set_upstream(run_this_2)
def test_skipping_non_latest(self): latest_task = LatestOnlyOperator(task_id='latest', dag=self.dag) downstream_task = DummyOperator(task_id='downstream', dag=self.dag) downstream_task2 = DummyOperator(task_id='downstream_2', dag=self.dag) downstream_task3 = DummyOperator(task_id='downstream_3', trigger_rule=TriggerRule.NONE_FAILED, dag=self.dag) downstream_task.set_upstream(latest_task) downstream_task2.set_upstream(downstream_task) downstream_task3.set_upstream(downstream_task) self.dag.create_dagrun( run_type=DagRunType.SCHEDULED, start_date=timezone.utcnow(), execution_date=DEFAULT_DATE, state=State.RUNNING, ) self.dag.create_dagrun( run_type=DagRunType.SCHEDULED, start_date=timezone.utcnow(), execution_date=timezone.datetime(2016, 1, 1, 12), state=State.RUNNING, ) self.dag.create_dagrun( run_type=DagRunType.SCHEDULED, start_date=timezone.utcnow(), execution_date=END_DATE, state=State.RUNNING, ) latest_task.run(start_date=DEFAULT_DATE, end_date=END_DATE) downstream_task.run(start_date=DEFAULT_DATE, end_date=END_DATE) downstream_task2.run(start_date=DEFAULT_DATE, end_date=END_DATE) downstream_task3.run(start_date=DEFAULT_DATE, end_date=END_DATE) latest_instances = get_task_instances('latest') exec_date_to_latest_state = { ti.execution_date: ti.state for ti in latest_instances } assert { timezone.datetime(2016, 1, 1): 'success', timezone.datetime(2016, 1, 1, 12): 'success', timezone.datetime(2016, 1, 2): 'success', } == exec_date_to_latest_state downstream_instances = get_task_instances('downstream') exec_date_to_downstream_state = { ti.execution_date: ti.state for ti in downstream_instances } assert { timezone.datetime(2016, 1, 1): 'skipped', timezone.datetime(2016, 1, 1, 12): 'skipped', timezone.datetime(2016, 1, 2): 'success', } == exec_date_to_downstream_state downstream_instances = get_task_instances('downstream_2') exec_date_to_downstream_state = { ti.execution_date: ti.state for ti in downstream_instances } assert { timezone.datetime(2016, 1, 1): None, timezone.datetime(2016, 1, 1, 12): None, timezone.datetime(2016, 1, 2): 'success', } == exec_date_to_downstream_state downstream_instances = get_task_instances('downstream_3') exec_date_to_downstream_state = { ti.execution_date: ti.state for ti in downstream_instances } assert { timezone.datetime(2016, 1, 1): 'success', timezone.datetime(2016, 1, 1, 12): 'success', timezone.datetime(2016, 1, 2): 'success', } == exec_date_to_downstream_state
class TestBranchDateTimeOperator(unittest.TestCase): @classmethod def setUpClass(cls): super().setUpClass() with create_session() as session: session.query(DagRun).delete() session.query(TI).delete() cls.targets = [ (datetime.datetime(2020, 7, 7, 10, 0, 0), datetime.datetime(2020, 7, 7, 11, 0, 0)), (datetime.time(10, 0, 0), datetime.time(11, 0, 0)), (datetime.datetime(2020, 7, 7, 10, 0, 0), datetime.time(11, 0, 0)), (datetime.time(10, 0, 0), datetime.datetime(2020, 7, 7, 11, 0, 0)), (datetime.time(11, 0, 0), datetime.time(10, 0, 0)), ] def setUp(self): self.dag = DAG( 'branch_datetime_operator_test', default_args={ 'owner': 'airflow', 'start_date': DEFAULT_DATE }, schedule_interval=INTERVAL, ) self.branch_1 = DummyOperator(task_id='branch_1', dag=self.dag) self.branch_2 = DummyOperator(task_id='branch_2', dag=self.dag) self.branch_op = BranchDateTimeOperator( task_id='datetime_branch', follow_task_ids_if_true='branch_1', follow_task_ids_if_false='branch_2', target_upper=datetime.datetime(2020, 7, 7, 11, 0, 0), target_lower=datetime.datetime(2020, 7, 7, 10, 0, 0), dag=self.dag, ) self.branch_1.set_upstream(self.branch_op) self.branch_2.set_upstream(self.branch_op) self.dag.clear() self.dr = self.dag.create_dagrun(run_id='manual__', start_date=DEFAULT_DATE, execution_date=DEFAULT_DATE, state=State.RUNNING) def tearDown(self): super().tearDown() with create_session() as session: session.query(DagRun).delete() session.query(TI).delete() def _assert_task_ids_match_states(self, task_ids_to_states): """Helper that asserts task instances with a given id are in a given state""" tis = self.dr.get_task_instances() for ti in tis: try: expected_state = task_ids_to_states[ti.task_id] except KeyError: raise ValueError(f'Invalid task id {ti.task_id} found!') else: self.assertEqual( ti.state, expected_state, f"Task {ti.task_id} has state {ti.state} instead of expected {expected_state}", ) def test_no_target_time(self): """Check if BranchDateTimeOperator raises exception on missing target""" with self.assertRaises(AirflowException): BranchDateTimeOperator( task_id='datetime_branch', follow_task_ids_if_true='branch_1', follow_task_ids_if_false='branch_2', target_upper=None, target_lower=None, dag=self.dag, ) @freezegun.freeze_time("2020-07-07 10:54:05") def test_branch_datetime_operator_falls_within_range(self): """Check BranchDateTimeOperator branch operation""" for target_lower, target_upper in self.targets: with self.subTest(target_lower=target_lower, target_upper=target_upper): self.branch_op.target_lower = target_lower self.branch_op.target_upper = target_upper self.branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) self._assert_task_ids_match_states({ 'datetime_branch': State.SUCCESS, 'branch_1': State.NONE, 'branch_2': State.SKIPPED, }) def test_branch_datetime_operator_falls_outside_range(self): """Check BranchDateTimeOperator branch operation""" dates = [ datetime.datetime(2020, 7, 7, 12, 0, 0, tzinfo=datetime.timezone.utc), datetime.datetime(2020, 6, 7, 12, 0, 0, tzinfo=datetime.timezone.utc), ] for target_lower, target_upper in self.targets: with self.subTest(target_lower=target_lower, target_upper=target_upper): self.branch_op.target_lower = target_lower self.branch_op.target_upper = target_upper for date in dates: with freezegun.freeze_time(date): self.branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) self._assert_task_ids_match_states({ 'datetime_branch': State.SUCCESS, 'branch_1': State.SKIPPED, 'branch_2': State.NONE, }) @freezegun.freeze_time("2020-07-07 10:54:05") def test_branch_datetime_operator_upper_comparison_within_range(self): """Check BranchDateTimeOperator branch operation""" for _, target_upper in self.targets: with self.subTest(target_upper=target_upper): self.branch_op.target_upper = target_upper self.branch_op.target_lower = None self.branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) self._assert_task_ids_match_states({ 'datetime_branch': State.SUCCESS, 'branch_1': State.NONE, 'branch_2': State.SKIPPED, }) @freezegun.freeze_time("2020-07-07 10:54:05") def test_branch_datetime_operator_lower_comparison_within_range(self): """Check BranchDateTimeOperator branch operation""" for target_lower, _ in self.targets: with self.subTest(target_lower=target_lower): self.branch_op.target_lower = target_lower self.branch_op.target_upper = None self.branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) self._assert_task_ids_match_states({ 'datetime_branch': State.SUCCESS, 'branch_1': State.NONE, 'branch_2': State.SKIPPED, }) @freezegun.freeze_time("2020-07-07 12:00:00") def test_branch_datetime_operator_upper_comparison_outside_range(self): """Check BranchDateTimeOperator branch operation""" for _, target_upper in self.targets: with self.subTest(target_upper=target_upper): self.branch_op.target_upper = target_upper self.branch_op.target_lower = None self.branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) self._assert_task_ids_match_states({ 'datetime_branch': State.SUCCESS, 'branch_1': State.SKIPPED, 'branch_2': State.NONE, }) @freezegun.freeze_time("2020-07-07 09:00:00") def test_branch_datetime_operator_lower_comparison_outside_range(self): """Check BranchDateTimeOperator branch operation""" for target_lower, _ in self.targets: with self.subTest(target_lower=target_lower): self.branch_op.target_lower = target_lower self.branch_op.target_upper = None self.branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) self._assert_task_ids_match_states({ 'datetime_branch': State.SUCCESS, 'branch_1': State.SKIPPED, 'branch_2': State.NONE, }) @freezegun.freeze_time("2020-12-01 09:00:00") def test_branch_datetime_operator_use_task_execution_date(self): """Check if BranchDateTimeOperator uses task execution date""" in_between_date = timezone.datetime(2020, 7, 7, 10, 30, 0) self.branch_op.use_task_execution_date = True self.dr = self.dag.create_dagrun( run_id='manual_exec_date__', start_date=in_between_date, execution_date=in_between_date, state=State.RUNNING, ) for target_lower, target_upper in self.targets: with self.subTest(target_lower=target_lower, target_upper=target_upper): self.branch_op.target_lower = target_lower self.branch_op.target_upper = target_upper self.branch_op.run(start_date=in_between_date, end_date=in_between_date) self._assert_task_ids_match_states({ 'datetime_branch': State.SUCCESS, 'branch_1': State.NONE, 'branch_2': State.SKIPPED, })
dag=dag1, pool='test_backfill_pooled_task_pool', ) # dag2 has been moved to test_prev_dagrun_dep.py # DAG tests that a Dag run that doesn't complete is marked failed dag3 = DAG(dag_id='test_dagrun_states_fail', default_args=default_args) dag3_task1 = PythonOperator(task_id='test_dagrun_fail', dag=dag3, python_callable=fail) dag3_task2 = DummyOperator( task_id='test_dagrun_succeed', dag=dag3, ) dag3_task2.set_upstream(dag3_task1) # DAG tests that a Dag run that completes but has a failure is marked success dag4 = DAG(dag_id='test_dagrun_states_success', default_args=default_args) dag4_task1 = PythonOperator( task_id='test_dagrun_fail', dag=dag4, python_callable=fail, ) dag4_task2 = DummyOperator(task_id='test_dagrun_succeed', dag=dag4, trigger_rule=TriggerRule.ALL_FAILED) dag4_task2.set_upstream(dag4_task1) # DAG tests that a Dag run that completes but has a root failure is marked fail dag5 = DAG(dag_id='test_dagrun_states_root_fail', default_args=default_args)
seven_days_ago = datetime.combine(datetime.today() - timedelta(7), datetime.min.time()) args = { 'owner': 'airflow', 'start_date': seven_days_ago, } dag = DAG(dag_id='example_branch_operator', default_args=args, schedule_interval="@daily") cmd = 'ls -l' run_this_first = DummyOperator(task_id='run_this_first', dag=dag) options = ['branch_a', 'branch_b', 'branch_c', 'branch_d'] branching = BranchPythonOperator( task_id='branching', python_callable=lambda: random.choice(options), dag=dag) branching.set_upstream(run_this_first) join = DummyOperator(task_id='join', trigger_rule='one_success', dag=dag) for option in options: t = DummyOperator(task_id=option, dag=dag) t.set_upstream(branching) dummy_follow = DummyOperator(task_id='follow_' + option, dag=dag) t.set_downstream(dummy_follow) dummy_follow.set_downstream(join)
class TestSqlBranch(TestHiveEnvironment, unittest.TestCase): """ Test for SQL Branch Operator """ @classmethod def setUpClass(cls): super().setUpClass() with create_session() as session: session.query(DagRun).delete() session.query(TI).delete() def setUp(self): super().setUp() self.dag = DAG( "sql_branch_operator_test", default_args={ "owner": "airflow", "start_date": DEFAULT_DATE }, schedule_interval=INTERVAL, ) self.branch_1 = DummyOperator(task_id="branch_1", dag=self.dag) self.branch_2 = DummyOperator(task_id="branch_2", dag=self.dag) self.branch_3 = None def tearDown(self): super().tearDown() with create_session() as session: session.query(DagRun).delete() session.query(TI).delete() def test_unsupported_conn_type(self): """Check if BranchSQLOperator throws an exception for unsupported connection type""" op = BranchSQLOperator( task_id="make_choice", conn_id="redis_default", sql="SELECT count(1) FROM INFORMATION_SCHEMA.TABLES", follow_task_ids_if_true="branch_1", follow_task_ids_if_false="branch_2", dag=self.dag, ) with pytest.raises(AirflowException): op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) def test_invalid_conn(self): """Check if BranchSQLOperator throws an exception for invalid connection""" op = BranchSQLOperator( task_id="make_choice", conn_id="invalid_connection", sql="SELECT count(1) FROM INFORMATION_SCHEMA.TABLES", follow_task_ids_if_true="branch_1", follow_task_ids_if_false="branch_2", dag=self.dag, ) with pytest.raises(AirflowException): op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) def test_invalid_follow_task_true(self): """Check if BranchSQLOperator throws an exception for invalid connection""" op = BranchSQLOperator( task_id="make_choice", conn_id="invalid_connection", sql="SELECT count(1) FROM INFORMATION_SCHEMA.TABLES", follow_task_ids_if_true=None, follow_task_ids_if_false="branch_2", dag=self.dag, ) with pytest.raises(AirflowException): op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) def test_invalid_follow_task_false(self): """Check if BranchSQLOperator throws an exception for invalid connection""" op = BranchSQLOperator( task_id="make_choice", conn_id="invalid_connection", sql="SELECT count(1) FROM INFORMATION_SCHEMA.TABLES", follow_task_ids_if_true="branch_1", follow_task_ids_if_false=None, dag=self.dag, ) with pytest.raises(AirflowException): op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) @pytest.mark.backend("mysql") def test_sql_branch_operator_mysql(self): """Check if BranchSQLOperator works with backend""" branch_op = BranchSQLOperator( task_id="make_choice", conn_id="mysql_default", sql="SELECT 1", follow_task_ids_if_true="branch_1", follow_task_ids_if_false="branch_2", dag=self.dag, ) branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) @pytest.mark.backend("postgres") def test_sql_branch_operator_postgres(self): """Check if BranchSQLOperator works with backend""" branch_op = BranchSQLOperator( task_id="make_choice", conn_id="postgres_default", sql="SELECT 1", follow_task_ids_if_true="branch_1", follow_task_ids_if_false="branch_2", dag=self.dag, ) branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) @mock.patch("airflow.operators.sql.BaseSQLOperator.get_db_hook") def test_branch_single_value_with_dag_run(self, mock_get_db_hook): """Check BranchSQLOperator branch operation""" branch_op = BranchSQLOperator( task_id="make_choice", conn_id="mysql_default", sql="SELECT 1", follow_task_ids_if_true="branch_1", follow_task_ids_if_false="branch_2", dag=self.dag, ) self.branch_1.set_upstream(branch_op) self.branch_2.set_upstream(branch_op) self.dag.clear() dr = self.dag.create_dagrun( run_id="manual__", start_date=timezone.utcnow(), execution_date=DEFAULT_DATE, state=State.RUNNING, ) mock_get_records = mock_get_db_hook.return_value.get_first mock_get_records.return_value = 1 branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) tis = dr.get_task_instances() for ti in tis: if ti.task_id == "make_choice": assert ti.state == State.SUCCESS elif ti.task_id == "branch_1": assert ti.state == State.NONE elif ti.task_id == "branch_2": assert ti.state == State.SKIPPED else: raise ValueError(f"Invalid task id {ti.task_id} found!") @mock.patch("airflow.operators.sql.BaseSQLOperator.get_db_hook") def test_branch_true_with_dag_run(self, mock_get_db_hook): """Check BranchSQLOperator branch operation""" branch_op = BranchSQLOperator( task_id="make_choice", conn_id="mysql_default", sql="SELECT 1", follow_task_ids_if_true="branch_1", follow_task_ids_if_false="branch_2", dag=self.dag, ) self.branch_1.set_upstream(branch_op) self.branch_2.set_upstream(branch_op) self.dag.clear() dr = self.dag.create_dagrun( run_id="manual__", start_date=timezone.utcnow(), execution_date=DEFAULT_DATE, state=State.RUNNING, ) mock_get_records = mock_get_db_hook.return_value.get_first for true_value in SUPPORTED_TRUE_VALUES: mock_get_records.return_value = true_value branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) tis = dr.get_task_instances() for ti in tis: if ti.task_id == "make_choice": assert ti.state == State.SUCCESS elif ti.task_id == "branch_1": assert ti.state == State.NONE elif ti.task_id == "branch_2": assert ti.state == State.SKIPPED else: raise ValueError(f"Invalid task id {ti.task_id} found!") @mock.patch("airflow.operators.sql.BaseSQLOperator.get_db_hook") def test_branch_false_with_dag_run(self, mock_get_db_hook): """Check BranchSQLOperator branch operation""" branch_op = BranchSQLOperator( task_id="make_choice", conn_id="mysql_default", sql="SELECT 1", follow_task_ids_if_true="branch_1", follow_task_ids_if_false="branch_2", dag=self.dag, ) self.branch_1.set_upstream(branch_op) self.branch_2.set_upstream(branch_op) self.dag.clear() dr = self.dag.create_dagrun( run_id="manual__", start_date=timezone.utcnow(), execution_date=DEFAULT_DATE, state=State.RUNNING, ) mock_get_records = mock_get_db_hook.return_value.get_first for false_value in SUPPORTED_FALSE_VALUES: mock_get_records.return_value = false_value branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) tis = dr.get_task_instances() for ti in tis: if ti.task_id == "make_choice": assert ti.state == State.SUCCESS elif ti.task_id == "branch_1": assert ti.state == State.SKIPPED elif ti.task_id == "branch_2": assert ti.state == State.NONE else: raise ValueError(f"Invalid task id {ti.task_id} found!") @mock.patch("airflow.operators.sql.BaseSQLOperator.get_db_hook") def test_branch_list_with_dag_run(self, mock_get_db_hook): """Checks if the BranchSQLOperator supports branching off to a list of tasks.""" branch_op = BranchSQLOperator( task_id="make_choice", conn_id="mysql_default", sql="SELECT 1", follow_task_ids_if_true=["branch_1", "branch_2"], follow_task_ids_if_false="branch_3", dag=self.dag, ) self.branch_1.set_upstream(branch_op) self.branch_2.set_upstream(branch_op) self.branch_3 = DummyOperator(task_id="branch_3", dag=self.dag) self.branch_3.set_upstream(branch_op) self.dag.clear() dr = self.dag.create_dagrun( run_id="manual__", start_date=timezone.utcnow(), execution_date=DEFAULT_DATE, state=State.RUNNING, ) mock_get_records = mock_get_db_hook.return_value.get_first mock_get_records.return_value = [["1"]] branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) tis = dr.get_task_instances() for ti in tis: if ti.task_id == "make_choice": assert ti.state == State.SUCCESS elif ti.task_id == "branch_1": assert ti.state == State.NONE elif ti.task_id == "branch_2": assert ti.state == State.NONE elif ti.task_id == "branch_3": assert ti.state == State.SKIPPED else: raise ValueError(f"Invalid task id {ti.task_id} found!") @mock.patch("airflow.operators.sql.BaseSQLOperator.get_db_hook") def test_invalid_query_result_with_dag_run(self, mock_get_db_hook): """Check BranchSQLOperator branch operation""" branch_op = BranchSQLOperator( task_id="make_choice", conn_id="mysql_default", sql="SELECT 1", follow_task_ids_if_true="branch_1", follow_task_ids_if_false="branch_2", dag=self.dag, ) self.branch_1.set_upstream(branch_op) self.branch_2.set_upstream(branch_op) self.dag.clear() self.dag.create_dagrun( run_id="manual__", start_date=timezone.utcnow(), execution_date=DEFAULT_DATE, state=State.RUNNING, ) mock_get_records = mock_get_db_hook.return_value.get_first mock_get_records.return_value = ["Invalid Value"] with pytest.raises(AirflowException): branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) @mock.patch("airflow.operators.sql.BaseSQLOperator.get_db_hook") def test_with_skip_in_branch_downstream_dependencies( self, mock_get_db_hook): """Test SQL Branch with skipping all downstream dependencies""" branch_op = BranchSQLOperator( task_id="make_choice", conn_id="mysql_default", sql="SELECT 1", follow_task_ids_if_true="branch_1", follow_task_ids_if_false="branch_2", dag=self.dag, ) branch_op >> self.branch_1 >> self.branch_2 branch_op >> self.branch_2 self.dag.clear() dr = self.dag.create_dagrun( run_id="manual__", start_date=timezone.utcnow(), execution_date=DEFAULT_DATE, state=State.RUNNING, ) mock_get_records = mock_get_db_hook.return_value.get_first for true_value in SUPPORTED_TRUE_VALUES: mock_get_records.return_value = [true_value] branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) tis = dr.get_task_instances() for ti in tis: if ti.task_id == "make_choice": assert ti.state == State.SUCCESS elif ti.task_id == "branch_1": assert ti.state == State.NONE elif ti.task_id == "branch_2": assert ti.state == State.NONE else: raise ValueError(f"Invalid task id {ti.task_id} found!") @mock.patch("airflow.operators.sql.BaseSQLOperator.get_db_hook") def test_with_skip_in_branch_downstream_dependencies2( self, mock_get_db_hook): """Test skipping downstream dependency for false condition""" branch_op = BranchSQLOperator( task_id="make_choice", conn_id="mysql_default", sql="SELECT 1", follow_task_ids_if_true="branch_1", follow_task_ids_if_false="branch_2", dag=self.dag, ) branch_op >> self.branch_1 >> self.branch_2 branch_op >> self.branch_2 self.dag.clear() dr = self.dag.create_dagrun( run_id="manual__", start_date=timezone.utcnow(), execution_date=DEFAULT_DATE, state=State.RUNNING, ) mock_get_records = mock_get_db_hook.return_value.get_first for false_value in SUPPORTED_FALSE_VALUES: mock_get_records.return_value = [false_value] branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) tis = dr.get_task_instances() for ti in tis: if ti.task_id == "make_choice": assert ti.state == State.SUCCESS elif ti.task_id == "branch_1": assert ti.state == State.SKIPPED elif ti.task_id == "branch_2": assert ti.state == State.NONE else: raise ValueError(f"Invalid task id {ti.task_id} found!")
class TestBranchDayOfWeekOperator(unittest.TestCase): """ Tests for BranchDayOfWeekOperator """ @classmethod def setUpClass(cls): with create_session() as session: session.query(DagRun).delete() session.query(TI).delete() def setUp(self): self.dag = DAG( "branch_day_of_week_operator_test", start_date=DEFAULT_DATE, schedule_interval=INTERVAL, ) self.branch_1 = DummyOperator(task_id="branch_1", dag=self.dag) self.branch_2 = DummyOperator(task_id="branch_2", dag=self.dag) self.branch_3 = None def tearDown(self): with create_session() as session: session.query(DagRun).delete() session.query(TI).delete() def _assert_task_ids_match_states(self, dr, task_ids_to_states): """Helper that asserts task instances with a given id are in a given state""" tis = dr.get_task_instances() for ti in tis: try: expected_state = task_ids_to_states[ti.task_id] except KeyError: raise ValueError(f'Invalid task id {ti.task_id} found!') else: self.assertEqual( ti.state, expected_state, f"Task {ti.task_id} has state {ti.state} instead of expected {expected_state}", ) @parameterized.expand([ ("with-string", "Monday"), ("with-enum", WeekDay.MONDAY), ("with-enum-set", {WeekDay.MONDAY}), ("with-enum-set-2-items", {WeekDay.MONDAY, WeekDay.FRIDAY}), ("with-string-set", {"Monday"}), ("with-string-set-2-items", {"Monday", "Friday"}), ]) @freeze_time("2021-01-25") # Monday def test_branch_follow_true(self, _, weekday): """Checks if BranchDayOfWeekOperator follows true branch""" print(datetime.datetime.now()) branch_op = BranchDayOfWeekOperator( task_id="make_choice", follow_task_ids_if_true=["branch_1", "branch_2"], follow_task_ids_if_false="branch_3", week_day=weekday, dag=self.dag, ) self.branch_1.set_upstream(branch_op) self.branch_2.set_upstream(branch_op) self.branch_3 = DummyOperator(task_id="branch_3", dag=self.dag) self.branch_3.set_upstream(branch_op) self.dag.clear() dr = self.dag.create_dagrun( run_id="manual__", start_date=timezone.utcnow(), execution_date=DEFAULT_DATE, state=State.RUNNING, ) branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) self._assert_task_ids_match_states( dr, { 'make_choice': State.SUCCESS, 'branch_1': State.NONE, 'branch_2': State.NONE, 'branch_3': State.SKIPPED, }, ) @freeze_time("2021-01-25") # Monday def test_branch_follow_true_with_execution_date(self): """Checks if BranchDayOfWeekOperator follows true branch when set use_task_execution_day """ branch_op = BranchDayOfWeekOperator( task_id="make_choice", follow_task_ids_if_true="branch_1", follow_task_ids_if_false="branch_2", week_day="Wednesday", use_task_execution_day= True, # We compare to DEFAULT_DATE which is Wednesday dag=self.dag, ) self.branch_1.set_upstream(branch_op) self.branch_2.set_upstream(branch_op) self.dag.clear() dr = self.dag.create_dagrun( run_id="manual__", start_date=timezone.utcnow(), execution_date=DEFAULT_DATE, state=State.RUNNING, ) branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) self._assert_task_ids_match_states( dr, { 'make_choice': State.SUCCESS, 'branch_1': State.NONE, 'branch_2': State.SKIPPED, }, ) @freeze_time("2021-01-25") # Monday def test_branch_follow_false(self): """Checks if BranchDayOfWeekOperator follow false branch""" branch_op = BranchDayOfWeekOperator( task_id="make_choice", follow_task_ids_if_true="branch_1", follow_task_ids_if_false="branch_2", week_day="Sunday", dag=self.dag, ) self.branch_1.set_upstream(branch_op) self.branch_2.set_upstream(branch_op) self.dag.clear() dr = self.dag.create_dagrun( run_id="manual__", start_date=timezone.utcnow(), execution_date=DEFAULT_DATE, state=State.RUNNING, ) branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) self._assert_task_ids_match_states( dr, { 'make_choice': State.SUCCESS, 'branch_1': State.SKIPPED, 'branch_2': State.NONE, }, ) def test_branch_with_no_weekday(self): """Check if BranchDayOfWeekOperator raises exception on missing weekday""" with self.assertRaises(AirflowException): BranchDayOfWeekOperator( # pylint: disable=missing-kwoa task_id="make_choice", follow_task_ids_if_true="branch_1", follow_task_ids_if_false="branch_2", dag=self.dag, ) def test_branch_with_invalid_type(self): """Check if BranchDayOfWeekOperator raises exception on unsupported weekday type""" invalid_week_day = ['Monday'] with pytest.raises( TypeError, match='Unsupported Type for week_day parameter:' ' {}. It should be one of str, set or ' 'Weekday enum type'.format(type(invalid_week_day)), ): BranchDayOfWeekOperator( task_id="make_choice", follow_task_ids_if_true="branch_1", follow_task_ids_if_false="branch_2", week_day=invalid_week_day, dag=self.dag, ) def test_weekday_branch_invalid_weekday_number(self): """Check if BranchDayOfWeekOperator raises exception on wrong value of weekday""" invalid_week_day = 'Thsday' with pytest.raises( AttributeError, match=f'Invalid Week Day passed: "{invalid_week_day}"'): BranchDayOfWeekOperator( task_id="make_choice", follow_task_ids_if_true="branch_1", follow_task_ids_if_false="branch_2", week_day=invalid_week_day, dag=self.dag, ) @freeze_time("2021-01-25") # Monday def test_branch_xcom_push_true_branch(self): """Check if BranchDayOfWeekOperator push to xcom value of follow_task_ids_if_true""" branch_op = BranchDayOfWeekOperator( task_id="make_choice", follow_task_ids_if_true="branch_1", follow_task_ids_if_false="branch_2", week_day="Monday", dag=self.dag, ) self.branch_1.set_upstream(branch_op) self.branch_2.set_upstream(branch_op) self.dag.clear() dr = self.dag.create_dagrun( run_id="manual__", start_date=timezone.utcnow(), execution_date=DEFAULT_DATE, state=State.RUNNING, ) branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) tis = dr.get_task_instances() for ti in tis: if ti.task_id == 'make_choice': assert ti.xcom_pull(task_ids='make_choice') == 'branch_1'
def test_not_skipping_external(self): latest_task = LatestOnlyOperator(task_id='latest', dag=self.dag) downstream_task = DummyOperator(task_id='downstream', dag=self.dag) downstream_task2 = DummyOperator(task_id='downstream_2', dag=self.dag) downstream_task.set_upstream(latest_task) downstream_task2.set_upstream(downstream_task) self.dag.create_dagrun( run_type=DagRunType.MANUAL, start_date=timezone.utcnow(), execution_date=DEFAULT_DATE, state=State.RUNNING, external_trigger=True, ) self.dag.create_dagrun( run_type=DagRunType.MANUAL, start_date=timezone.utcnow(), execution_date=timezone.datetime(2016, 1, 1, 12), state=State.RUNNING, external_trigger=True, ) self.dag.create_dagrun( run_type=DagRunType.MANUAL, start_date=timezone.utcnow(), execution_date=END_DATE, state=State.RUNNING, external_trigger=True, ) latest_task.run(start_date=DEFAULT_DATE, end_date=END_DATE) downstream_task.run(start_date=DEFAULT_DATE, end_date=END_DATE) downstream_task2.run(start_date=DEFAULT_DATE, end_date=END_DATE) latest_instances = get_task_instances('latest') exec_date_to_latest_state = { ti.execution_date: ti.state for ti in latest_instances } assert { timezone.datetime(2016, 1, 1): 'success', timezone.datetime(2016, 1, 1, 12): 'success', timezone.datetime(2016, 1, 2): 'success', } == exec_date_to_latest_state downstream_instances = get_task_instances('downstream') exec_date_to_downstream_state = { ti.execution_date: ti.state for ti in downstream_instances } assert { timezone.datetime(2016, 1, 1): 'success', timezone.datetime(2016, 1, 1, 12): 'success', timezone.datetime(2016, 1, 2): 'success', } == exec_date_to_downstream_state downstream_instances = get_task_instances('downstream_2') exec_date_to_downstream_state = { ti.execution_date: ti.state for ti in downstream_instances } assert { timezone.datetime(2016, 1, 1): 'success', timezone.datetime(2016, 1, 1, 12): 'success', timezone.datetime(2016, 1, 2): 'success', } == exec_date_to_downstream_state
def test_with_dag_run(self): value = False dag = DAG( 'shortcircuit_operator_test_with_dag_run', default_args={ 'owner': 'airflow', 'start_date': DEFAULT_DATE }, schedule_interval=INTERVAL, ) short_op = ShortCircuitOperator(task_id='make_choice', dag=dag, python_callable=lambda: value) branch_1 = DummyOperator(task_id='branch_1', dag=dag) branch_1.set_upstream(short_op) branch_2 = DummyOperator(task_id='branch_2', dag=dag) branch_2.set_upstream(branch_1) upstream = DummyOperator(task_id='upstream', dag=dag) upstream.set_downstream(short_op) dag.clear() logging.error("Tasks %s", dag.tasks) dr = dag.create_dagrun( run_type=DagRunType.MANUAL, start_date=timezone.utcnow(), execution_date=DEFAULT_DATE, state=State.RUNNING, ) upstream.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) short_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) tis = dr.get_task_instances() assert len(tis) == 4 for ti in tis: if ti.task_id == 'make_choice': assert ti.state == State.SUCCESS elif ti.task_id == 'upstream': assert ti.state == State.SUCCESS elif ti.task_id == 'branch_1' or ti.task_id == 'branch_2': assert ti.state == State.SKIPPED else: raise ValueError(f'Invalid task id {ti.task_id} found!') value = True dag.clear() dr.verify_integrity() upstream.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) short_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) tis = dr.get_task_instances() assert len(tis) == 4 for ti in tis: if ti.task_id == 'make_choice': assert ti.state == State.SUCCESS elif ti.task_id == 'upstream': assert ti.state == State.SUCCESS elif ti.task_id == 'branch_1' or ti.task_id == 'branch_2': assert ti.state == State.NONE else: raise ValueError(f'Invalid task id {ti.task_id} found!')
class TestBranchOperator(unittest.TestCase): @classmethod def setUpClass(cls): super().setUpClass() with create_session() as session: session.query(DagRun).delete() session.query(TI).delete() def setUp(self): self.dag = DAG( 'branch_operator_test', default_args={ 'owner': 'airflow', 'start_date': DEFAULT_DATE }, schedule_interval=INTERVAL, ) self.branch_1 = DummyOperator(task_id='branch_1', dag=self.dag) self.branch_2 = DummyOperator(task_id='branch_2', dag=self.dag) self.branch_3 = None def tearDown(self): super().tearDown() with create_session() as session: session.query(DagRun).delete() session.query(TI).delete() def test_without_dag_run(self): """This checks the defensive against non existent tasks in a dag run""" branch_op = BranchPythonOperator(task_id='make_choice', dag=self.dag, python_callable=lambda: 'branch_1') self.branch_1.set_upstream(branch_op) self.branch_2.set_upstream(branch_op) self.dag.clear() branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) with create_session() as session: tis = session.query(TI).filter(TI.dag_id == self.dag.dag_id, TI.execution_date == DEFAULT_DATE) for ti in tis: if ti.task_id == 'make_choice': assert ti.state == State.SUCCESS elif ti.task_id == 'branch_1': # should exist with state None assert ti.state == State.NONE elif ti.task_id == 'branch_2': assert ti.state == State.SKIPPED else: raise ValueError(f'Invalid task id {ti.task_id} found!') def test_branch_list_without_dag_run(self): """This checks if the BranchPythonOperator supports branching off to a list of tasks.""" branch_op = BranchPythonOperator( task_id='make_choice', dag=self.dag, python_callable=lambda: ['branch_1', 'branch_2']) self.branch_1.set_upstream(branch_op) self.branch_2.set_upstream(branch_op) self.branch_3 = DummyOperator(task_id='branch_3', dag=self.dag) self.branch_3.set_upstream(branch_op) self.dag.clear() branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) with create_session() as session: tis = session.query(TI).filter(TI.dag_id == self.dag.dag_id, TI.execution_date == DEFAULT_DATE) expected = { "make_choice": State.SUCCESS, "branch_1": State.NONE, "branch_2": State.NONE, "branch_3": State.SKIPPED, } for ti in tis: if ti.task_id in expected: assert ti.state == expected[ti.task_id] else: raise ValueError(f'Invalid task id {ti.task_id} found!') def test_with_dag_run(self): branch_op = BranchPythonOperator(task_id='make_choice', dag=self.dag, python_callable=lambda: 'branch_1') self.branch_1.set_upstream(branch_op) self.branch_2.set_upstream(branch_op) self.dag.clear() dr = self.dag.create_dagrun( run_type=DagRunType.MANUAL, start_date=timezone.utcnow(), execution_date=DEFAULT_DATE, state=State.RUNNING, ) branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) tis = dr.get_task_instances() for ti in tis: if ti.task_id == 'make_choice': assert ti.state == State.SUCCESS elif ti.task_id == 'branch_1': assert ti.state == State.NONE elif ti.task_id == 'branch_2': assert ti.state == State.SKIPPED else: raise ValueError(f'Invalid task id {ti.task_id} found!') def test_with_skip_in_branch_downstream_dependencies(self): branch_op = BranchPythonOperator(task_id='make_choice', dag=self.dag, python_callable=lambda: 'branch_1') branch_op >> self.branch_1 >> self.branch_2 branch_op >> self.branch_2 self.dag.clear() dr = self.dag.create_dagrun( run_type=DagRunType.MANUAL, start_date=timezone.utcnow(), execution_date=DEFAULT_DATE, state=State.RUNNING, ) branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) tis = dr.get_task_instances() for ti in tis: if ti.task_id == 'make_choice': assert ti.state == State.SUCCESS elif ti.task_id == 'branch_1': assert ti.state == State.NONE elif ti.task_id == 'branch_2': assert ti.state == State.NONE else: raise ValueError(f'Invalid task id {ti.task_id} found!') def test_with_skip_in_branch_downstream_dependencies2(self): branch_op = BranchPythonOperator(task_id='make_choice', dag=self.dag, python_callable=lambda: 'branch_2') branch_op >> self.branch_1 >> self.branch_2 branch_op >> self.branch_2 self.dag.clear() dr = self.dag.create_dagrun( run_type=DagRunType.MANUAL, start_date=timezone.utcnow(), execution_date=DEFAULT_DATE, state=State.RUNNING, ) branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) tis = dr.get_task_instances() for ti in tis: if ti.task_id == 'make_choice': assert ti.state == State.SUCCESS elif ti.task_id == 'branch_1': assert ti.state == State.SKIPPED elif ti.task_id == 'branch_2': assert ti.state == State.NONE else: raise ValueError(f'Invalid task id {ti.task_id} found!') def test_xcom_push(self): branch_op = BranchPythonOperator(task_id='make_choice', dag=self.dag, python_callable=lambda: 'branch_1') self.branch_1.set_upstream(branch_op) self.branch_2.set_upstream(branch_op) self.dag.clear() dr = self.dag.create_dagrun( run_type=DagRunType.MANUAL, start_date=timezone.utcnow(), execution_date=DEFAULT_DATE, state=State.RUNNING, ) branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) tis = dr.get_task_instances() for ti in tis: if ti.task_id == 'make_choice': assert ti.xcom_pull(task_ids='make_choice') == 'branch_1' def test_clear_skipped_downstream_task(self): """ After a downstream task is skipped by BranchPythonOperator, clearing the skipped task should not cause it to be executed. """ branch_op = BranchPythonOperator(task_id='make_choice', dag=self.dag, python_callable=lambda: 'branch_1') branches = [self.branch_1, self.branch_2] branch_op >> branches self.dag.clear() dr = self.dag.create_dagrun( run_type=DagRunType.MANUAL, start_date=timezone.utcnow(), execution_date=DEFAULT_DATE, state=State.RUNNING, ) branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) for task in branches: task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) tis = dr.get_task_instances() for ti in tis: if ti.task_id == 'make_choice': assert ti.state == State.SUCCESS elif ti.task_id == 'branch_1': assert ti.state == State.SUCCESS elif ti.task_id == 'branch_2': assert ti.state == State.SKIPPED else: raise ValueError(f'Invalid task id {ti.task_id} found!') children_tis = [ ti for ti in tis if ti.task_id in branch_op.get_direct_relative_ids() ] # Clear the children tasks. with create_session() as session: clear_task_instances(children_tis, session=session, dag=self.dag) # Run the cleared tasks again. for task in branches: task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) # Check if the states are correct after children tasks are cleared. for ti in dr.get_task_instances(): if ti.task_id == 'make_choice': assert ti.state == State.SUCCESS elif ti.task_id == 'branch_1': assert ti.state == State.SUCCESS elif ti.task_id == 'branch_2': assert ti.state == State.SKIPPED else: raise ValueError(f'Invalid task id {ti.task_id} found!')
class TestBranchOperator(unittest.TestCase): @classmethod def setUpClass(cls): super().setUpClass() with create_session() as session: session.query(DagRun).delete() session.query(TI).delete() def setUp(self): self.dag = DAG( 'branch_operator_test', default_args={ 'owner': 'airflow', 'start_date': DEFAULT_DATE }, schedule_interval=INTERVAL, ) self.branch_1 = DummyOperator(task_id='branch_1', dag=self.dag) self.branch_2 = DummyOperator(task_id='branch_2', dag=self.dag) self.branch_3 = None self.branch_op = None def tearDown(self): super().tearDown() with create_session() as session: session.query(DagRun).delete() session.query(TI).delete() def test_without_dag_run(self): """This checks the defensive against non existent tasks in a dag run""" self.branch_op = ChooseBranchOne(task_id="make_choice", dag=self.dag) self.branch_1.set_upstream(self.branch_op) self.branch_2.set_upstream(self.branch_op) self.dag.clear() self.branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) with create_session() as session: tis = session.query(TI).filter(TI.dag_id == self.dag.dag_id, TI.execution_date == DEFAULT_DATE) for ti in tis: if ti.task_id == 'make_choice': assert ti.state == State.SUCCESS elif ti.task_id == 'branch_1': # should exist with state None assert ti.state == State.NONE elif ti.task_id == 'branch_2': assert ti.state == State.SKIPPED else: raise Exception def test_branch_list_without_dag_run(self): """This checks if the BranchOperator supports branching off to a list of tasks.""" self.branch_op = ChooseBranchOneTwo(task_id='make_choice', dag=self.dag) self.branch_1.set_upstream(self.branch_op) self.branch_2.set_upstream(self.branch_op) self.branch_3 = DummyOperator(task_id='branch_3', dag=self.dag) self.branch_3.set_upstream(self.branch_op) self.dag.clear() self.branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) with create_session() as session: tis = session.query(TI).filter(TI.dag_id == self.dag.dag_id, TI.execution_date == DEFAULT_DATE) expected = { "make_choice": State.SUCCESS, "branch_1": State.NONE, "branch_2": State.NONE, "branch_3": State.SKIPPED, } for ti in tis: if ti.task_id in expected: assert ti.state == expected[ti.task_id] else: raise Exception def test_with_dag_run(self): self.branch_op = ChooseBranchOne(task_id="make_choice", dag=self.dag) self.branch_1.set_upstream(self.branch_op) self.branch_2.set_upstream(self.branch_op) self.dag.clear() dagrun = self.dag.create_dagrun( run_type=DagRunType.MANUAL, start_date=timezone.utcnow(), execution_date=DEFAULT_DATE, state=State.RUNNING, ) self.branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) tis = dagrun.get_task_instances() for ti in tis: if ti.task_id == 'make_choice': assert ti.state == State.SUCCESS elif ti.task_id == 'branch_1': assert ti.state == State.NONE elif ti.task_id == 'branch_2': assert ti.state == State.SKIPPED else: raise Exception def test_with_skip_in_branch_downstream_dependencies(self): self.branch_op = ChooseBranchOne(task_id="make_choice", dag=self.dag) self.branch_op >> self.branch_1 >> self.branch_2 self.branch_op >> self.branch_2 self.dag.clear() dagrun = self.dag.create_dagrun( run_type=DagRunType.MANUAL, start_date=timezone.utcnow(), execution_date=DEFAULT_DATE, state=State.RUNNING, ) self.branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) tis = dagrun.get_task_instances() for ti in tis: if ti.task_id == 'make_choice': assert ti.state == State.SUCCESS elif ti.task_id == 'branch_1': assert ti.state == State.NONE elif ti.task_id == 'branch_2': assert ti.state == State.NONE else: raise Exception