def test_invalid_query_result_with_dag_run(self, mock_get_db_hook): """Check BranchSQLOperator branch operation""" branch_op = BranchSQLOperator( task_id="make_choice", conn_id="mysql_default", sql="SELECT 1", follow_task_ids_if_true="branch_1", follow_task_ids_if_false="branch_2", dag=self.dag, ) self.branch_1.set_upstream(branch_op) self.branch_2.set_upstream(branch_op) self.dag.clear() self.dag.create_dagrun( run_id="manual__", start_date=timezone.utcnow(), execution_date=DEFAULT_DATE, state=State.RUNNING, ) mock_get_records = mock_get_db_hook.return_value.get_first mock_get_records.return_value = ["Invalid Value"] with pytest.raises(AirflowException): branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
def test_sql_branch_operator_postgres(self): """Check if BranchSQLOperator works with backend """ branch_op = BranchSQLOperator( task_id="make_choice", conn_id="postgres_default", sql="SELECT 1", follow_task_ids_if_true="branch_1", follow_task_ids_if_false="branch_2", dag=self.dag, ) branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
def test_invalid_follow_task_false(self): """Check if BranchSQLOperator throws an exception for invalid connection """ op = BranchSQLOperator( task_id="make_choice", conn_id="invalid_connection", sql="SELECT count(1) FROM INFORMATION_SCHEMA.TABLES", follow_task_ids_if_true="branch_1", follow_task_ids_if_false=None, dag=self.dag, ) with pytest.raises(AirflowException): op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
def test_unsupported_conn_type(self): """ Check if BranchSQLOperator throws an exception for unsupported connection type """ op = BranchSQLOperator( task_id="make_choice", conn_id="redis_default", sql="SELECT count(1) FROM INFORMATION_SCHEMA.TABLES", follow_task_ids_if_true="branch_1", follow_task_ids_if_false="branch_2", dag=self.dag, ) with self.assertRaises(AirflowException): op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
def test_branch_list_with_dag_run(self, mock_hook): """ Checks if the BranchSQLOperator supports branching off to a list of tasks.""" branch_op = BranchSQLOperator( task_id="make_choice", conn_id="mysql_default", sql="SELECT 1", follow_task_ids_if_true=["branch_1", "branch_2"], follow_task_ids_if_false="branch_3", dag=self.dag, ) self.branch_1.set_upstream(branch_op) self.branch_2.set_upstream(branch_op) self.branch_3 = DummyOperator(task_id="branch_3", dag=self.dag) self.branch_3.set_upstream(branch_op) self.dag.clear() dr = self.dag.create_dagrun( run_id="manual__", start_date=timezone.utcnow(), execution_date=DEFAULT_DATE, state=State.RUNNING, ) mock_hook.get_connection("mysql_default").conn_type = "mysql" mock_get_records = (mock_hook.get_connection.return_value.get_hook. return_value.get_first) mock_get_records.return_value = [["1"]] branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) tis = dr.get_task_instances() for ti in tis: if ti.task_id == "make_choice": self.assertEqual(ti.state, State.SUCCESS) elif ti.task_id == "branch_1": self.assertEqual(ti.state, State.NONE) elif ti.task_id == "branch_2": self.assertEqual(ti.state, State.NONE) elif ti.task_id == "branch_3": self.assertEqual(ti.state, State.SKIPPED) else: raise ValueError("Invalid task id {task_id} found!".format( task_id=ti.task_id))
def test_with_skip_in_branch_downstream_dependencies2(self, mock_hook): """ Test skipping downstream dependency for false condition""" branch_op = BranchSQLOperator( task_id="make_choice", conn_id="mysql_default", sql="SELECT 1", follow_task_ids_if_true="branch_1", follow_task_ids_if_false="branch_2", dag=self.dag, ) branch_op >> self.branch_1 >> self.branch_2 branch_op >> self.branch_2 self.dag.clear() dr = self.dag.create_dagrun( run_id="manual__", start_date=timezone.utcnow(), execution_date=DEFAULT_DATE, state=State.RUNNING, ) mock_hook.get_connection("mysql_default").conn_type = "mysql" mock_get_records = (mock_hook.get_connection.return_value.get_hook. return_value.get_first) for false_value in SUPPORTED_FALSE_VALUES: mock_get_records.return_value = [false_value] branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) tis = dr.get_task_instances() for ti in tis: if ti.task_id == "make_choice": self.assertEqual(ti.state, State.SUCCESS) elif ti.task_id == "branch_1": self.assertEqual(ti.state, State.SKIPPED) elif ti.task_id == "branch_2": self.assertEqual(ti.state, State.NONE) else: raise ValueError("Invalid task id {task_id} found!".format( task_id=ti.task_id))
def test_branch_false_with_dag_run(self, mock_hook): """ Check BranchSQLOperator branch operation """ branch_op = BranchSQLOperator( task_id="make_choice", conn_id="mysql_default", sql="SELECT 1", follow_task_ids_if_true="branch_1", follow_task_ids_if_false="branch_2", dag=self.dag, ) self.branch_1.set_upstream(branch_op) self.branch_2.set_upstream(branch_op) self.dag.clear() dr = self.dag.create_dagrun( run_id="manual__", start_date=timezone.utcnow(), execution_date=DEFAULT_DATE, state=State.RUNNING, ) mock_hook.get_connection("mysql_default").conn_type = "mysql" mock_get_records = ( mock_hook.get_connection.return_value.get_hook.return_value.get_first ) for false_value in SUPPORTED_FALSE_VALUES: mock_get_records.return_value = false_value branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) tis = dr.get_task_instances() for ti in tis: if ti.task_id == "make_choice": self.assertEqual(ti.state, State.SUCCESS) elif ti.task_id == "branch_1": self.assertEqual(ti.state, State.SKIPPED) elif ti.task_id == "branch_2": self.assertEqual(ti.state, State.NONE) else: raise ValueError(f"Invalid task id {ti.task_id} found!")
def test_with_skip_in_branch_downstream_dependencies( self, mock_get_db_hook): """Test SQL Branch with skipping all downstream dependencies""" branch_op = BranchSQLOperator( task_id="make_choice", conn_id="mysql_default", sql="SELECT 1", follow_task_ids_if_true="branch_1", follow_task_ids_if_false="branch_2", dag=self.dag, ) branch_op >> self.branch_1 >> self.branch_2 branch_op >> self.branch_2 self.dag.clear() dr = self.dag.create_dagrun( run_id="manual__", start_date=timezone.utcnow(), execution_date=DEFAULT_DATE, state=State.RUNNING, ) mock_get_records = mock_get_db_hook.return_value.get_first for true_value in SUPPORTED_TRUE_VALUES: mock_get_records.return_value = [true_value] branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) tis = dr.get_task_instances() for ti in tis: if ti.task_id == "make_choice": assert ti.state == State.SUCCESS elif ti.task_id == "branch_1": assert ti.state == State.NONE elif ti.task_id == "branch_2": assert ti.state == State.NONE else: raise ValueError(f"Invalid task id {ti.task_id} found!")
def test_branch_single_value_with_dag_run(self, mock_get_db_hook): """Check BranchSQLOperator branch operation""" branch_op = BranchSQLOperator( task_id="make_choice", conn_id="mysql_default", sql="SELECT 1", follow_task_ids_if_true="branch_1", follow_task_ids_if_false="branch_2", dag=self.dag, ) self.branch_1.set_upstream(branch_op) self.branch_2.set_upstream(branch_op) self.dag.clear() dr = self.dag.create_dagrun( run_id="manual__", start_date=timezone.utcnow(), execution_date=DEFAULT_DATE, state=State.RUNNING, ) mock_get_records = mock_get_db_hook.return_value.get_first mock_get_records.return_value = 1 branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) tis = dr.get_task_instances() for ti in tis: if ti.task_id == "make_choice": assert ti.state == State.SUCCESS elif ti.task_id == "branch_1": assert ti.state == State.NONE elif ti.task_id == "branch_2": assert ti.state == State.SKIPPED else: raise ValueError(f"Invalid task id {ti.task_id} found!")
training_model_tasks = [] for feature in max_features: for estimator in n_estimators: ml_id = f"{feature}_{estimator}" training_model_tasks.append( PapermillOperator( task_id=f'training_model_{ml_id}', input_nb= '/usr/local/airflow/include/notebooks/avocado_prediction.ipynb', output_nb= f'/tmp/out-model-avocado-prediction-{ml_id}.ipynb', parameters={ 'filepath': '/tmp/avocado.csv', 'n_estimators': estimator, 'max_features': feature, 'ml_id': ml_id })) evaluating_rmse = BranchSQLOperator(task_id='evaluating_rmse', sql='sql/FETCH_MIN_RMSE.sql', conn_id='postgres', follow_task_ids_if_true='accurate', follow_task_ids_if_false='inaccurate') accurate = DummyOperator(task_id='accurate') inaccurate = DummyOperator(task_id='inaccurate') creating_table >> downloading_data >> sanity_check >> waiting_for_data >> training_model_tasks >> evaluating_rmse evaluating_rmse >> [accurate, inaccurate]