Ejemplo n.º 1
0
    def test_invalid_query_result_with_dag_run(self, mock_hook):
        """ Check BranchSqlOperator branch operation """
        branch_op = BranchSqlOperator(
            task_id="make_choice",
            conn_id="mysql_default",
            sql="SELECT 1",
            follow_task_ids_if_true="branch_1",
            follow_task_ids_if_false="branch_2",
            dag=self.dag,
        )

        self.branch_1.set_upstream(branch_op)
        self.branch_2.set_upstream(branch_op)
        self.dag.clear()

        self.dag.create_dagrun(
            run_id="manual__",
            start_date=timezone.utcnow(),
            execution_date=DEFAULT_DATE,
            state=State.RUNNING,
        )

        mock_hook.get_connection("mysql_default").conn_type = "mysql"
        mock_get_records = (mock_hook.get_connection.return_value.get_hook.
                            return_value.get_first)

        mock_get_records.return_value = ["Invalid Value"]

        with self.assertRaises(AirflowException):
            branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
Ejemplo n.º 2
0
 def test_sql_branch_operator_postgres(self):
     """ Check if BranchSqlOperator works with backend """
     branch_op = BranchSqlOperator(
         task_id="make_choice",
         conn_id="postgres_default",
         sql="SELECT 1",
         follow_task_ids_if_true="branch_1",
         follow_task_ids_if_false="branch_2",
         dag=self.dag,
     )
     branch_op.run(start_date=DEFAULT_DATE,
                   end_date=DEFAULT_DATE,
                   ignore_ti_state=True)
Ejemplo n.º 3
0
    def test_unsupported_conn_type(self):
        """ Check if BranchSqlOperator throws an exception for unsupported connection type """
        op = BranchSqlOperator(
            task_id="make_choice",
            conn_id="redis_default",
            sql="SELECT count(1) FROM INFORMATION_SCHEMA.TABLES",
            follow_task_ids_if_true="branch_1",
            follow_task_ids_if_false="branch_2",
            dag=self.dag,
        )

        with self.assertRaises(AirflowException):
            op.run(start_date=DEFAULT_DATE,
                   end_date=DEFAULT_DATE,
                   ignore_ti_state=True)
Ejemplo n.º 4
0
    def test_branch_list_with_dag_run(self, mock_hook):
        """ Checks if the BranchSqlOperator supports branching off to a list of tasks."""
        branch_op = BranchSqlOperator(
            task_id="make_choice",
            conn_id="mysql_default",
            sql="SELECT 1",
            follow_task_ids_if_true=["branch_1", "branch_2"],
            follow_task_ids_if_false="branch_3",
            dag=self.dag,
        )

        self.branch_1.set_upstream(branch_op)
        self.branch_2.set_upstream(branch_op)
        self.branch_3 = DummyOperator(task_id="branch_3", dag=self.dag)
        self.branch_3.set_upstream(branch_op)
        self.dag.clear()

        dr = self.dag.create_dagrun(
            run_id="manual__",
            start_date=timezone.utcnow(),
            execution_date=DEFAULT_DATE,
            state=State.RUNNING,
        )

        mock_hook.get_connection("mysql_default").conn_type = "mysql"
        mock_get_records = (mock_hook.get_connection.return_value.get_hook.
                            return_value.get_first)
        mock_get_records.return_value = [["1"]]

        branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)

        tis = dr.get_task_instances()
        for ti in tis:
            if ti.task_id == "make_choice":
                self.assertEqual(ti.state, State.SUCCESS)
            elif ti.task_id == "branch_1":
                self.assertEqual(ti.state, State.NONE)
            elif ti.task_id == "branch_2":
                self.assertEqual(ti.state, State.NONE)
            elif ti.task_id == "branch_3":
                self.assertEqual(ti.state, State.SKIPPED)
            else:
                raise ValueError(f"Invalid task id {ti.task_id} found!")
Ejemplo n.º 5
0
    def test_with_skip_in_branch_downstream_dependencies2(self, mock_hook):
        """ Test skipping downstream dependency for false condition"""
        branch_op = BranchSqlOperator(
            task_id="make_choice",
            conn_id="mysql_default",
            sql="SELECT 1",
            follow_task_ids_if_true="branch_1",
            follow_task_ids_if_false="branch_2",
            dag=self.dag,
        )

        branch_op >> self.branch_1 >> self.branch_2
        branch_op >> self.branch_2
        self.dag.clear()

        dr = self.dag.create_dagrun(
            run_id="manual__",
            start_date=timezone.utcnow(),
            execution_date=DEFAULT_DATE,
            state=State.RUNNING,
        )

        mock_hook.get_connection("mysql_default").conn_type = "mysql"
        mock_get_records = (mock_hook.get_connection.return_value.get_hook.
                            return_value.get_first)

        for false_value in SUPPORTED_FALSE_VALUES:
            mock_get_records.return_value = [false_value]

            branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)

            tis = dr.get_task_instances()
            for ti in tis:
                if ti.task_id == "make_choice":
                    self.assertEqual(ti.state, State.SUCCESS)
                elif ti.task_id == "branch_1":
                    self.assertEqual(ti.state, State.SKIPPED)
                elif ti.task_id == "branch_2":
                    self.assertEqual(ti.state, State.NONE)
                else:
                    raise ValueError(f"Invalid task id {ti.task_id} found!")
Ejemplo n.º 6
0
                                  python_callable=check_dataset,
                                  provide_context=True)

    waiting_for_data = FileSensor(task_id='waiting_for_data',
                                  fs_conn_id='fs_default',
                                  filepath='avocado.csv',
                                  poke_interval=15)

    training_model_tasks = SubDagOperator(
        task_id='training_model_tasks',
        subdag=subdag_factory('avocado_dag', 'training_model_tasks',
                              default_args))

    evaluating_rmse = BranchSqlOperator(task_id='evaluating_rmse',
                                        sql='sql/FETCH_MIN_RMSE.sql',
                                        conn_id='postgres',
                                        follow_task_ids_if_true='accurate',
                                        follow_task_ids_if_false='inaccurate')

    accurate = DummyOperator(task_id='accurate')

    inaccurate = DummyOperator(task_id='inaccurate')

    fetch_best_model = NotebookToKeepOperator(task_id='fetch_best_model',
                                              sql='sql/FETCH_BEST_MODEL.sql',
                                              postgres_conn_id='postgres')

    publish_notebook = NotebookToGitOperator(
        task_id='publish_notebook',
        conn_id='git',
        nb_path='/tmp',