Esempio n. 1
0
    def test_dagrun_deadlock(self):
        session = settings.Session()
        dag = DAG('text_dagrun_deadlock', start_date=DEFAULT_DATE, default_args={'owner': 'owner1'})

        with dag:
            op1 = DummyOperator(task_id='A')
            op2 = DummyOperator(task_id='B')
            op2.trigger_rule = TriggerRule.ONE_FAILED
            op2.set_upstream(op1)

        dag.clear()
        now = timezone.utcnow()
        dr = dag.create_dagrun(
            run_id='test_dagrun_deadlock', state=State.RUNNING, execution_date=now, start_date=now
        )

        ti_op1 = dr.get_task_instance(task_id=op1.task_id)
        ti_op1.set_state(state=State.SUCCESS, session=session)
        ti_op2 = dr.get_task_instance(task_id=op2.task_id)
        ti_op2.set_state(state=State.NONE, session=session)

        dr.update_state()
        assert dr.state == State.RUNNING

        ti_op2.set_state(state=State.NONE, session=session)
        op2.trigger_rule = 'invalid'
        dr.update_state()
        assert dr.state == State.FAILED
Esempio n. 2
0
    def test_dagrun_update_state_end_date(self):
        session = settings.Session()

        dag = DAG(
            'test_dagrun_update_state_end_date', start_date=DEFAULT_DATE, default_args={'owner': 'owner1'}
        )

        # A -> B
        with dag:
            op1 = DummyOperator(task_id='A')
            op2 = DummyOperator(task_id='B')
            op1.set_upstream(op2)

        dag.clear()

        now = timezone.utcnow()
        dr = dag.create_dagrun(
            run_id='test_dagrun_update_state_end_date',
            state=State.RUNNING,
            execution_date=now,
            start_date=now,
        )

        # Initial end_date should be NULL
        # State.SUCCESS and State.FAILED are all ending state and should set end_date
        # State.RUNNING set end_date back to NULL
        session.merge(dr)
        session.commit()
        assert dr.end_date is None

        ti_op1 = dr.get_task_instance(task_id=op1.task_id)
        ti_op1.set_state(state=State.SUCCESS, session=session)
        ti_op2 = dr.get_task_instance(task_id=op2.task_id)
        ti_op2.set_state(state=State.SUCCESS, session=session)

        dr.update_state()

        dr_database = session.query(DagRun).filter(DagRun.run_id == 'test_dagrun_update_state_end_date').one()
        assert dr_database.end_date is not None
        assert dr.end_date == dr_database.end_date

        ti_op1.set_state(state=State.RUNNING, session=session)
        ti_op2.set_state(state=State.RUNNING, session=session)
        dr.update_state()

        dr_database = session.query(DagRun).filter(DagRun.run_id == 'test_dagrun_update_state_end_date').one()

        assert dr._state == State.RUNNING
        assert dr.end_date is None
        assert dr_database.end_date is None

        ti_op1.set_state(state=State.FAILED, session=session)
        ti_op2.set_state(state=State.FAILED, session=session)
        dr.update_state()

        dr_database = session.query(DagRun).filter(DagRun.run_id == 'test_dagrun_update_state_end_date').one()

        assert dr_database.end_date is not None
        assert dr.end_date == dr_database.end_date
Esempio n. 3
0
    def test_without_dag_run(self):
        """This checks the defensive against non existent tasks in a dag run"""
        value = False
        dag = DAG(
            'shortcircuit_operator_test_without_dag_run',
            default_args={
                'owner': 'airflow',
                'start_date': DEFAULT_DATE
            },
            schedule_interval=INTERVAL,
        )
        short_op = ShortCircuitOperator(task_id='make_choice',
                                        dag=dag,
                                        python_callable=lambda: value)
        branch_1 = DummyOperator(task_id='branch_1', dag=dag)
        branch_1.set_upstream(short_op)
        branch_2 = DummyOperator(task_id='branch_2', dag=dag)
        branch_2.set_upstream(branch_1)
        upstream = DummyOperator(task_id='upstream', dag=dag)
        upstream.set_downstream(short_op)
        dag.clear()

        short_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)

        with create_session() as session:
            tis = session.query(TI).filter(TI.dag_id == dag.dag_id,
                                           TI.execution_date == DEFAULT_DATE)

            for ti in tis:
                if ti.task_id == 'make_choice':
                    assert ti.state == State.SUCCESS
                elif ti.task_id == 'upstream':
                    # should not exist
                    raise ValueError(f'Invalid task id {ti.task_id} found!')
                elif ti.task_id == 'branch_1' or ti.task_id == 'branch_2':
                    assert ti.state == State.SKIPPED
                else:
                    raise ValueError(f'Invalid task id {ti.task_id} found!')

            value = True
            dag.clear()

            short_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
            for ti in tis:
                if ti.task_id == 'make_choice':
                    assert ti.state == State.SUCCESS
                elif ti.task_id == 'upstream':
                    # should not exist
                    raise ValueError(f'Invalid task id {ti.task_id} found!')
                elif ti.task_id == 'branch_1' or ti.task_id == 'branch_2':
                    assert ti.state == State.NONE
                else:
                    raise ValueError(f'Invalid task id {ti.task_id} found!')
    def _make_smart_operator(self, index, **kwargs):
        poke_interval = 'poke_interval'
        smart_sensor_timeout = 'smart_sensor_timeout'
        if poke_interval not in kwargs:
            kwargs[poke_interval] = 0
        if smart_sensor_timeout not in kwargs:
            kwargs[smart_sensor_timeout] = 0

        smart_task = DummySmartSensor(task_id=SMART_OP + "_" + str(index), dag=self.dag, **kwargs)

        dummy_op = DummyOperator(task_id=DUMMY_OP, dag=self.dag)
        dummy_op.set_upstream(smart_task)
        return smart_task
Esempio n. 5
0
    def test_get_states_count_upstream_ti(self):
        """
        this test tests the helper function '_get_states_count_upstream_ti' as a unit and inside update_state
        """
        from airflow.ti_deps.dep_context import DepContext

        get_states_count_upstream_ti = TriggerRuleDep._get_states_count_upstream_ti
        session = settings.Session()
        now = timezone.utcnow()
        dag = DAG('test_dagrun_with_pre_tis', start_date=DEFAULT_DATE, default_args={'owner': 'owner1'})

        with dag:
            op1 = DummyOperator(task_id='A')
            op2 = DummyOperator(task_id='B')
            op3 = DummyOperator(task_id='C')
            op4 = DummyOperator(task_id='D')
            op5 = DummyOperator(task_id='E', trigger_rule=TriggerRule.ONE_FAILED)

            op1.set_downstream([op2, op3])  # op1 >> op2, op3
            op4.set_upstream([op3, op2])  # op3, op2 >> op4
            op5.set_upstream([op2, op3, op4])  # (op2, op3, op4) >> op5

        clear_db_runs()
        dag.clear()
        dr = dag.create_dagrun(
            run_id='test_dagrun_with_pre_tis', state=State.RUNNING, execution_date=now, start_date=now
        )

        ti_op1 = TaskInstance(task=dag.get_task(op1.task_id), execution_date=dr.execution_date)
        ti_op2 = TaskInstance(task=dag.get_task(op2.task_id), execution_date=dr.execution_date)
        ti_op3 = TaskInstance(task=dag.get_task(op3.task_id), execution_date=dr.execution_date)
        ti_op4 = TaskInstance(task=dag.get_task(op4.task_id), execution_date=dr.execution_date)
        ti_op5 = TaskInstance(task=dag.get_task(op5.task_id), execution_date=dr.execution_date)

        ti_op1.set_state(state=State.SUCCESS, session=session)
        ti_op2.set_state(state=State.FAILED, session=session)
        ti_op3.set_state(state=State.SUCCESS, session=session)
        ti_op4.set_state(state=State.SUCCESS, session=session)
        ti_op5.set_state(state=State.SUCCESS, session=session)

        session.commit()

        # check handling with cases that tasks are triggered from backfill with no finished tasks
        finished_tasks = DepContext().ensure_finished_tasks(ti_op2.task.dag, ti_op2.execution_date, session)
        assert get_states_count_upstream_ti(finished_tasks=finished_tasks, ti=ti_op2) == (1, 0, 0, 0, 1)
        finished_tasks = dr.get_task_instances(state=State.finished, session=session)
        assert get_states_count_upstream_ti(finished_tasks=finished_tasks, ti=ti_op4) == (1, 0, 1, 0, 2)
        assert get_states_count_upstream_ti(finished_tasks=finished_tasks, ti=ti_op5) == (2, 0, 1, 0, 3)

        dr.update_state()
        assert State.SUCCESS == dr.state
Esempio n. 6
0
    def _make_sensor(self, return_value, task_id=SENSOR_OP, **kwargs):
        poke_interval = 'poke_interval'
        timeout = 'timeout'

        if poke_interval not in kwargs:
            kwargs[poke_interval] = 0
        if timeout not in kwargs:
            kwargs[timeout] = 0

        sensor = DummySensor(task_id=task_id,
                             return_value=return_value,
                             dag=self.dag,
                             **kwargs)

        dummy_op = DummyOperator(task_id=DUMMY_OP, dag=self.dag)
        dummy_op.set_upstream(sensor)
        return sensor
Esempio n. 7
0
    def test_dagrun_success_conditions(self):
        session = settings.Session()

        dag = DAG('test_dagrun_success_conditions',
                  start_date=DEFAULT_DATE,
                  default_args={'owner': 'owner1'})

        # A -> B
        # A -> C -> D
        # ordered: B, D, C, A or D, B, C, A or D, C, B, A
        with dag:
            op1 = DummyOperator(task_id='A')
            op2 = DummyOperator(task_id='B')
            op3 = DummyOperator(task_id='C')
            op4 = DummyOperator(task_id='D')
            op1.set_upstream([op2, op3])
            op3.set_upstream(op4)

        dag.clear()

        now = timezone.utcnow()
        dr = dag.create_dagrun(run_id='test_dagrun_success_conditions',
                               state=State.RUNNING,
                               execution_date=now,
                               start_date=now)

        # op1 = root
        ti_op1 = dr.get_task_instance(task_id=op1.task_id)
        ti_op1.set_state(state=State.SUCCESS, session=session)

        ti_op2 = dr.get_task_instance(task_id=op2.task_id)
        ti_op3 = dr.get_task_instance(task_id=op3.task_id)
        ti_op4 = dr.get_task_instance(task_id=op4.task_id)

        # root is successful, but unfinished tasks
        dr.update_state()
        assert State.RUNNING == dr.state

        # one has failed, but root is successful
        ti_op2.set_state(state=State.FAILED, session=session)
        ti_op3.set_state(state=State.SUCCESS, session=session)
        ti_op4.set_state(state=State.SUCCESS, session=session)
        dr.update_state()
        assert State.SUCCESS == dr.state
Esempio n. 8
0
    def test_dagrun_no_deadlock_with_shutdown(self):
        session = settings.Session()
        dag = DAG('test_dagrun_no_deadlock_with_shutdown', start_date=DEFAULT_DATE)
        with dag:
            op1 = DummyOperator(task_id='upstream_task')
            op2 = DummyOperator(task_id='downstream_task')
            op2.set_upstream(op1)

        dr = dag.create_dagrun(
            run_id='test_dagrun_no_deadlock_with_shutdown',
            state=State.RUNNING,
            execution_date=DEFAULT_DATE,
            start_date=DEFAULT_DATE,
        )
        upstream_ti = dr.get_task_instance(task_id='upstream_task')
        upstream_ti.set_state(State.SHUTDOWN, session=session)

        dr.update_state()
        assert dr.state == State.RUNNING
Esempio n. 9
0
    def test_operator_clear(self):
        dag = DAG(
            'test_operator_clear',
            start_date=DEFAULT_DATE,
            end_date=DEFAULT_DATE + datetime.timedelta(days=10),
        )
        op1 = DummyOperator(task_id='bash_op', owner='test', dag=dag)
        op2 = DummyOperator(task_id='dummy_op',
                            owner='test',
                            dag=dag,
                            retries=1)

        op2.set_upstream(op1)

        ti1 = TI(task=op1, execution_date=DEFAULT_DATE)
        ti2 = TI(task=op2, execution_date=DEFAULT_DATE)

        dag.create_dagrun(
            execution_date=ti1.execution_date,
            state=State.RUNNING,
            run_type=DagRunType.SCHEDULED,
        )

        ti2.run()
        # Dependency not met
        assert ti2.try_number == 1
        assert ti2.max_tries == 1

        op2.clear(upstream=True)
        ti1.run()
        ti2.run(ignore_ti_state=True)
        assert ti1.try_number == 2
        # max_tries is 0 because there is no task instance in db for ti1
        # so clear won't change the max_tries.
        assert ti1.max_tries == 0
        assert ti2.try_number == 2
        # try_number (0) + retries(1)
        assert ti2.max_tries == 1
Esempio n. 10
0
# specific language governing permissions and limitations
# under the License.
"""This dag only runs some simple tasks to test Airflow's task execution."""
from datetime import datetime, timedelta

from airflow.models.dag import DAG
from airflow.operators.dummy import DummyOperator
from airflow.utils.dates import days_ago

now = datetime.now()
now_to_the_hour = (now - timedelta(0, 0, 0, 0, 0, 3)).replace(minute=0,
                                                              second=0,
                                                              microsecond=0)
START_DATE = now_to_the_hour
DAG_NAME = 'test_dag_v2'

default_args = {
    'owner': 'airflow',
    'depends_on_past': True,
    'start_date': days_ago(2)
}
dag = DAG(DAG_NAME,
          schedule_interval='*/10 * * * *',
          default_args=default_args)

run_this_1 = DummyOperator(task_id='run_this_1', dag=dag)
run_this_2 = DummyOperator(task_id='run_this_2', dag=dag)
run_this_2.set_upstream(run_this_1)
run_this_3 = DummyOperator(task_id='run_this_3', dag=dag)
run_this_3.set_upstream(run_this_2)
    def test_skipping_non_latest(self):
        latest_task = LatestOnlyOperator(task_id='latest', dag=self.dag)
        downstream_task = DummyOperator(task_id='downstream', dag=self.dag)
        downstream_task2 = DummyOperator(task_id='downstream_2', dag=self.dag)
        downstream_task3 = DummyOperator(task_id='downstream_3',
                                         trigger_rule=TriggerRule.NONE_FAILED,
                                         dag=self.dag)

        downstream_task.set_upstream(latest_task)
        downstream_task2.set_upstream(downstream_task)
        downstream_task3.set_upstream(downstream_task)

        self.dag.create_dagrun(
            run_type=DagRunType.SCHEDULED,
            start_date=timezone.utcnow(),
            execution_date=DEFAULT_DATE,
            state=State.RUNNING,
        )

        self.dag.create_dagrun(
            run_type=DagRunType.SCHEDULED,
            start_date=timezone.utcnow(),
            execution_date=timezone.datetime(2016, 1, 1, 12),
            state=State.RUNNING,
        )

        self.dag.create_dagrun(
            run_type=DagRunType.SCHEDULED,
            start_date=timezone.utcnow(),
            execution_date=END_DATE,
            state=State.RUNNING,
        )

        latest_task.run(start_date=DEFAULT_DATE, end_date=END_DATE)
        downstream_task.run(start_date=DEFAULT_DATE, end_date=END_DATE)
        downstream_task2.run(start_date=DEFAULT_DATE, end_date=END_DATE)
        downstream_task3.run(start_date=DEFAULT_DATE, end_date=END_DATE)

        latest_instances = get_task_instances('latest')
        exec_date_to_latest_state = {
            ti.execution_date: ti.state
            for ti in latest_instances
        }
        assert {
            timezone.datetime(2016, 1, 1): 'success',
            timezone.datetime(2016, 1, 1, 12): 'success',
            timezone.datetime(2016, 1, 2): 'success',
        } == exec_date_to_latest_state

        downstream_instances = get_task_instances('downstream')
        exec_date_to_downstream_state = {
            ti.execution_date: ti.state
            for ti in downstream_instances
        }
        assert {
            timezone.datetime(2016, 1, 1): 'skipped',
            timezone.datetime(2016, 1, 1, 12): 'skipped',
            timezone.datetime(2016, 1, 2): 'success',
        } == exec_date_to_downstream_state

        downstream_instances = get_task_instances('downstream_2')
        exec_date_to_downstream_state = {
            ti.execution_date: ti.state
            for ti in downstream_instances
        }
        assert {
            timezone.datetime(2016, 1, 1): None,
            timezone.datetime(2016, 1, 1, 12): None,
            timezone.datetime(2016, 1, 2): 'success',
        } == exec_date_to_downstream_state

        downstream_instances = get_task_instances('downstream_3')
        exec_date_to_downstream_state = {
            ti.execution_date: ti.state
            for ti in downstream_instances
        }
        assert {
            timezone.datetime(2016, 1, 1): 'success',
            timezone.datetime(2016, 1, 1, 12): 'success',
            timezone.datetime(2016, 1, 2): 'success',
        } == exec_date_to_downstream_state
Esempio n. 12
0
class TestBranchDateTimeOperator(unittest.TestCase):
    @classmethod
    def setUpClass(cls):
        super().setUpClass()

        with create_session() as session:
            session.query(DagRun).delete()
            session.query(TI).delete()

        cls.targets = [
            (datetime.datetime(2020, 7, 7, 10, 0,
                               0), datetime.datetime(2020, 7, 7, 11, 0, 0)),
            (datetime.time(10, 0, 0), datetime.time(11, 0, 0)),
            (datetime.datetime(2020, 7, 7, 10, 0, 0), datetime.time(11, 0, 0)),
            (datetime.time(10, 0, 0), datetime.datetime(2020, 7, 7, 11, 0, 0)),
            (datetime.time(11, 0, 0), datetime.time(10, 0, 0)),
        ]

    def setUp(self):
        self.dag = DAG(
            'branch_datetime_operator_test',
            default_args={
                'owner': 'airflow',
                'start_date': DEFAULT_DATE
            },
            schedule_interval=INTERVAL,
        )

        self.branch_1 = DummyOperator(task_id='branch_1', dag=self.dag)
        self.branch_2 = DummyOperator(task_id='branch_2', dag=self.dag)

        self.branch_op = BranchDateTimeOperator(
            task_id='datetime_branch',
            follow_task_ids_if_true='branch_1',
            follow_task_ids_if_false='branch_2',
            target_upper=datetime.datetime(2020, 7, 7, 11, 0, 0),
            target_lower=datetime.datetime(2020, 7, 7, 10, 0, 0),
            dag=self.dag,
        )

        self.branch_1.set_upstream(self.branch_op)
        self.branch_2.set_upstream(self.branch_op)
        self.dag.clear()

        self.dr = self.dag.create_dagrun(run_id='manual__',
                                         start_date=DEFAULT_DATE,
                                         execution_date=DEFAULT_DATE,
                                         state=State.RUNNING)

    def tearDown(self):
        super().tearDown()

        with create_session() as session:
            session.query(DagRun).delete()
            session.query(TI).delete()

    def _assert_task_ids_match_states(self, task_ids_to_states):
        """Helper that asserts task instances with a given id are in a given state"""
        tis = self.dr.get_task_instances()
        for ti in tis:
            try:
                expected_state = task_ids_to_states[ti.task_id]
            except KeyError:
                raise ValueError(f'Invalid task id {ti.task_id} found!')
            else:
                self.assertEqual(
                    ti.state,
                    expected_state,
                    f"Task {ti.task_id} has state {ti.state} instead of expected {expected_state}",
                )

    def test_no_target_time(self):
        """Check if BranchDateTimeOperator raises exception on missing target"""
        with self.assertRaises(AirflowException):
            BranchDateTimeOperator(
                task_id='datetime_branch',
                follow_task_ids_if_true='branch_1',
                follow_task_ids_if_false='branch_2',
                target_upper=None,
                target_lower=None,
                dag=self.dag,
            )

    @freezegun.freeze_time("2020-07-07 10:54:05")
    def test_branch_datetime_operator_falls_within_range(self):
        """Check BranchDateTimeOperator branch operation"""
        for target_lower, target_upper in self.targets:
            with self.subTest(target_lower=target_lower,
                              target_upper=target_upper):
                self.branch_op.target_lower = target_lower
                self.branch_op.target_upper = target_upper
                self.branch_op.run(start_date=DEFAULT_DATE,
                                   end_date=DEFAULT_DATE)

                self._assert_task_ids_match_states({
                    'datetime_branch': State.SUCCESS,
                    'branch_1': State.NONE,
                    'branch_2': State.SKIPPED,
                })

    def test_branch_datetime_operator_falls_outside_range(self):
        """Check BranchDateTimeOperator branch operation"""
        dates = [
            datetime.datetime(2020,
                              7,
                              7,
                              12,
                              0,
                              0,
                              tzinfo=datetime.timezone.utc),
            datetime.datetime(2020,
                              6,
                              7,
                              12,
                              0,
                              0,
                              tzinfo=datetime.timezone.utc),
        ]

        for target_lower, target_upper in self.targets:
            with self.subTest(target_lower=target_lower,
                              target_upper=target_upper):
                self.branch_op.target_lower = target_lower
                self.branch_op.target_upper = target_upper

                for date in dates:
                    with freezegun.freeze_time(date):
                        self.branch_op.run(start_date=DEFAULT_DATE,
                                           end_date=DEFAULT_DATE)

                        self._assert_task_ids_match_states({
                            'datetime_branch':
                            State.SUCCESS,
                            'branch_1':
                            State.SKIPPED,
                            'branch_2':
                            State.NONE,
                        })

    @freezegun.freeze_time("2020-07-07 10:54:05")
    def test_branch_datetime_operator_upper_comparison_within_range(self):
        """Check BranchDateTimeOperator branch operation"""
        for _, target_upper in self.targets:
            with self.subTest(target_upper=target_upper):
                self.branch_op.target_upper = target_upper
                self.branch_op.target_lower = None

                self.branch_op.run(start_date=DEFAULT_DATE,
                                   end_date=DEFAULT_DATE)

                self._assert_task_ids_match_states({
                    'datetime_branch': State.SUCCESS,
                    'branch_1': State.NONE,
                    'branch_2': State.SKIPPED,
                })

    @freezegun.freeze_time("2020-07-07 10:54:05")
    def test_branch_datetime_operator_lower_comparison_within_range(self):
        """Check BranchDateTimeOperator branch operation"""
        for target_lower, _ in self.targets:
            with self.subTest(target_lower=target_lower):
                self.branch_op.target_lower = target_lower
                self.branch_op.target_upper = None

                self.branch_op.run(start_date=DEFAULT_DATE,
                                   end_date=DEFAULT_DATE)

                self._assert_task_ids_match_states({
                    'datetime_branch': State.SUCCESS,
                    'branch_1': State.NONE,
                    'branch_2': State.SKIPPED,
                })

    @freezegun.freeze_time("2020-07-07 12:00:00")
    def test_branch_datetime_operator_upper_comparison_outside_range(self):
        """Check BranchDateTimeOperator branch operation"""
        for _, target_upper in self.targets:
            with self.subTest(target_upper=target_upper):
                self.branch_op.target_upper = target_upper
                self.branch_op.target_lower = None

                self.branch_op.run(start_date=DEFAULT_DATE,
                                   end_date=DEFAULT_DATE)

                self._assert_task_ids_match_states({
                    'datetime_branch': State.SUCCESS,
                    'branch_1': State.SKIPPED,
                    'branch_2': State.NONE,
                })

    @freezegun.freeze_time("2020-07-07 09:00:00")
    def test_branch_datetime_operator_lower_comparison_outside_range(self):
        """Check BranchDateTimeOperator branch operation"""
        for target_lower, _ in self.targets:
            with self.subTest(target_lower=target_lower):
                self.branch_op.target_lower = target_lower
                self.branch_op.target_upper = None

                self.branch_op.run(start_date=DEFAULT_DATE,
                                   end_date=DEFAULT_DATE)

                self._assert_task_ids_match_states({
                    'datetime_branch': State.SUCCESS,
                    'branch_1': State.SKIPPED,
                    'branch_2': State.NONE,
                })

    @freezegun.freeze_time("2020-12-01 09:00:00")
    def test_branch_datetime_operator_use_task_execution_date(self):
        """Check if BranchDateTimeOperator uses task execution date"""
        in_between_date = timezone.datetime(2020, 7, 7, 10, 30, 0)
        self.branch_op.use_task_execution_date = True
        self.dr = self.dag.create_dagrun(
            run_id='manual_exec_date__',
            start_date=in_between_date,
            execution_date=in_between_date,
            state=State.RUNNING,
        )

        for target_lower, target_upper in self.targets:
            with self.subTest(target_lower=target_lower,
                              target_upper=target_upper):
                self.branch_op.target_lower = target_lower
                self.branch_op.target_upper = target_upper
                self.branch_op.run(start_date=in_between_date,
                                   end_date=in_between_date)

                self._assert_task_ids_match_states({
                    'datetime_branch': State.SUCCESS,
                    'branch_1': State.NONE,
                    'branch_2': State.SKIPPED,
                })
Esempio n. 13
0
    dag=dag1,
    pool='test_backfill_pooled_task_pool',
)

# dag2 has been moved to test_prev_dagrun_dep.py

# DAG tests that a Dag run that doesn't complete is marked failed
dag3 = DAG(dag_id='test_dagrun_states_fail', default_args=default_args)
dag3_task1 = PythonOperator(task_id='test_dagrun_fail',
                            dag=dag3,
                            python_callable=fail)
dag3_task2 = DummyOperator(
    task_id='test_dagrun_succeed',
    dag=dag3,
)
dag3_task2.set_upstream(dag3_task1)

# DAG tests that a Dag run that completes but has a failure is marked success
dag4 = DAG(dag_id='test_dagrun_states_success', default_args=default_args)
dag4_task1 = PythonOperator(
    task_id='test_dagrun_fail',
    dag=dag4,
    python_callable=fail,
)
dag4_task2 = DummyOperator(task_id='test_dagrun_succeed',
                           dag=dag4,
                           trigger_rule=TriggerRule.ALL_FAILED)
dag4_task2.set_upstream(dag4_task1)

# DAG tests that a Dag run that completes but has a root failure is marked fail
dag5 = DAG(dag_id='test_dagrun_states_root_fail', default_args=default_args)
seven_days_ago = datetime.combine(datetime.today() - timedelta(7),
                                  datetime.min.time())
args = {
    'owner': 'airflow',
    'start_date': seven_days_ago,
}

dag = DAG(dag_id='example_branch_operator',
          default_args=args,
          schedule_interval="@daily")

cmd = 'ls -l'
run_this_first = DummyOperator(task_id='run_this_first', dag=dag)

options = ['branch_a', 'branch_b', 'branch_c', 'branch_d']

branching = BranchPythonOperator(
    task_id='branching',
    python_callable=lambda: random.choice(options),
    dag=dag)
branching.set_upstream(run_this_first)

join = DummyOperator(task_id='join', trigger_rule='one_success', dag=dag)

for option in options:
    t = DummyOperator(task_id=option, dag=dag)
    t.set_upstream(branching)
    dummy_follow = DummyOperator(task_id='follow_' + option, dag=dag)
    t.set_downstream(dummy_follow)
    dummy_follow.set_downstream(join)
Esempio n. 15
0
class TestSqlBranch(TestHiveEnvironment, unittest.TestCase):
    """
    Test for SQL Branch Operator
    """
    @classmethod
    def setUpClass(cls):
        super().setUpClass()

        with create_session() as session:
            session.query(DagRun).delete()
            session.query(TI).delete()

    def setUp(self):
        super().setUp()
        self.dag = DAG(
            "sql_branch_operator_test",
            default_args={
                "owner": "airflow",
                "start_date": DEFAULT_DATE
            },
            schedule_interval=INTERVAL,
        )
        self.branch_1 = DummyOperator(task_id="branch_1", dag=self.dag)
        self.branch_2 = DummyOperator(task_id="branch_2", dag=self.dag)
        self.branch_3 = None

    def tearDown(self):
        super().tearDown()

        with create_session() as session:
            session.query(DagRun).delete()
            session.query(TI).delete()

    def test_unsupported_conn_type(self):
        """Check if BranchSQLOperator throws an exception for unsupported connection type"""
        op = BranchSQLOperator(
            task_id="make_choice",
            conn_id="redis_default",
            sql="SELECT count(1) FROM INFORMATION_SCHEMA.TABLES",
            follow_task_ids_if_true="branch_1",
            follow_task_ids_if_false="branch_2",
            dag=self.dag,
        )

        with pytest.raises(AirflowException):
            op.run(start_date=DEFAULT_DATE,
                   end_date=DEFAULT_DATE,
                   ignore_ti_state=True)

    def test_invalid_conn(self):
        """Check if BranchSQLOperator throws an exception for invalid connection"""
        op = BranchSQLOperator(
            task_id="make_choice",
            conn_id="invalid_connection",
            sql="SELECT count(1) FROM INFORMATION_SCHEMA.TABLES",
            follow_task_ids_if_true="branch_1",
            follow_task_ids_if_false="branch_2",
            dag=self.dag,
        )

        with pytest.raises(AirflowException):
            op.run(start_date=DEFAULT_DATE,
                   end_date=DEFAULT_DATE,
                   ignore_ti_state=True)

    def test_invalid_follow_task_true(self):
        """Check if BranchSQLOperator throws an exception for invalid connection"""
        op = BranchSQLOperator(
            task_id="make_choice",
            conn_id="invalid_connection",
            sql="SELECT count(1) FROM INFORMATION_SCHEMA.TABLES",
            follow_task_ids_if_true=None,
            follow_task_ids_if_false="branch_2",
            dag=self.dag,
        )

        with pytest.raises(AirflowException):
            op.run(start_date=DEFAULT_DATE,
                   end_date=DEFAULT_DATE,
                   ignore_ti_state=True)

    def test_invalid_follow_task_false(self):
        """Check if BranchSQLOperator throws an exception for invalid connection"""
        op = BranchSQLOperator(
            task_id="make_choice",
            conn_id="invalid_connection",
            sql="SELECT count(1) FROM INFORMATION_SCHEMA.TABLES",
            follow_task_ids_if_true="branch_1",
            follow_task_ids_if_false=None,
            dag=self.dag,
        )

        with pytest.raises(AirflowException):
            op.run(start_date=DEFAULT_DATE,
                   end_date=DEFAULT_DATE,
                   ignore_ti_state=True)

    @pytest.mark.backend("mysql")
    def test_sql_branch_operator_mysql(self):
        """Check if BranchSQLOperator works with backend"""
        branch_op = BranchSQLOperator(
            task_id="make_choice",
            conn_id="mysql_default",
            sql="SELECT 1",
            follow_task_ids_if_true="branch_1",
            follow_task_ids_if_false="branch_2",
            dag=self.dag,
        )
        branch_op.run(start_date=DEFAULT_DATE,
                      end_date=DEFAULT_DATE,
                      ignore_ti_state=True)

    @pytest.mark.backend("postgres")
    def test_sql_branch_operator_postgres(self):
        """Check if BranchSQLOperator works with backend"""
        branch_op = BranchSQLOperator(
            task_id="make_choice",
            conn_id="postgres_default",
            sql="SELECT 1",
            follow_task_ids_if_true="branch_1",
            follow_task_ids_if_false="branch_2",
            dag=self.dag,
        )
        branch_op.run(start_date=DEFAULT_DATE,
                      end_date=DEFAULT_DATE,
                      ignore_ti_state=True)

    @mock.patch("airflow.operators.sql.BaseSQLOperator.get_db_hook")
    def test_branch_single_value_with_dag_run(self, mock_get_db_hook):
        """Check BranchSQLOperator branch operation"""
        branch_op = BranchSQLOperator(
            task_id="make_choice",
            conn_id="mysql_default",
            sql="SELECT 1",
            follow_task_ids_if_true="branch_1",
            follow_task_ids_if_false="branch_2",
            dag=self.dag,
        )

        self.branch_1.set_upstream(branch_op)
        self.branch_2.set_upstream(branch_op)
        self.dag.clear()

        dr = self.dag.create_dagrun(
            run_id="manual__",
            start_date=timezone.utcnow(),
            execution_date=DEFAULT_DATE,
            state=State.RUNNING,
        )

        mock_get_records = mock_get_db_hook.return_value.get_first

        mock_get_records.return_value = 1

        branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)

        tis = dr.get_task_instances()
        for ti in tis:
            if ti.task_id == "make_choice":
                assert ti.state == State.SUCCESS
            elif ti.task_id == "branch_1":
                assert ti.state == State.NONE
            elif ti.task_id == "branch_2":
                assert ti.state == State.SKIPPED
            else:
                raise ValueError(f"Invalid task id {ti.task_id} found!")

    @mock.patch("airflow.operators.sql.BaseSQLOperator.get_db_hook")
    def test_branch_true_with_dag_run(self, mock_get_db_hook):
        """Check BranchSQLOperator branch operation"""
        branch_op = BranchSQLOperator(
            task_id="make_choice",
            conn_id="mysql_default",
            sql="SELECT 1",
            follow_task_ids_if_true="branch_1",
            follow_task_ids_if_false="branch_2",
            dag=self.dag,
        )

        self.branch_1.set_upstream(branch_op)
        self.branch_2.set_upstream(branch_op)
        self.dag.clear()

        dr = self.dag.create_dagrun(
            run_id="manual__",
            start_date=timezone.utcnow(),
            execution_date=DEFAULT_DATE,
            state=State.RUNNING,
        )

        mock_get_records = mock_get_db_hook.return_value.get_first

        for true_value in SUPPORTED_TRUE_VALUES:
            mock_get_records.return_value = true_value

            branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)

            tis = dr.get_task_instances()
            for ti in tis:
                if ti.task_id == "make_choice":
                    assert ti.state == State.SUCCESS
                elif ti.task_id == "branch_1":
                    assert ti.state == State.NONE
                elif ti.task_id == "branch_2":
                    assert ti.state == State.SKIPPED
                else:
                    raise ValueError(f"Invalid task id {ti.task_id} found!")

    @mock.patch("airflow.operators.sql.BaseSQLOperator.get_db_hook")
    def test_branch_false_with_dag_run(self, mock_get_db_hook):
        """Check BranchSQLOperator branch operation"""
        branch_op = BranchSQLOperator(
            task_id="make_choice",
            conn_id="mysql_default",
            sql="SELECT 1",
            follow_task_ids_if_true="branch_1",
            follow_task_ids_if_false="branch_2",
            dag=self.dag,
        )

        self.branch_1.set_upstream(branch_op)
        self.branch_2.set_upstream(branch_op)
        self.dag.clear()

        dr = self.dag.create_dagrun(
            run_id="manual__",
            start_date=timezone.utcnow(),
            execution_date=DEFAULT_DATE,
            state=State.RUNNING,
        )

        mock_get_records = mock_get_db_hook.return_value.get_first

        for false_value in SUPPORTED_FALSE_VALUES:
            mock_get_records.return_value = false_value
            branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)

            tis = dr.get_task_instances()
            for ti in tis:
                if ti.task_id == "make_choice":
                    assert ti.state == State.SUCCESS
                elif ti.task_id == "branch_1":
                    assert ti.state == State.SKIPPED
                elif ti.task_id == "branch_2":
                    assert ti.state == State.NONE
                else:
                    raise ValueError(f"Invalid task id {ti.task_id} found!")

    @mock.patch("airflow.operators.sql.BaseSQLOperator.get_db_hook")
    def test_branch_list_with_dag_run(self, mock_get_db_hook):
        """Checks if the BranchSQLOperator supports branching off to a list of tasks."""
        branch_op = BranchSQLOperator(
            task_id="make_choice",
            conn_id="mysql_default",
            sql="SELECT 1",
            follow_task_ids_if_true=["branch_1", "branch_2"],
            follow_task_ids_if_false="branch_3",
            dag=self.dag,
        )

        self.branch_1.set_upstream(branch_op)
        self.branch_2.set_upstream(branch_op)
        self.branch_3 = DummyOperator(task_id="branch_3", dag=self.dag)
        self.branch_3.set_upstream(branch_op)
        self.dag.clear()

        dr = self.dag.create_dagrun(
            run_id="manual__",
            start_date=timezone.utcnow(),
            execution_date=DEFAULT_DATE,
            state=State.RUNNING,
        )

        mock_get_records = mock_get_db_hook.return_value.get_first
        mock_get_records.return_value = [["1"]]

        branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)

        tis = dr.get_task_instances()
        for ti in tis:
            if ti.task_id == "make_choice":
                assert ti.state == State.SUCCESS
            elif ti.task_id == "branch_1":
                assert ti.state == State.NONE
            elif ti.task_id == "branch_2":
                assert ti.state == State.NONE
            elif ti.task_id == "branch_3":
                assert ti.state == State.SKIPPED
            else:
                raise ValueError(f"Invalid task id {ti.task_id} found!")

    @mock.patch("airflow.operators.sql.BaseSQLOperator.get_db_hook")
    def test_invalid_query_result_with_dag_run(self, mock_get_db_hook):
        """Check BranchSQLOperator branch operation"""
        branch_op = BranchSQLOperator(
            task_id="make_choice",
            conn_id="mysql_default",
            sql="SELECT 1",
            follow_task_ids_if_true="branch_1",
            follow_task_ids_if_false="branch_2",
            dag=self.dag,
        )

        self.branch_1.set_upstream(branch_op)
        self.branch_2.set_upstream(branch_op)
        self.dag.clear()

        self.dag.create_dagrun(
            run_id="manual__",
            start_date=timezone.utcnow(),
            execution_date=DEFAULT_DATE,
            state=State.RUNNING,
        )

        mock_get_records = mock_get_db_hook.return_value.get_first

        mock_get_records.return_value = ["Invalid Value"]

        with pytest.raises(AirflowException):
            branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)

    @mock.patch("airflow.operators.sql.BaseSQLOperator.get_db_hook")
    def test_with_skip_in_branch_downstream_dependencies(
            self, mock_get_db_hook):
        """Test SQL Branch with skipping all downstream dependencies"""
        branch_op = BranchSQLOperator(
            task_id="make_choice",
            conn_id="mysql_default",
            sql="SELECT 1",
            follow_task_ids_if_true="branch_1",
            follow_task_ids_if_false="branch_2",
            dag=self.dag,
        )

        branch_op >> self.branch_1 >> self.branch_2
        branch_op >> self.branch_2
        self.dag.clear()

        dr = self.dag.create_dagrun(
            run_id="manual__",
            start_date=timezone.utcnow(),
            execution_date=DEFAULT_DATE,
            state=State.RUNNING,
        )

        mock_get_records = mock_get_db_hook.return_value.get_first

        for true_value in SUPPORTED_TRUE_VALUES:
            mock_get_records.return_value = [true_value]

            branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)

            tis = dr.get_task_instances()
            for ti in tis:
                if ti.task_id == "make_choice":
                    assert ti.state == State.SUCCESS
                elif ti.task_id == "branch_1":
                    assert ti.state == State.NONE
                elif ti.task_id == "branch_2":
                    assert ti.state == State.NONE
                else:
                    raise ValueError(f"Invalid task id {ti.task_id} found!")

    @mock.patch("airflow.operators.sql.BaseSQLOperator.get_db_hook")
    def test_with_skip_in_branch_downstream_dependencies2(
            self, mock_get_db_hook):
        """Test skipping downstream dependency for false condition"""
        branch_op = BranchSQLOperator(
            task_id="make_choice",
            conn_id="mysql_default",
            sql="SELECT 1",
            follow_task_ids_if_true="branch_1",
            follow_task_ids_if_false="branch_2",
            dag=self.dag,
        )

        branch_op >> self.branch_1 >> self.branch_2
        branch_op >> self.branch_2
        self.dag.clear()

        dr = self.dag.create_dagrun(
            run_id="manual__",
            start_date=timezone.utcnow(),
            execution_date=DEFAULT_DATE,
            state=State.RUNNING,
        )

        mock_get_records = mock_get_db_hook.return_value.get_first

        for false_value in SUPPORTED_FALSE_VALUES:
            mock_get_records.return_value = [false_value]

            branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)

            tis = dr.get_task_instances()
            for ti in tis:
                if ti.task_id == "make_choice":
                    assert ti.state == State.SUCCESS
                elif ti.task_id == "branch_1":
                    assert ti.state == State.SKIPPED
                elif ti.task_id == "branch_2":
                    assert ti.state == State.NONE
                else:
                    raise ValueError(f"Invalid task id {ti.task_id} found!")
Esempio n. 16
0
class TestBranchDayOfWeekOperator(unittest.TestCase):
    """
    Tests for BranchDayOfWeekOperator
    """
    @classmethod
    def setUpClass(cls):

        with create_session() as session:
            session.query(DagRun).delete()
            session.query(TI).delete()

    def setUp(self):
        self.dag = DAG(
            "branch_day_of_week_operator_test",
            start_date=DEFAULT_DATE,
            schedule_interval=INTERVAL,
        )
        self.branch_1 = DummyOperator(task_id="branch_1", dag=self.dag)
        self.branch_2 = DummyOperator(task_id="branch_2", dag=self.dag)
        self.branch_3 = None

    def tearDown(self):

        with create_session() as session:
            session.query(DagRun).delete()
            session.query(TI).delete()

    def _assert_task_ids_match_states(self, dr, task_ids_to_states):
        """Helper that asserts task instances with a given id are in a given state"""
        tis = dr.get_task_instances()
        for ti in tis:
            try:
                expected_state = task_ids_to_states[ti.task_id]
            except KeyError:
                raise ValueError(f'Invalid task id {ti.task_id} found!')
            else:
                self.assertEqual(
                    ti.state,
                    expected_state,
                    f"Task {ti.task_id} has state {ti.state} instead of expected {expected_state}",
                )

    @parameterized.expand([
        ("with-string", "Monday"),
        ("with-enum", WeekDay.MONDAY),
        ("with-enum-set", {WeekDay.MONDAY}),
        ("with-enum-set-2-items", {WeekDay.MONDAY, WeekDay.FRIDAY}),
        ("with-string-set", {"Monday"}),
        ("with-string-set-2-items", {"Monday", "Friday"}),
    ])
    @freeze_time("2021-01-25")  # Monday
    def test_branch_follow_true(self, _, weekday):
        """Checks if BranchDayOfWeekOperator follows true branch"""
        print(datetime.datetime.now())
        branch_op = BranchDayOfWeekOperator(
            task_id="make_choice",
            follow_task_ids_if_true=["branch_1", "branch_2"],
            follow_task_ids_if_false="branch_3",
            week_day=weekday,
            dag=self.dag,
        )

        self.branch_1.set_upstream(branch_op)
        self.branch_2.set_upstream(branch_op)
        self.branch_3 = DummyOperator(task_id="branch_3", dag=self.dag)
        self.branch_3.set_upstream(branch_op)
        self.dag.clear()

        dr = self.dag.create_dagrun(
            run_id="manual__",
            start_date=timezone.utcnow(),
            execution_date=DEFAULT_DATE,
            state=State.RUNNING,
        )

        branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)

        self._assert_task_ids_match_states(
            dr,
            {
                'make_choice': State.SUCCESS,
                'branch_1': State.NONE,
                'branch_2': State.NONE,
                'branch_3': State.SKIPPED,
            },
        )

    @freeze_time("2021-01-25")  # Monday
    def test_branch_follow_true_with_execution_date(self):
        """Checks if BranchDayOfWeekOperator follows true branch when set use_task_execution_day """

        branch_op = BranchDayOfWeekOperator(
            task_id="make_choice",
            follow_task_ids_if_true="branch_1",
            follow_task_ids_if_false="branch_2",
            week_day="Wednesday",
            use_task_execution_day=
            True,  # We compare to DEFAULT_DATE which is Wednesday
            dag=self.dag,
        )

        self.branch_1.set_upstream(branch_op)
        self.branch_2.set_upstream(branch_op)
        self.dag.clear()

        dr = self.dag.create_dagrun(
            run_id="manual__",
            start_date=timezone.utcnow(),
            execution_date=DEFAULT_DATE,
            state=State.RUNNING,
        )

        branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)

        self._assert_task_ids_match_states(
            dr,
            {
                'make_choice': State.SUCCESS,
                'branch_1': State.NONE,
                'branch_2': State.SKIPPED,
            },
        )

    @freeze_time("2021-01-25")  # Monday
    def test_branch_follow_false(self):
        """Checks if BranchDayOfWeekOperator follow false branch"""

        branch_op = BranchDayOfWeekOperator(
            task_id="make_choice",
            follow_task_ids_if_true="branch_1",
            follow_task_ids_if_false="branch_2",
            week_day="Sunday",
            dag=self.dag,
        )

        self.branch_1.set_upstream(branch_op)
        self.branch_2.set_upstream(branch_op)
        self.dag.clear()

        dr = self.dag.create_dagrun(
            run_id="manual__",
            start_date=timezone.utcnow(),
            execution_date=DEFAULT_DATE,
            state=State.RUNNING,
        )

        branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)

        self._assert_task_ids_match_states(
            dr,
            {
                'make_choice': State.SUCCESS,
                'branch_1': State.SKIPPED,
                'branch_2': State.NONE,
            },
        )

    def test_branch_with_no_weekday(self):
        """Check if BranchDayOfWeekOperator raises exception on missing weekday"""
        with self.assertRaises(AirflowException):
            BranchDayOfWeekOperator(  # pylint: disable=missing-kwoa
                task_id="make_choice",
                follow_task_ids_if_true="branch_1",
                follow_task_ids_if_false="branch_2",
                dag=self.dag,
            )

    def test_branch_with_invalid_type(self):
        """Check if BranchDayOfWeekOperator raises exception on unsupported weekday type"""
        invalid_week_day = ['Monday']
        with pytest.raises(
                TypeError,
                match='Unsupported Type for week_day parameter:'
                ' {}. It should be one of str, set or '
                'Weekday enum type'.format(type(invalid_week_day)),
        ):
            BranchDayOfWeekOperator(
                task_id="make_choice",
                follow_task_ids_if_true="branch_1",
                follow_task_ids_if_false="branch_2",
                week_day=invalid_week_day,
                dag=self.dag,
            )

    def test_weekday_branch_invalid_weekday_number(self):
        """Check if BranchDayOfWeekOperator raises exception on wrong value of weekday"""
        invalid_week_day = 'Thsday'
        with pytest.raises(
                AttributeError,
                match=f'Invalid Week Day passed: "{invalid_week_day}"'):
            BranchDayOfWeekOperator(
                task_id="make_choice",
                follow_task_ids_if_true="branch_1",
                follow_task_ids_if_false="branch_2",
                week_day=invalid_week_day,
                dag=self.dag,
            )

    @freeze_time("2021-01-25")  # Monday
    def test_branch_xcom_push_true_branch(self):
        """Check if BranchDayOfWeekOperator push to xcom value of follow_task_ids_if_true"""
        branch_op = BranchDayOfWeekOperator(
            task_id="make_choice",
            follow_task_ids_if_true="branch_1",
            follow_task_ids_if_false="branch_2",
            week_day="Monday",
            dag=self.dag,
        )

        self.branch_1.set_upstream(branch_op)
        self.branch_2.set_upstream(branch_op)
        self.dag.clear()

        dr = self.dag.create_dagrun(
            run_id="manual__",
            start_date=timezone.utcnow(),
            execution_date=DEFAULT_DATE,
            state=State.RUNNING,
        )

        branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)

        tis = dr.get_task_instances()
        for ti in tis:
            if ti.task_id == 'make_choice':
                assert ti.xcom_pull(task_ids='make_choice') == 'branch_1'
    def test_not_skipping_external(self):
        latest_task = LatestOnlyOperator(task_id='latest', dag=self.dag)
        downstream_task = DummyOperator(task_id='downstream', dag=self.dag)
        downstream_task2 = DummyOperator(task_id='downstream_2', dag=self.dag)

        downstream_task.set_upstream(latest_task)
        downstream_task2.set_upstream(downstream_task)

        self.dag.create_dagrun(
            run_type=DagRunType.MANUAL,
            start_date=timezone.utcnow(),
            execution_date=DEFAULT_DATE,
            state=State.RUNNING,
            external_trigger=True,
        )

        self.dag.create_dagrun(
            run_type=DagRunType.MANUAL,
            start_date=timezone.utcnow(),
            execution_date=timezone.datetime(2016, 1, 1, 12),
            state=State.RUNNING,
            external_trigger=True,
        )

        self.dag.create_dagrun(
            run_type=DagRunType.MANUAL,
            start_date=timezone.utcnow(),
            execution_date=END_DATE,
            state=State.RUNNING,
            external_trigger=True,
        )

        latest_task.run(start_date=DEFAULT_DATE, end_date=END_DATE)
        downstream_task.run(start_date=DEFAULT_DATE, end_date=END_DATE)
        downstream_task2.run(start_date=DEFAULT_DATE, end_date=END_DATE)

        latest_instances = get_task_instances('latest')
        exec_date_to_latest_state = {
            ti.execution_date: ti.state
            for ti in latest_instances
        }
        assert {
            timezone.datetime(2016, 1, 1): 'success',
            timezone.datetime(2016, 1, 1, 12): 'success',
            timezone.datetime(2016, 1, 2): 'success',
        } == exec_date_to_latest_state

        downstream_instances = get_task_instances('downstream')
        exec_date_to_downstream_state = {
            ti.execution_date: ti.state
            for ti in downstream_instances
        }
        assert {
            timezone.datetime(2016, 1, 1): 'success',
            timezone.datetime(2016, 1, 1, 12): 'success',
            timezone.datetime(2016, 1, 2): 'success',
        } == exec_date_to_downstream_state

        downstream_instances = get_task_instances('downstream_2')
        exec_date_to_downstream_state = {
            ti.execution_date: ti.state
            for ti in downstream_instances
        }
        assert {
            timezone.datetime(2016, 1, 1): 'success',
            timezone.datetime(2016, 1, 1, 12): 'success',
            timezone.datetime(2016, 1, 2): 'success',
        } == exec_date_to_downstream_state
Esempio n. 18
0
    def test_with_dag_run(self):
        value = False
        dag = DAG(
            'shortcircuit_operator_test_with_dag_run',
            default_args={
                'owner': 'airflow',
                'start_date': DEFAULT_DATE
            },
            schedule_interval=INTERVAL,
        )
        short_op = ShortCircuitOperator(task_id='make_choice',
                                        dag=dag,
                                        python_callable=lambda: value)
        branch_1 = DummyOperator(task_id='branch_1', dag=dag)
        branch_1.set_upstream(short_op)
        branch_2 = DummyOperator(task_id='branch_2', dag=dag)
        branch_2.set_upstream(branch_1)
        upstream = DummyOperator(task_id='upstream', dag=dag)
        upstream.set_downstream(short_op)
        dag.clear()

        logging.error("Tasks %s", dag.tasks)
        dr = dag.create_dagrun(
            run_type=DagRunType.MANUAL,
            start_date=timezone.utcnow(),
            execution_date=DEFAULT_DATE,
            state=State.RUNNING,
        )

        upstream.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
        short_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)

        tis = dr.get_task_instances()
        assert len(tis) == 4
        for ti in tis:
            if ti.task_id == 'make_choice':
                assert ti.state == State.SUCCESS
            elif ti.task_id == 'upstream':
                assert ti.state == State.SUCCESS
            elif ti.task_id == 'branch_1' or ti.task_id == 'branch_2':
                assert ti.state == State.SKIPPED
            else:
                raise ValueError(f'Invalid task id {ti.task_id} found!')

        value = True
        dag.clear()
        dr.verify_integrity()
        upstream.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
        short_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)

        tis = dr.get_task_instances()
        assert len(tis) == 4
        for ti in tis:
            if ti.task_id == 'make_choice':
                assert ti.state == State.SUCCESS
            elif ti.task_id == 'upstream':
                assert ti.state == State.SUCCESS
            elif ti.task_id == 'branch_1' or ti.task_id == 'branch_2':
                assert ti.state == State.NONE
            else:
                raise ValueError(f'Invalid task id {ti.task_id} found!')
Esempio n. 19
0
class TestBranchOperator(unittest.TestCase):
    @classmethod
    def setUpClass(cls):
        super().setUpClass()

        with create_session() as session:
            session.query(DagRun).delete()
            session.query(TI).delete()

    def setUp(self):
        self.dag = DAG(
            'branch_operator_test',
            default_args={
                'owner': 'airflow',
                'start_date': DEFAULT_DATE
            },
            schedule_interval=INTERVAL,
        )

        self.branch_1 = DummyOperator(task_id='branch_1', dag=self.dag)
        self.branch_2 = DummyOperator(task_id='branch_2', dag=self.dag)
        self.branch_3 = None

    def tearDown(self):
        super().tearDown()

        with create_session() as session:
            session.query(DagRun).delete()
            session.query(TI).delete()

    def test_without_dag_run(self):
        """This checks the defensive against non existent tasks in a dag run"""
        branch_op = BranchPythonOperator(task_id='make_choice',
                                         dag=self.dag,
                                         python_callable=lambda: 'branch_1')
        self.branch_1.set_upstream(branch_op)
        self.branch_2.set_upstream(branch_op)
        self.dag.clear()

        branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)

        with create_session() as session:
            tis = session.query(TI).filter(TI.dag_id == self.dag.dag_id,
                                           TI.execution_date == DEFAULT_DATE)

            for ti in tis:
                if ti.task_id == 'make_choice':
                    assert ti.state == State.SUCCESS
                elif ti.task_id == 'branch_1':
                    # should exist with state None
                    assert ti.state == State.NONE
                elif ti.task_id == 'branch_2':
                    assert ti.state == State.SKIPPED
                else:
                    raise ValueError(f'Invalid task id {ti.task_id} found!')

    def test_branch_list_without_dag_run(self):
        """This checks if the BranchPythonOperator supports branching off to a list of tasks."""
        branch_op = BranchPythonOperator(
            task_id='make_choice',
            dag=self.dag,
            python_callable=lambda: ['branch_1', 'branch_2'])
        self.branch_1.set_upstream(branch_op)
        self.branch_2.set_upstream(branch_op)
        self.branch_3 = DummyOperator(task_id='branch_3', dag=self.dag)
        self.branch_3.set_upstream(branch_op)
        self.dag.clear()

        branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)

        with create_session() as session:
            tis = session.query(TI).filter(TI.dag_id == self.dag.dag_id,
                                           TI.execution_date == DEFAULT_DATE)

            expected = {
                "make_choice": State.SUCCESS,
                "branch_1": State.NONE,
                "branch_2": State.NONE,
                "branch_3": State.SKIPPED,
            }

            for ti in tis:
                if ti.task_id in expected:
                    assert ti.state == expected[ti.task_id]
                else:
                    raise ValueError(f'Invalid task id {ti.task_id} found!')

    def test_with_dag_run(self):
        branch_op = BranchPythonOperator(task_id='make_choice',
                                         dag=self.dag,
                                         python_callable=lambda: 'branch_1')

        self.branch_1.set_upstream(branch_op)
        self.branch_2.set_upstream(branch_op)
        self.dag.clear()

        dr = self.dag.create_dagrun(
            run_type=DagRunType.MANUAL,
            start_date=timezone.utcnow(),
            execution_date=DEFAULT_DATE,
            state=State.RUNNING,
        )

        branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)

        tis = dr.get_task_instances()
        for ti in tis:
            if ti.task_id == 'make_choice':
                assert ti.state == State.SUCCESS
            elif ti.task_id == 'branch_1':
                assert ti.state == State.NONE
            elif ti.task_id == 'branch_2':
                assert ti.state == State.SKIPPED
            else:
                raise ValueError(f'Invalid task id {ti.task_id} found!')

    def test_with_skip_in_branch_downstream_dependencies(self):
        branch_op = BranchPythonOperator(task_id='make_choice',
                                         dag=self.dag,
                                         python_callable=lambda: 'branch_1')

        branch_op >> self.branch_1 >> self.branch_2
        branch_op >> self.branch_2
        self.dag.clear()

        dr = self.dag.create_dagrun(
            run_type=DagRunType.MANUAL,
            start_date=timezone.utcnow(),
            execution_date=DEFAULT_DATE,
            state=State.RUNNING,
        )

        branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)

        tis = dr.get_task_instances()
        for ti in tis:
            if ti.task_id == 'make_choice':
                assert ti.state == State.SUCCESS
            elif ti.task_id == 'branch_1':
                assert ti.state == State.NONE
            elif ti.task_id == 'branch_2':
                assert ti.state == State.NONE
            else:
                raise ValueError(f'Invalid task id {ti.task_id} found!')

    def test_with_skip_in_branch_downstream_dependencies2(self):
        branch_op = BranchPythonOperator(task_id='make_choice',
                                         dag=self.dag,
                                         python_callable=lambda: 'branch_2')

        branch_op >> self.branch_1 >> self.branch_2
        branch_op >> self.branch_2
        self.dag.clear()

        dr = self.dag.create_dagrun(
            run_type=DagRunType.MANUAL,
            start_date=timezone.utcnow(),
            execution_date=DEFAULT_DATE,
            state=State.RUNNING,
        )

        branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)

        tis = dr.get_task_instances()
        for ti in tis:
            if ti.task_id == 'make_choice':
                assert ti.state == State.SUCCESS
            elif ti.task_id == 'branch_1':
                assert ti.state == State.SKIPPED
            elif ti.task_id == 'branch_2':
                assert ti.state == State.NONE
            else:
                raise ValueError(f'Invalid task id {ti.task_id} found!')

    def test_xcom_push(self):
        branch_op = BranchPythonOperator(task_id='make_choice',
                                         dag=self.dag,
                                         python_callable=lambda: 'branch_1')

        self.branch_1.set_upstream(branch_op)
        self.branch_2.set_upstream(branch_op)
        self.dag.clear()

        dr = self.dag.create_dagrun(
            run_type=DagRunType.MANUAL,
            start_date=timezone.utcnow(),
            execution_date=DEFAULT_DATE,
            state=State.RUNNING,
        )

        branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)

        tis = dr.get_task_instances()
        for ti in tis:
            if ti.task_id == 'make_choice':
                assert ti.xcom_pull(task_ids='make_choice') == 'branch_1'

    def test_clear_skipped_downstream_task(self):
        """
        After a downstream task is skipped by BranchPythonOperator, clearing the skipped task
        should not cause it to be executed.
        """
        branch_op = BranchPythonOperator(task_id='make_choice',
                                         dag=self.dag,
                                         python_callable=lambda: 'branch_1')
        branches = [self.branch_1, self.branch_2]
        branch_op >> branches
        self.dag.clear()

        dr = self.dag.create_dagrun(
            run_type=DagRunType.MANUAL,
            start_date=timezone.utcnow(),
            execution_date=DEFAULT_DATE,
            state=State.RUNNING,
        )

        branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)

        for task in branches:
            task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)

        tis = dr.get_task_instances()
        for ti in tis:
            if ti.task_id == 'make_choice':
                assert ti.state == State.SUCCESS
            elif ti.task_id == 'branch_1':
                assert ti.state == State.SUCCESS
            elif ti.task_id == 'branch_2':
                assert ti.state == State.SKIPPED
            else:
                raise ValueError(f'Invalid task id {ti.task_id} found!')

        children_tis = [
            ti for ti in tis
            if ti.task_id in branch_op.get_direct_relative_ids()
        ]

        # Clear the children tasks.
        with create_session() as session:
            clear_task_instances(children_tis, session=session, dag=self.dag)

        # Run the cleared tasks again.
        for task in branches:
            task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)

        # Check if the states are correct after children tasks are cleared.
        for ti in dr.get_task_instances():
            if ti.task_id == 'make_choice':
                assert ti.state == State.SUCCESS
            elif ti.task_id == 'branch_1':
                assert ti.state == State.SUCCESS
            elif ti.task_id == 'branch_2':
                assert ti.state == State.SKIPPED
            else:
                raise ValueError(f'Invalid task id {ti.task_id} found!')
class TestBranchOperator(unittest.TestCase):
    @classmethod
    def setUpClass(cls):
        super().setUpClass()

        with create_session() as session:
            session.query(DagRun).delete()
            session.query(TI).delete()

    def setUp(self):
        self.dag = DAG(
            'branch_operator_test',
            default_args={
                'owner': 'airflow',
                'start_date': DEFAULT_DATE
            },
            schedule_interval=INTERVAL,
        )

        self.branch_1 = DummyOperator(task_id='branch_1', dag=self.dag)
        self.branch_2 = DummyOperator(task_id='branch_2', dag=self.dag)
        self.branch_3 = None
        self.branch_op = None

    def tearDown(self):
        super().tearDown()

        with create_session() as session:
            session.query(DagRun).delete()
            session.query(TI).delete()

    def test_without_dag_run(self):
        """This checks the defensive against non existent tasks in a dag run"""
        self.branch_op = ChooseBranchOne(task_id="make_choice", dag=self.dag)
        self.branch_1.set_upstream(self.branch_op)
        self.branch_2.set_upstream(self.branch_op)
        self.dag.clear()

        self.branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)

        with create_session() as session:
            tis = session.query(TI).filter(TI.dag_id == self.dag.dag_id,
                                           TI.execution_date == DEFAULT_DATE)

            for ti in tis:
                if ti.task_id == 'make_choice':
                    assert ti.state == State.SUCCESS
                elif ti.task_id == 'branch_1':
                    # should exist with state None
                    assert ti.state == State.NONE
                elif ti.task_id == 'branch_2':
                    assert ti.state == State.SKIPPED
                else:
                    raise Exception

    def test_branch_list_without_dag_run(self):
        """This checks if the BranchOperator supports branching off to a list of tasks."""
        self.branch_op = ChooseBranchOneTwo(task_id='make_choice',
                                            dag=self.dag)
        self.branch_1.set_upstream(self.branch_op)
        self.branch_2.set_upstream(self.branch_op)
        self.branch_3 = DummyOperator(task_id='branch_3', dag=self.dag)
        self.branch_3.set_upstream(self.branch_op)
        self.dag.clear()

        self.branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)

        with create_session() as session:
            tis = session.query(TI).filter(TI.dag_id == self.dag.dag_id,
                                           TI.execution_date == DEFAULT_DATE)

            expected = {
                "make_choice": State.SUCCESS,
                "branch_1": State.NONE,
                "branch_2": State.NONE,
                "branch_3": State.SKIPPED,
            }

            for ti in tis:
                if ti.task_id in expected:
                    assert ti.state == expected[ti.task_id]
                else:
                    raise Exception

    def test_with_dag_run(self):
        self.branch_op = ChooseBranchOne(task_id="make_choice", dag=self.dag)
        self.branch_1.set_upstream(self.branch_op)
        self.branch_2.set_upstream(self.branch_op)
        self.dag.clear()

        dagrun = self.dag.create_dagrun(
            run_type=DagRunType.MANUAL,
            start_date=timezone.utcnow(),
            execution_date=DEFAULT_DATE,
            state=State.RUNNING,
        )

        self.branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)

        tis = dagrun.get_task_instances()
        for ti in tis:
            if ti.task_id == 'make_choice':
                assert ti.state == State.SUCCESS
            elif ti.task_id == 'branch_1':
                assert ti.state == State.NONE
            elif ti.task_id == 'branch_2':
                assert ti.state == State.SKIPPED
            else:
                raise Exception

    def test_with_skip_in_branch_downstream_dependencies(self):
        self.branch_op = ChooseBranchOne(task_id="make_choice", dag=self.dag)
        self.branch_op >> self.branch_1 >> self.branch_2
        self.branch_op >> self.branch_2
        self.dag.clear()

        dagrun = self.dag.create_dagrun(
            run_type=DagRunType.MANUAL,
            start_date=timezone.utcnow(),
            execution_date=DEFAULT_DATE,
            state=State.RUNNING,
        )

        self.branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)

        tis = dagrun.get_task_instances()
        for ti in tis:
            if ti.task_id == 'make_choice':
                assert ti.state == State.SUCCESS
            elif ti.task_id == 'branch_1':
                assert ti.state == State.NONE
            elif ti.task_id == 'branch_2':
                assert ti.state == State.NONE
            else:
                raise Exception