Example #1
0
    def test_invalid_query_result_with_dag_run(self, mock_get_db_hook):
        """Check BranchSQLOperator branch operation"""
        branch_op = BranchSQLOperator(
            task_id="make_choice",
            conn_id="mysql_default",
            sql="SELECT 1",
            follow_task_ids_if_true="branch_1",
            follow_task_ids_if_false="branch_2",
            dag=self.dag,
        )

        self.branch_1.set_upstream(branch_op)
        self.branch_2.set_upstream(branch_op)
        self.dag.clear()

        self.dag.create_dagrun(
            run_id="manual__",
            start_date=timezone.utcnow(),
            execution_date=DEFAULT_DATE,
            state=State.RUNNING,
        )

        mock_get_records = mock_get_db_hook.return_value.get_first

        mock_get_records.return_value = ["Invalid Value"]

        with pytest.raises(AirflowException):
            branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
Example #2
0
 def test_sql_branch_operator_postgres(self):
     """Check if BranchSQLOperator works with backend """
     branch_op = BranchSQLOperator(
         task_id="make_choice",
         conn_id="postgres_default",
         sql="SELECT 1",
         follow_task_ids_if_true="branch_1",
         follow_task_ids_if_false="branch_2",
         dag=self.dag,
     )
     branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
Example #3
0
    def test_invalid_follow_task_false(self):
        """Check if BranchSQLOperator throws an exception for invalid connection """
        op = BranchSQLOperator(
            task_id="make_choice",
            conn_id="invalid_connection",
            sql="SELECT count(1) FROM INFORMATION_SCHEMA.TABLES",
            follow_task_ids_if_true="branch_1",
            follow_task_ids_if_false=None,
            dag=self.dag,
        )

        with pytest.raises(AirflowException):
            op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
Example #4
0
    def test_unsupported_conn_type(self):
        """ Check if BranchSQLOperator throws an exception for unsupported connection type """
        op = BranchSQLOperator(
            task_id="make_choice",
            conn_id="redis_default",
            sql="SELECT count(1) FROM INFORMATION_SCHEMA.TABLES",
            follow_task_ids_if_true="branch_1",
            follow_task_ids_if_false="branch_2",
            dag=self.dag,
        )

        with self.assertRaises(AirflowException):
            op.run(start_date=DEFAULT_DATE,
                   end_date=DEFAULT_DATE, ignore_ti_state=True)
Example #5
0
    def test_branch_list_with_dag_run(self, mock_hook):
        """ Checks if the BranchSQLOperator supports branching off to a list of tasks."""
        branch_op = BranchSQLOperator(
            task_id="make_choice",
            conn_id="mysql_default",
            sql="SELECT 1",
            follow_task_ids_if_true=["branch_1", "branch_2"],
            follow_task_ids_if_false="branch_3",
            dag=self.dag,
        )

        self.branch_1.set_upstream(branch_op)
        self.branch_2.set_upstream(branch_op)
        self.branch_3 = DummyOperator(task_id="branch_3", dag=self.dag)
        self.branch_3.set_upstream(branch_op)
        self.dag.clear()

        dr = self.dag.create_dagrun(
            run_id="manual__",
            start_date=timezone.utcnow(),
            execution_date=DEFAULT_DATE,
            state=State.RUNNING,
        )

        mock_hook.get_connection("mysql_default").conn_type = "mysql"
        mock_get_records = (mock_hook.get_connection.return_value.get_hook.
                            return_value.get_first)
        mock_get_records.return_value = [["1"]]

        branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)

        tis = dr.get_task_instances()
        for ti in tis:
            if ti.task_id == "make_choice":
                self.assertEqual(ti.state, State.SUCCESS)
            elif ti.task_id == "branch_1":
                self.assertEqual(ti.state, State.NONE)
            elif ti.task_id == "branch_2":
                self.assertEqual(ti.state, State.NONE)
            elif ti.task_id == "branch_3":
                self.assertEqual(ti.state, State.SKIPPED)
            else:
                raise ValueError("Invalid task id {task_id} found!".format(
                    task_id=ti.task_id))
Example #6
0
    def test_with_skip_in_branch_downstream_dependencies2(self, mock_hook):
        """ Test skipping downstream dependency for false condition"""
        branch_op = BranchSQLOperator(
            task_id="make_choice",
            conn_id="mysql_default",
            sql="SELECT 1",
            follow_task_ids_if_true="branch_1",
            follow_task_ids_if_false="branch_2",
            dag=self.dag,
        )

        branch_op >> self.branch_1 >> self.branch_2
        branch_op >> self.branch_2
        self.dag.clear()

        dr = self.dag.create_dagrun(
            run_id="manual__",
            start_date=timezone.utcnow(),
            execution_date=DEFAULT_DATE,
            state=State.RUNNING,
        )

        mock_hook.get_connection("mysql_default").conn_type = "mysql"
        mock_get_records = (mock_hook.get_connection.return_value.get_hook.
                            return_value.get_first)

        for false_value in SUPPORTED_FALSE_VALUES:
            mock_get_records.return_value = [false_value]

            branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)

            tis = dr.get_task_instances()
            for ti in tis:
                if ti.task_id == "make_choice":
                    self.assertEqual(ti.state, State.SUCCESS)
                elif ti.task_id == "branch_1":
                    self.assertEqual(ti.state, State.SKIPPED)
                elif ti.task_id == "branch_2":
                    self.assertEqual(ti.state, State.NONE)
                else:
                    raise ValueError("Invalid task id {task_id} found!".format(
                        task_id=ti.task_id))
Example #7
0
    def test_branch_false_with_dag_run(self, mock_hook):
        """ Check BranchSQLOperator branch operation """
        branch_op = BranchSQLOperator(
            task_id="make_choice",
            conn_id="mysql_default",
            sql="SELECT 1",
            follow_task_ids_if_true="branch_1",
            follow_task_ids_if_false="branch_2",
            dag=self.dag,
        )

        self.branch_1.set_upstream(branch_op)
        self.branch_2.set_upstream(branch_op)
        self.dag.clear()

        dr = self.dag.create_dagrun(
            run_id="manual__",
            start_date=timezone.utcnow(),
            execution_date=DEFAULT_DATE,
            state=State.RUNNING,
        )

        mock_hook.get_connection("mysql_default").conn_type = "mysql"
        mock_get_records = (
            mock_hook.get_connection.return_value.get_hook.return_value.get_first
        )

        for false_value in SUPPORTED_FALSE_VALUES:
            mock_get_records.return_value = false_value

            branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)

            tis = dr.get_task_instances()
            for ti in tis:
                if ti.task_id == "make_choice":
                    self.assertEqual(ti.state, State.SUCCESS)
                elif ti.task_id == "branch_1":
                    self.assertEqual(ti.state, State.SKIPPED)
                elif ti.task_id == "branch_2":
                    self.assertEqual(ti.state, State.NONE)
                else:
                    raise ValueError(f"Invalid task id {ti.task_id} found!")
Example #8
0
    def test_with_skip_in_branch_downstream_dependencies(
            self, mock_get_db_hook):
        """Test SQL Branch with skipping all downstream dependencies"""
        branch_op = BranchSQLOperator(
            task_id="make_choice",
            conn_id="mysql_default",
            sql="SELECT 1",
            follow_task_ids_if_true="branch_1",
            follow_task_ids_if_false="branch_2",
            dag=self.dag,
        )

        branch_op >> self.branch_1 >> self.branch_2
        branch_op >> self.branch_2
        self.dag.clear()

        dr = self.dag.create_dagrun(
            run_id="manual__",
            start_date=timezone.utcnow(),
            execution_date=DEFAULT_DATE,
            state=State.RUNNING,
        )

        mock_get_records = mock_get_db_hook.return_value.get_first

        for true_value in SUPPORTED_TRUE_VALUES:
            mock_get_records.return_value = [true_value]

            branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)

            tis = dr.get_task_instances()
            for ti in tis:
                if ti.task_id == "make_choice":
                    assert ti.state == State.SUCCESS
                elif ti.task_id == "branch_1":
                    assert ti.state == State.NONE
                elif ti.task_id == "branch_2":
                    assert ti.state == State.NONE
                else:
                    raise ValueError(f"Invalid task id {ti.task_id} found!")
Example #9
0
    def test_branch_single_value_with_dag_run(self, mock_get_db_hook):
        """Check BranchSQLOperator branch operation"""
        branch_op = BranchSQLOperator(
            task_id="make_choice",
            conn_id="mysql_default",
            sql="SELECT 1",
            follow_task_ids_if_true="branch_1",
            follow_task_ids_if_false="branch_2",
            dag=self.dag,
        )

        self.branch_1.set_upstream(branch_op)
        self.branch_2.set_upstream(branch_op)
        self.dag.clear()

        dr = self.dag.create_dagrun(
            run_id="manual__",
            start_date=timezone.utcnow(),
            execution_date=DEFAULT_DATE,
            state=State.RUNNING,
        )

        mock_get_records = mock_get_db_hook.return_value.get_first

        mock_get_records.return_value = 1

        branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)

        tis = dr.get_task_instances()
        for ti in tis:
            if ti.task_id == "make_choice":
                assert ti.state == State.SUCCESS
            elif ti.task_id == "branch_1":
                assert ti.state == State.NONE
            elif ti.task_id == "branch_2":
                assert ti.state == State.SKIPPED
            else:
                raise ValueError(f"Invalid task id {ti.task_id} found!")
    training_model_tasks = []
    for feature in max_features:
        for estimator in n_estimators:
            ml_id = f"{feature}_{estimator}"
            training_model_tasks.append(
                PapermillOperator(
                    task_id=f'training_model_{ml_id}',
                    input_nb=
                    '/usr/local/airflow/include/notebooks/avocado_prediction.ipynb',
                    output_nb=
                    f'/tmp/out-model-avocado-prediction-{ml_id}.ipynb',
                    parameters={
                        'filepath': '/tmp/avocado.csv',
                        'n_estimators': estimator,
                        'max_features': feature,
                        'ml_id': ml_id
                    }))

    evaluating_rmse = BranchSQLOperator(task_id='evaluating_rmse',
                                        sql='sql/FETCH_MIN_RMSE.sql',
                                        conn_id='postgres',
                                        follow_task_ids_if_true='accurate',
                                        follow_task_ids_if_false='inaccurate')

    accurate = DummyOperator(task_id='accurate')

    inaccurate = DummyOperator(task_id='inaccurate')

    creating_table >> downloading_data >> sanity_check >> waiting_for_data >> training_model_tasks >> evaluating_rmse
    evaluating_rmse >> [accurate, inaccurate]