def test_with_skip_in_branch_downstream_dependencies2(self): branch_op = BranchPythonOperator(task_id='make_choice', dag=self.dag, python_callable=lambda: 'branch_2') branch_op >> self.branch_1 >> self.branch_2 branch_op >> self.branch_2 self.dag.clear() dr = self.dag.create_dagrun(run_id="manual__", start_date=timezone.utcnow(), execution_date=DEFAULT_DATE, state=State.RUNNING) branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) tis = dr.get_task_instances() for ti in tis: if ti.task_id == 'make_choice': self.assertEqual(ti.state, State.SUCCESS) elif ti.task_id == 'branch_1': self.assertEqual(ti.state, State.SKIPPED) elif ti.task_id == 'branch_2': self.assertEqual(ti.state, State.NONE) else: raise Exception
def test_branch_list_without_dag_run(self): """This checks if the BranchPythonOperator supports branching off to a list of tasks.""" branch_op = BranchPythonOperator( task_id='make_choice', dag=self.dag, python_callable=lambda: ['branch_1', 'branch_2']) self.branch_1.set_upstream(branch_op) self.branch_2.set_upstream(branch_op) self.branch_3 = DummyOperator(task_id='branch_3', dag=self.dag) self.branch_3.set_upstream(branch_op) self.dag.clear() branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) with create_session() as session: tis = session.query(TI).filter(TI.dag_id == self.dag.dag_id, TI.execution_date == DEFAULT_DATE) expected = { "make_choice": State.SUCCESS, "branch_1": State.NONE, "branch_2": State.NONE, "branch_3": State.SKIPPED, } for ti in tis: if ti.task_id in expected: self.assertEqual(ti.state, expected[ti.task_id]) else: raise Exception
def test_without_dag_run(self): """This checks the defensive against non existent tasks in a dag run""" branch_op = BranchPythonOperator(task_id='make_choice', dag=self.dag, python_callable=lambda: 'branch_1') self.branch_1.set_upstream(branch_op) self.branch_2.set_upstream(branch_op) self.dag.clear() branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) with create_session() as session: tis = session.query(TI).filter(TI.dag_id == self.dag.dag_id, TI.execution_date == DEFAULT_DATE) for ti in tis: if ti.task_id == 'make_choice': self.assertEqual(ti.state, State.SUCCESS) elif ti.task_id == 'branch_1': # should exist with state None self.assertEqual(ti.state, State.NONE) elif ti.task_id == 'branch_2': self.assertEqual(ti.state, State.SKIPPED) else: raise Exception
def test_with_skip_in_branch_downstream_dependencies2(self): branch_op = BranchPythonOperator(task_id='make_choice', dag=self.dag, python_callable=lambda: 'branch_2') branch_op >> self.branch_1 >> self.branch_2 branch_op >> self.branch_2 self.dag.clear() dr = self.dag.create_dagrun( run_type=DagRunType.MANUAL, start_date=timezone.utcnow(), execution_date=DEFAULT_DATE, state=State.RUNNING, ) branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) tis = dr.get_task_instances() for ti in tis: if ti.task_id == 'make_choice': assert ti.state == State.SUCCESS elif ti.task_id == 'branch_1': assert ti.state == State.SKIPPED elif ti.task_id == 'branch_2': assert ti.state == State.NONE else: raise ValueError(f'Invalid task id {ti.task_id} found!')
def test_with_dag_run(self): branch_op = BranchPythonOperator(task_id='make_choice', dag=self.dag, python_callable=lambda: 'branch_1') self.branch_1.set_upstream(branch_op) self.branch_2.set_upstream(branch_op) self.dag.clear() dr = self.dag.create_dagrun(run_type=DagRunType.MANUAL, start_date=timezone.utcnow(), execution_date=DEFAULT_DATE, state=State.RUNNING) branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) tis = dr.get_task_instances() for ti in tis: if ti.task_id == 'make_choice': self.assertEqual(ti.state, State.SUCCESS) elif ti.task_id == 'branch_1': self.assertEqual(ti.state, State.NONE) elif ti.task_id == 'branch_2': self.assertEqual(ti.state, State.SKIPPED) else: raise ValueError(f'Invalid task id {ti.task_id} found!')
def test_clear_skipped_downstream_task(self): """ After a downstream task is skipped by BranchPythonOperator, clearing the skipped task should not cause it to be executed. """ branch_op = BranchPythonOperator(task_id='make_choice', dag=self.dag, python_callable=lambda: 'branch_1') branches = [self.branch_1, self.branch_2] branch_op >> branches self.dag.clear() dr = self.dag.create_dagrun( run_type=DagRunType.MANUAL, start_date=timezone.utcnow(), execution_date=DEFAULT_DATE, state=State.RUNNING, ) branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) for task in branches: task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) tis = dr.get_task_instances() for ti in tis: if ti.task_id == 'make_choice': assert ti.state == State.SUCCESS elif ti.task_id == 'branch_1': assert ti.state == State.SUCCESS elif ti.task_id == 'branch_2': assert ti.state == State.SKIPPED else: raise ValueError(f'Invalid task id {ti.task_id} found!') children_tis = [ ti for ti in tis if ti.task_id in branch_op.get_direct_relative_ids() ] # Clear the children tasks. with create_session() as session: clear_task_instances(children_tis, session=session, dag=self.dag) # Run the cleared tasks again. for task in branches: task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) # Check if the states are correct after children tasks are cleared. for ti in dr.get_task_instances(): if ti.task_id == 'make_choice': assert ti.state == State.SUCCESS elif ti.task_id == 'branch_1': assert ti.state == State.SUCCESS elif ti.task_id == 'branch_2': assert ti.state == State.SKIPPED else: raise ValueError(f'Invalid task id {ti.task_id} found!')
def test_xcom_push(self): branch_op = BranchPythonOperator(task_id='make_choice', dag=self.dag, python_callable=lambda: 'branch_1') self.branch_1.set_upstream(branch_op) self.branch_2.set_upstream(branch_op) self.dag.clear() dr = self.dag.create_dagrun(run_id="manual__", start_date=timezone.utcnow(), execution_date=DEFAULT_DATE, state=State.RUNNING) branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) tis = dr.get_task_instances() for ti in tis: if ti.task_id == 'make_choice': self.assertEqual(ti.xcom_pull(task_ids='make_choice'), 'branch_1')