def test_invalid_query_result_with_dag_run(self, mock_hook): """ Check BranchSqlOperator branch operation """ branch_op = BranchSqlOperator( task_id="make_choice", conn_id="mysql_default", sql="SELECT 1", follow_task_ids_if_true="branch_1", follow_task_ids_if_false="branch_2", dag=self.dag, ) self.branch_1.set_upstream(branch_op) self.branch_2.set_upstream(branch_op) self.dag.clear() self.dag.create_dagrun( run_id="manual__", start_date=timezone.utcnow(), execution_date=DEFAULT_DATE, state=State.RUNNING, ) mock_hook.get_connection("mysql_default").conn_type = "mysql" mock_get_records = (mock_hook.get_connection.return_value.get_hook. return_value.get_first) mock_get_records.return_value = ["Invalid Value"] with self.assertRaises(AirflowException): branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
def test_sql_branch_operator_postgres(self): """ Check if BranchSqlOperator works with backend """ branch_op = BranchSqlOperator( task_id="make_choice", conn_id="postgres_default", sql="SELECT 1", follow_task_ids_if_true="branch_1", follow_task_ids_if_false="branch_2", dag=self.dag, ) branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
def test_unsupported_conn_type(self): """ Check if BranchSqlOperator throws an exception for unsupported connection type """ op = BranchSqlOperator( task_id="make_choice", conn_id="redis_default", sql="SELECT count(1) FROM INFORMATION_SCHEMA.TABLES", follow_task_ids_if_true="branch_1", follow_task_ids_if_false="branch_2", dag=self.dag, ) with self.assertRaises(AirflowException): op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
def test_branch_list_with_dag_run(self, mock_hook): """ Checks if the BranchSqlOperator supports branching off to a list of tasks.""" branch_op = BranchSqlOperator( task_id="make_choice", conn_id="mysql_default", sql="SELECT 1", follow_task_ids_if_true=["branch_1", "branch_2"], follow_task_ids_if_false="branch_3", dag=self.dag, ) self.branch_1.set_upstream(branch_op) self.branch_2.set_upstream(branch_op) self.branch_3 = DummyOperator(task_id="branch_3", dag=self.dag) self.branch_3.set_upstream(branch_op) self.dag.clear() dr = self.dag.create_dagrun( run_id="manual__", start_date=timezone.utcnow(), execution_date=DEFAULT_DATE, state=State.RUNNING, ) mock_hook.get_connection("mysql_default").conn_type = "mysql" mock_get_records = (mock_hook.get_connection.return_value.get_hook. return_value.get_first) mock_get_records.return_value = [["1"]] branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) tis = dr.get_task_instances() for ti in tis: if ti.task_id == "make_choice": self.assertEqual(ti.state, State.SUCCESS) elif ti.task_id == "branch_1": self.assertEqual(ti.state, State.NONE) elif ti.task_id == "branch_2": self.assertEqual(ti.state, State.NONE) elif ti.task_id == "branch_3": self.assertEqual(ti.state, State.SKIPPED) else: raise ValueError(f"Invalid task id {ti.task_id} found!")
def test_with_skip_in_branch_downstream_dependencies2(self, mock_hook): """ Test skipping downstream dependency for false condition""" branch_op = BranchSqlOperator( task_id="make_choice", conn_id="mysql_default", sql="SELECT 1", follow_task_ids_if_true="branch_1", follow_task_ids_if_false="branch_2", dag=self.dag, ) branch_op >> self.branch_1 >> self.branch_2 branch_op >> self.branch_2 self.dag.clear() dr = self.dag.create_dagrun( run_id="manual__", start_date=timezone.utcnow(), execution_date=DEFAULT_DATE, state=State.RUNNING, ) mock_hook.get_connection("mysql_default").conn_type = "mysql" mock_get_records = (mock_hook.get_connection.return_value.get_hook. return_value.get_first) for false_value in SUPPORTED_FALSE_VALUES: mock_get_records.return_value = [false_value] branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) tis = dr.get_task_instances() for ti in tis: if ti.task_id == "make_choice": self.assertEqual(ti.state, State.SUCCESS) elif ti.task_id == "branch_1": self.assertEqual(ti.state, State.SKIPPED) elif ti.task_id == "branch_2": self.assertEqual(ti.state, State.NONE) else: raise ValueError(f"Invalid task id {ti.task_id} found!")
python_callable=check_dataset, provide_context=True) waiting_for_data = FileSensor(task_id='waiting_for_data', fs_conn_id='fs_default', filepath='avocado.csv', poke_interval=15) training_model_tasks = SubDagOperator( task_id='training_model_tasks', subdag=subdag_factory('avocado_dag', 'training_model_tasks', default_args)) evaluating_rmse = BranchSqlOperator(task_id='evaluating_rmse', sql='sql/FETCH_MIN_RMSE.sql', conn_id='postgres', follow_task_ids_if_true='accurate', follow_task_ids_if_false='inaccurate') accurate = DummyOperator(task_id='accurate') inaccurate = DummyOperator(task_id='inaccurate') fetch_best_model = NotebookToKeepOperator(task_id='fetch_best_model', sql='sql/FETCH_BEST_MODEL.sql', postgres_conn_id='postgres') publish_notebook = NotebookToGitOperator( task_id='publish_notebook', conn_id='git', nb_path='/tmp',