def test_invalid_query_result_with_dag_run(self, mock_hook): """ Check BranchSqlOperator branch operation """ branch_op = BranchSqlOperator( task_id="make_choice", conn_id="mysql_default", sql="SELECT 1", follow_task_ids_if_true="branch_1", follow_task_ids_if_false="branch_2", dag=self.dag, ) self.branch_1.set_upstream(branch_op) self.branch_2.set_upstream(branch_op) self.dag.clear() self.dag.create_dagrun( run_id="manual__", start_date=timezone.utcnow(), execution_date=DEFAULT_DATE, state=State.RUNNING, ) mock_hook.get_connection("mysql_default").conn_type = "mysql" mock_get_records = (mock_hook.get_connection.return_value.get_hook. return_value.get_first) mock_get_records.return_value = ["Invalid Value"] with self.assertRaises(AirflowException): branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
def test_sql_branch_operator_postgres(self): """ Check if BranchSqlOperator works with backend """ branch_op = BranchSqlOperator( task_id="make_choice", conn_id="postgres_default", sql="SELECT 1", follow_task_ids_if_true="branch_1", follow_task_ids_if_false="branch_2", dag=self.dag, ) branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
def test_unsupported_conn_type(self): """ Check if BranchSqlOperator throws an exception for unsupported connection type """ op = BranchSqlOperator( task_id="make_choice", conn_id="redis_default", sql="SELECT count(1) FROM INFORMATION_SCHEMA.TABLES", follow_task_ids_if_true="branch_1", follow_task_ids_if_false="branch_2", dag=self.dag, ) with self.assertRaises(AirflowException): op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
def test_branch_list_with_dag_run(self, mock_hook): """ Checks if the BranchSqlOperator supports branching off to a list of tasks.""" branch_op = BranchSqlOperator( task_id="make_choice", conn_id="mysql_default", sql="SELECT 1", follow_task_ids_if_true=["branch_1", "branch_2"], follow_task_ids_if_false="branch_3", dag=self.dag, ) self.branch_1.set_upstream(branch_op) self.branch_2.set_upstream(branch_op) self.branch_3 = DummyOperator(task_id="branch_3", dag=self.dag) self.branch_3.set_upstream(branch_op) self.dag.clear() dr = self.dag.create_dagrun( run_id="manual__", start_date=timezone.utcnow(), execution_date=DEFAULT_DATE, state=State.RUNNING, ) mock_hook.get_connection("mysql_default").conn_type = "mysql" mock_get_records = (mock_hook.get_connection.return_value.get_hook. return_value.get_first) mock_get_records.return_value = [["1"]] branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) tis = dr.get_task_instances() for ti in tis: if ti.task_id == "make_choice": self.assertEqual(ti.state, State.SUCCESS) elif ti.task_id == "branch_1": self.assertEqual(ti.state, State.NONE) elif ti.task_id == "branch_2": self.assertEqual(ti.state, State.NONE) elif ti.task_id == "branch_3": self.assertEqual(ti.state, State.SKIPPED) else: raise ValueError(f"Invalid task id {ti.task_id} found!")
def test_with_skip_in_branch_downstream_dependencies2(self, mock_hook): """ Test skipping downstream dependency for false condition""" branch_op = BranchSqlOperator( task_id="make_choice", conn_id="mysql_default", sql="SELECT 1", follow_task_ids_if_true="branch_1", follow_task_ids_if_false="branch_2", dag=self.dag, ) branch_op >> self.branch_1 >> self.branch_2 branch_op >> self.branch_2 self.dag.clear() dr = self.dag.create_dagrun( run_id="manual__", start_date=timezone.utcnow(), execution_date=DEFAULT_DATE, state=State.RUNNING, ) mock_hook.get_connection("mysql_default").conn_type = "mysql" mock_get_records = (mock_hook.get_connection.return_value.get_hook. return_value.get_first) for false_value in SUPPORTED_FALSE_VALUES: mock_get_records.return_value = [false_value] branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) tis = dr.get_task_instances() for ti in tis: if ti.task_id == "make_choice": self.assertEqual(ti.state, State.SUCCESS) elif ti.task_id == "branch_1": self.assertEqual(ti.state, State.SKIPPED) elif ti.task_id == "branch_2": self.assertEqual(ti.state, State.NONE) else: raise ValueError(f"Invalid task id {ti.task_id} found!")