Exemplo n.º 1
0
    def test_localtaskjob_heartbeat(self, mock_pid):
        session = settings.Session()
        dag = DAG('test_localtaskjob_heartbeat',
                  start_date=DEFAULT_DATE,
                  default_args={'owner': 'owner1'})

        with dag:
            op1 = DummyOperator(task_id='op1')

        dag.clear()
        dr = dag.create_dagrun(run_id="test",
                               state=State.SUCCESS,
                               execution_date=DEFAULT_DATE,
                               start_date=DEFAULT_DATE,
                               session=session)
        ti = dr.get_task_instance(task_id=op1.task_id, session=session)
        ti.state = State.RUNNING
        ti.hostname = "blablabla"
        session.commit()

        job1 = LocalTaskJob(task_instance=ti,
                            ignore_ti_state=True,
                            executor=SequentialExecutor())
        self.assertRaises(AirflowException, job1.heartbeat_callback)

        mock_pid.return_value = 1
        ti.state = State.RUNNING
        ti.hostname = get_hostname()
        ti.pid = 1
        session.merge(ti)
        session.commit()

        job1.heartbeat_callback(session=None)

        mock_pid.return_value = 2
        self.assertRaises(AirflowException, job1.heartbeat_callback)
Exemplo n.º 2
0
class TestLatestOnlyOperator(unittest.TestCase):

    def setUp(self):
        super().setUp()
        self.dag = DAG(
            'test_dag',
            default_args={
                'owner': 'airflow',
                'start_date': DEFAULT_DATE},
            schedule_interval=INTERVAL)
        with db.create_session() as session:
            session.query(DagRun).delete()
            session.query(TaskInstance).delete()
        freezer = freeze_time(FROZEN_NOW)
        freezer.start()
        self.addCleanup(freezer.stop)

    def test_run(self):
        task = LatestOnlyOperator(
            task_id='latest',
            dag=self.dag)
        task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)

    def test_skipping_non_latest(self):
        latest_task = LatestOnlyOperator(
            task_id='latest',
            dag=self.dag)
        downstream_task = DummyOperator(
            task_id='downstream',
            dag=self.dag)
        downstream_task2 = DummyOperator(
            task_id='downstream_2',
            dag=self.dag)
        downstream_task3 = DummyOperator(
            task_id='downstream_3',
            trigger_rule=TriggerRule.NONE_FAILED,
            dag=self.dag)

        downstream_task.set_upstream(latest_task)
        downstream_task2.set_upstream(downstream_task)
        downstream_task3.set_upstream(downstream_task)

        self.dag.create_dagrun(
            run_id="scheduled__1",
            start_date=timezone.utcnow(),
            execution_date=DEFAULT_DATE,
            state=State.RUNNING,
        )

        self.dag.create_dagrun(
            run_id="scheduled__2",
            start_date=timezone.utcnow(),
            execution_date=timezone.datetime(2016, 1, 1, 12),
            state=State.RUNNING,
        )

        self.dag.create_dagrun(
            run_id="scheduled__3",
            start_date=timezone.utcnow(),
            execution_date=END_DATE,
            state=State.RUNNING,
        )

        latest_task.run(start_date=DEFAULT_DATE, end_date=END_DATE)
        downstream_task.run(start_date=DEFAULT_DATE, end_date=END_DATE)
        downstream_task2.run(start_date=DEFAULT_DATE, end_date=END_DATE)
        downstream_task3.run(start_date=DEFAULT_DATE, end_date=END_DATE)

        latest_instances = get_task_instances('latest')
        exec_date_to_latest_state = {
            ti.execution_date: ti.state for ti in latest_instances}
        self.assertEqual({
            timezone.datetime(2016, 1, 1): 'success',
            timezone.datetime(2016, 1, 1, 12): 'success',
            timezone.datetime(2016, 1, 2): 'success'},
            exec_date_to_latest_state)

        downstream_instances = get_task_instances('downstream')
        exec_date_to_downstream_state = {
            ti.execution_date: ti.state for ti in downstream_instances}
        self.assertEqual({
            timezone.datetime(2016, 1, 1): 'skipped',
            timezone.datetime(2016, 1, 1, 12): 'skipped',
            timezone.datetime(2016, 1, 2): 'success'},
            exec_date_to_downstream_state)

        downstream_instances = get_task_instances('downstream_2')
        exec_date_to_downstream_state = {
            ti.execution_date: ti.state for ti in downstream_instances}
        self.assertEqual({
            timezone.datetime(2016, 1, 1): None,
            timezone.datetime(2016, 1, 1, 12): None,
            timezone.datetime(2016, 1, 2): 'success'},
            exec_date_to_downstream_state)

        downstream_instances = get_task_instances('downstream_3')
        exec_date_to_downstream_state = {
            ti.execution_date: ti.state for ti in downstream_instances}
        self.assertEqual({
            timezone.datetime(2016, 1, 1): 'success',
            timezone.datetime(2016, 1, 1, 12): 'success',
            timezone.datetime(2016, 1, 2): 'success'},
            exec_date_to_downstream_state)

    def test_not_skipping_external(self):
        latest_task = LatestOnlyOperator(
            task_id='latest',
            dag=self.dag)
        downstream_task = DummyOperator(
            task_id='downstream',
            dag=self.dag)
        downstream_task2 = DummyOperator(
            task_id='downstream_2',
            dag=self.dag)

        downstream_task.set_upstream(latest_task)
        downstream_task2.set_upstream(downstream_task)

        self.dag.create_dagrun(
            run_id="manual__1",
            start_date=timezone.utcnow(),
            execution_date=DEFAULT_DATE,
            state=State.RUNNING,
            external_trigger=True,
        )

        self.dag.create_dagrun(
            run_id="manual__2",
            start_date=timezone.utcnow(),
            execution_date=timezone.datetime(2016, 1, 1, 12),
            state=State.RUNNING,
            external_trigger=True,
        )

        self.dag.create_dagrun(
            run_id="manual__3",
            start_date=timezone.utcnow(),
            execution_date=END_DATE,
            state=State.RUNNING,
            external_trigger=True,
        )

        latest_task.run(start_date=DEFAULT_DATE, end_date=END_DATE)
        downstream_task.run(start_date=DEFAULT_DATE, end_date=END_DATE)
        downstream_task2.run(start_date=DEFAULT_DATE, end_date=END_DATE)

        latest_instances = get_task_instances('latest')
        exec_date_to_latest_state = {
            ti.execution_date: ti.state for ti in latest_instances}
        self.assertEqual({
            timezone.datetime(2016, 1, 1): 'success',
            timezone.datetime(2016, 1, 1, 12): 'success',
            timezone.datetime(2016, 1, 2): 'success'},
            exec_date_to_latest_state)

        downstream_instances = get_task_instances('downstream')
        exec_date_to_downstream_state = {
            ti.execution_date: ti.state for ti in downstream_instances}
        self.assertEqual({
            timezone.datetime(2016, 1, 1): 'success',
            timezone.datetime(2016, 1, 1, 12): 'success',
            timezone.datetime(2016, 1, 2): 'success'},
            exec_date_to_downstream_state)

        downstream_instances = get_task_instances('downstream_2')
        exec_date_to_downstream_state = {
            ti.execution_date: ti.state for ti in downstream_instances}
        self.assertEqual({
            timezone.datetime(2016, 1, 1): 'success',
            timezone.datetime(2016, 1, 1, 12): 'success',
            timezone.datetime(2016, 1, 2): 'success'},
            exec_date_to_downstream_state)
Exemplo n.º 3
0
class TestCore(unittest.TestCase):
    default_scheduler_args = {"num_runs": 1}

    def setUp(self):
        self.dagbag = DagBag(dag_folder=DEV_NULL,
                             include_examples=True,
                             read_dags_from_db=False)
        self.args = {'owner': 'airflow', 'start_date': DEFAULT_DATE}
        self.dag = DAG(TEST_DAG_ID, default_args=self.args)
        self.dag_bash = self.dagbag.dags['example_bash_operator']
        self.runme_0 = self.dag_bash.get_task('runme_0')
        self.run_after_loop = self.dag_bash.get_task('run_after_loop')
        self.run_this_last = self.dag_bash.get_task('run_this_last')

    def tearDown(self):
        session = Session()
        session.query(DagRun).filter(DagRun.dag_id == TEST_DAG_ID).delete(
            synchronize_session=False)
        session.query(TaskInstance).filter(
            TaskInstance.dag_id == TEST_DAG_ID).delete(
                synchronize_session=False)
        session.query(TaskFail).filter(TaskFail.dag_id == TEST_DAG_ID).delete(
            synchronize_session=False)
        session.commit()
        session.close()
        clear_db_dags()
        clear_db_runs()

    def test_check_operators(self):

        conn_id = "sqlite_default"

        captain_hook = BaseHook.get_hook(conn_id=conn_id)  # quite funny :D
        captain_hook.run("CREATE TABLE operator_test_table (a, b)")
        captain_hook.run("insert into operator_test_table values (1,2)")

        self.dag.create_dagrun(run_type=DagRunType.MANUAL,
                               state=State.RUNNING,
                               execution_date=DEFAULT_DATE)
        op = CheckOperator(task_id='check',
                           sql="select count(*) from operator_test_table",
                           conn_id=conn_id,
                           dag=self.dag)

        op.run(start_date=DEFAULT_DATE,
               end_date=DEFAULT_DATE,
               ignore_ti_state=True)

        op = ValueCheckOperator(
            task_id='value_check',
            pass_value=95,
            tolerance=0.1,
            conn_id=conn_id,
            sql="SELECT 100",
            dag=self.dag,
        )
        op.run(start_date=DEFAULT_DATE,
               end_date=DEFAULT_DATE,
               ignore_ti_state=True)

        captain_hook.run("drop table operator_test_table")

    def test_clear_api(self):
        task = self.dag_bash.tasks[0]
        task.clear(start_date=DEFAULT_DATE,
                   end_date=DEFAULT_DATE,
                   upstream=True,
                   downstream=True)
        ti = TaskInstance(task=task, execution_date=DEFAULT_DATE)
        ti.are_dependents_done()

    def test_illegal_args(self):
        """
        Tests that Operators reject illegal arguments
        """
        msg = 'Invalid arguments were passed to BashOperator (task_id: test_illegal_args).'
        with conf_vars({('operators', 'allow_illegal_arguments'): 'True'}):
            with pytest.warns(PendingDeprecationWarning) as warnings:
                BashOperator(
                    task_id='test_illegal_args',
                    bash_command='echo success',
                    dag=self.dag,
                    illegal_argument_1234='hello?',
                )
                assert any(msg in str(w) for w in warnings)

    def test_illegal_args_forbidden(self):
        """
        Tests that operators raise exceptions on illegal arguments when
        illegal arguments are not allowed.
        """
        with pytest.raises(AirflowException) as ctx:
            BashOperator(
                task_id='test_illegal_args',
                bash_command='echo success',
                dag=self.dag,
                illegal_argument_1234='hello?',
            )
        assert 'Invalid arguments were passed to BashOperator (task_id: test_illegal_args).' in str(
            ctx.value)

    def test_bash_operator(self):
        op = BashOperator(task_id='test_bash_operator',
                          bash_command="echo success",
                          dag=self.dag)
        self.dag.create_dagrun(run_type=DagRunType.MANUAL,
                               state=State.RUNNING,
                               execution_date=DEFAULT_DATE)

        op.run(start_date=DEFAULT_DATE,
               end_date=DEFAULT_DATE,
               ignore_ti_state=True)

    def test_bash_operator_multi_byte_output(self):
        op = BashOperator(
            task_id='test_multi_byte_bash_operator',
            bash_command="echo \u2600",
            dag=self.dag,
            output_encoding='utf-8',
        )
        self.dag.create_dagrun(run_type=DagRunType.MANUAL,
                               state=State.RUNNING,
                               execution_date=DEFAULT_DATE)
        op.run(start_date=DEFAULT_DATE,
               end_date=DEFAULT_DATE,
               ignore_ti_state=True)

    def test_bash_operator_kill(self):
        import psutil

        sleep_time = "100%d" % os.getpid()
        op = BashOperator(
            task_id='test_bash_operator_kill',
            execution_timeout=timedelta(seconds=1),
            bash_command=f"/bin/bash -c 'sleep {sleep_time}'",
            dag=self.dag,
        )
        with pytest.raises(AirflowTaskTimeout):
            op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
        sleep(2)
        pid = -1
        for proc in psutil.process_iter():
            if proc.cmdline() == ['sleep', sleep_time]:
                pid = proc.pid
        if pid != -1:
            os.kill(pid, signal.SIGTERM)
            self.fail(
                "BashOperator's subprocess still running after stopping on timeout!"
            )

    def test_on_failure_callback(self):
        # Annoying workaround for nonlocal not existing in python 2
        data = {'called': False}

        def check_failure(context, test_case=self):  # pylint: disable=unused-argument
            data['called'] = True
            error = context.get("exception")
            test_case.assertIsInstance(error, AirflowException)

        op = BashOperator(
            task_id='check_on_failure_callback',
            bash_command="exit 1",
            dag=self.dag,
            on_failure_callback=check_failure,
        )
        with pytest.raises(AirflowException):
            op.run(start_date=DEFAULT_DATE,
                   end_date=DEFAULT_DATE,
                   ignore_ti_state=True)
        assert data['called']

    def test_dryrun(self):
        op = BashOperator(task_id='test_dryrun',
                          bash_command="echo success",
                          dag=self.dag)
        op.dry_run()

    def test_sqlite(self):
        import airflow.providers.sqlite.operators.sqlite

        op = airflow.providers.sqlite.operators.sqlite.SqliteOperator(
            task_id='time_sqlite',
            sql="CREATE TABLE IF NOT EXISTS unitest (dummy VARCHAR(20))",
            dag=self.dag)
        self.dag.create_dagrun(run_type=DagRunType.MANUAL,
                               state=State.RUNNING,
                               execution_date=DEFAULT_DATE)
        op.run(start_date=DEFAULT_DATE,
               end_date=DEFAULT_DATE,
               ignore_ti_state=True)

    def test_timeout(self):
        op = PythonOperator(
            task_id='test_timeout',
            execution_timeout=timedelta(seconds=1),
            python_callable=lambda: sleep(5),
            dag=self.dag,
        )
        with pytest.raises(AirflowTaskTimeout):
            op.run(start_date=DEFAULT_DATE,
                   end_date=DEFAULT_DATE,
                   ignore_ti_state=True)

    def test_python_op(self):
        def test_py_op(templates_dict, ds, **kwargs):
            if not templates_dict['ds'] == ds:
                raise Exception("failure")

        op = PythonOperator(task_id='test_py_op',
                            python_callable=test_py_op,
                            templates_dict={'ds': "{{ ds }}"},
                            dag=self.dag)
        self.dag.create_dagrun(run_type=DagRunType.MANUAL,
                               state=State.RUNNING,
                               execution_date=DEFAULT_DATE)
        op.run(start_date=DEFAULT_DATE,
               end_date=DEFAULT_DATE,
               ignore_ti_state=True)

    def test_complex_template(self):
        def verify_templated_field(context):
            assert context['ti'].task.some_templated_field['bar'][
                1] == context['ds']

        op = OperatorSubclass(
            task_id='test_complex_template',
            some_templated_field={
                'foo': '123',
                'bar': ['baz', '{{ ds }}']
            },
            dag=self.dag,
        )
        op.execute = verify_templated_field
        self.dag.create_dagrun(run_type=DagRunType.MANUAL,
                               state=State.RUNNING,
                               execution_date=DEFAULT_DATE)
        op.run(start_date=DEFAULT_DATE,
               end_date=DEFAULT_DATE,
               ignore_ti_state=True)

    def test_template_non_bool(self):
        """
        Test templates can handle objects with no sense of truthiness
        """
        class NonBoolObject:
            def __len__(self):  # pylint: disable=invalid-length-returned
                return NotImplemented

            def __bool__(self):  # pylint: disable=invalid-bool-returned, bad-option-value
                return NotImplemented

        op = OperatorSubclass(task_id='test_bad_template_obj',
                              some_templated_field=NonBoolObject(),
                              dag=self.dag)
        op.resolve_template_files()

    def test_task_get_template(self):
        TI = TaskInstance
        ti = TI(task=self.runme_0, execution_date=DEFAULT_DATE)
        ti.dag = self.dag_bash
        self.dag_bash.create_dagrun(run_type=DagRunType.MANUAL,
                                    state=State.RUNNING,
                                    execution_date=DEFAULT_DATE)
        ti.run(ignore_ti_state=True)
        context = ti.get_template_context()

        # DEFAULT DATE is 2015-01-01
        assert context['ds'] == '2015-01-01'
        assert context['ds_nodash'] == '20150101'

        # next_ds is 2015-01-02 as the dag interval is daily
        assert context['next_ds'] == '2015-01-02'
        assert context['next_ds_nodash'] == '20150102'

        # prev_ds is 2014-12-31 as the dag interval is daily
        assert context['prev_ds'] == '2014-12-31'
        assert context['prev_ds_nodash'] == '20141231'

        assert context['ts'] == '2015-01-01T00:00:00+00:00'
        assert context['ts_nodash'] == '20150101T000000'
        assert context['ts_nodash_with_tz'] == '20150101T000000+0000'

        assert context['yesterday_ds'] == '2014-12-31'
        assert context['yesterday_ds_nodash'] == '20141231'

        assert context['tomorrow_ds'] == '2015-01-02'
        assert context['tomorrow_ds_nodash'] == '20150102'

    def test_local_task_job(self):
        TI = TaskInstance
        ti = TI(task=self.runme_0, execution_date=DEFAULT_DATE)
        job = LocalTaskJob(task_instance=ti, ignore_ti_state=True)
        job.run()

    def test_raw_job(self):
        TI = TaskInstance
        ti = TI(task=self.runme_0, execution_date=DEFAULT_DATE)
        ti.dag = self.dag_bash
        self.dag_bash.create_dagrun(run_type=DagRunType.MANUAL,
                                    state=State.RUNNING,
                                    execution_date=DEFAULT_DATE)
        ti.run(ignore_ti_state=True)

    def test_bad_trigger_rule(self):
        with pytest.raises(AirflowException):
            DummyOperator(task_id='test_bad_trigger',
                          trigger_rule="non_existent",
                          dag=self.dag)

    def test_terminate_task(self):
        """If a task instance's db state get deleted, it should fail"""
        from airflow.executors.sequential_executor import SequentialExecutor

        TI = TaskInstance
        dag = self.dagbag.dags.get('test_utils')
        task = dag.task_dict.get('sleeps_forever')

        ti = TI(task=task, execution_date=DEFAULT_DATE)
        job = LocalTaskJob(task_instance=ti,
                           ignore_ti_state=True,
                           executor=SequentialExecutor())

        # Running task instance asynchronously
        proc = multiprocessing.Process(target=job.run)
        proc.start()
        sleep(5)
        settings.engine.dispose()
        session = settings.Session()
        ti.refresh_from_db(session=session)
        # making sure it's actually running
        assert State.RUNNING == ti.state
        ti = (session.query(TI).filter_by(dag_id=task.dag_id,
                                          task_id=task.task_id,
                                          execution_date=DEFAULT_DATE).one())

        # deleting the instance should result in a failure
        session.delete(ti)
        session.commit()
        # waiting for the async task to finish
        proc.join()

        # making sure that the task ended up as failed
        ti.refresh_from_db(session=session)
        assert State.FAILED == ti.state
        session.close()

    def test_task_fail_duration(self):
        """If a task fails, the duration should be recorded in TaskFail"""

        op1 = BashOperator(task_id='pass_sleepy',
                           bash_command='sleep 3',
                           dag=self.dag)
        op2 = BashOperator(
            task_id='fail_sleepy',
            bash_command='sleep 5',
            execution_timeout=timedelta(seconds=3),
            retry_delay=timedelta(seconds=0),
            dag=self.dag,
        )
        session = settings.Session()
        try:
            op1.run(start_date=DEFAULT_DATE,
                    end_date=DEFAULT_DATE,
                    ignore_ti_state=True)
        except Exception:  # pylint: disable=broad-except
            pass
        try:
            op2.run(start_date=DEFAULT_DATE,
                    end_date=DEFAULT_DATE,
                    ignore_ti_state=True)
        except Exception:  # pylint: disable=broad-except
            pass
        op1_fails = (session.query(TaskFail).filter_by(
            task_id='pass_sleepy',
            dag_id=self.dag.dag_id,
            execution_date=DEFAULT_DATE).all())
        op2_fails = (session.query(TaskFail).filter_by(
            task_id='fail_sleepy',
            dag_id=self.dag.dag_id,
            execution_date=DEFAULT_DATE).all())

        assert 0 == len(op1_fails)
        assert 1 == len(op2_fails)
        assert sum([f.duration for f in op2_fails]) >= 3

    def test_externally_triggered_dagrun(self):
        TI = TaskInstance

        # Create the dagrun between two "scheduled" execution dates of the DAG
        execution_date = DEFAULT_DATE + timedelta(days=2)
        execution_ds = execution_date.strftime('%Y-%m-%d')
        execution_ds_nodash = execution_ds.replace('-', '')

        dag = DAG(TEST_DAG_ID,
                  default_args=self.args,
                  schedule_interval=timedelta(weeks=1),
                  start_date=DEFAULT_DATE)
        task = DummyOperator(task_id='test_externally_triggered_dag_context',
                             dag=dag)
        dag.create_dagrun(
            run_type=DagRunType.SCHEDULED,
            execution_date=execution_date,
            state=State.RUNNING,
            external_trigger=True,
        )
        task.run(start_date=execution_date, end_date=execution_date)

        ti = TI(task=task, execution_date=execution_date)
        context = ti.get_template_context()

        # next_ds/prev_ds should be the execution date for manually triggered runs
        assert context['next_ds'] == execution_ds
        assert context['next_ds_nodash'] == execution_ds_nodash

        assert context['prev_ds'] == execution_ds
        assert context['prev_ds_nodash'] == execution_ds_nodash

    def test_dag_params_and_task_params(self):
        # This test case guards how params of DAG and Operator work together.
        # - If any key exists in either DAG's or Operator's params,
        #   it is guaranteed to be available eventually.
        # - If any key exists in both DAG's params and Operator's params,
        #   the latter has precedence.
        TI = TaskInstance

        dag = DAG(
            TEST_DAG_ID,
            default_args=self.args,
            schedule_interval=timedelta(weeks=1),
            start_date=DEFAULT_DATE,
            params={
                'key_1': 'value_1',
                'key_2': 'value_2_old'
            },
        )
        task1 = DummyOperator(
            task_id='task1',
            dag=dag,
            params={
                'key_2': 'value_2_new',
                'key_3': 'value_3'
            },
        )
        task2 = DummyOperator(task_id='task2', dag=dag)
        dag.create_dagrun(
            run_type=DagRunType.SCHEDULED,
            execution_date=DEFAULT_DATE,
            state=State.RUNNING,
            external_trigger=True,
        )
        task1.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
        task2.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
        ti1 = TI(task=task1, execution_date=DEFAULT_DATE)
        ti2 = TI(task=task2, execution_date=DEFAULT_DATE)
        context1 = ti1.get_template_context()
        context2 = ti2.get_template_context()

        assert context1['params'] == {
            'key_1': 'value_1',
            'key_2': 'value_2_new',
            'key_3': 'value_3'
        }
        assert context2['params'] == {
            'key_1': 'value_1',
            'key_2': 'value_2_old'
        }
Exemplo n.º 4
0
def _get_dag_run(
    *,
    dag: DAG,
    exec_date_or_run_id: str,
    create_if_necessary: CreateIfNecessary,
    session: Session,
) -> Tuple[DagRun, bool]:
    """Try to retrieve a DAG run from a string representing either a run ID or logical date.

    This checks DAG runs like this:

    1. If the input ``exec_date_or_run_id`` matches a DAG run ID, return the run.
    2. Try to parse the input as a date. If that works, and the resulting
       date matches a DAG run's logical date, return the run.
    3. If ``create_if_necessary`` is *False* and the input works for neither of
       the above, raise ``DagRunNotFound``.
    4. Try to create a new DAG run. If the input looks like a date, use it as
       the logical date; otherwise use it as a run ID and set the logical date
       to the current time.
    """
    dag_run = dag.get_dagrun(run_id=exec_date_or_run_id, session=session)
    if dag_run:
        return dag_run, False

    try:
        execution_date: Optional[datetime.datetime] = timezone.parse(
            exec_date_or_run_id)
    except (ParserError, TypeError):
        execution_date = None

    try:
        dag_run = (session.query(DagRun).filter(
            DagRun.dag_id == dag.dag_id,
            DagRun.execution_date == execution_date).one())
    except NoResultFound:
        if not create_if_necessary:
            raise DagRunNotFound(
                f"DagRun for {dag.dag_id} with run_id or execution_date of {exec_date_or_run_id!r} not found"
            ) from None
    else:
        return dag_run, False

    if execution_date is not None:
        dag_run_execution_date = execution_date
    else:
        dag_run_execution_date = timezone.utcnow()
    if create_if_necessary == "memory":
        dag_run = DagRun(dag.dag_id,
                         run_id=exec_date_or_run_id,
                         execution_date=dag_run_execution_date)
        return dag_run, True
    elif create_if_necessary == "db":
        dag_run = dag.create_dagrun(
            state=DagRunState.QUEUED,
            execution_date=dag_run_execution_date,
            run_id=_generate_temporary_run_id(),
            session=session,
        )
        return dag_run, True
    raise ValueError(
        f"unknown create_if_necessary value: {create_if_necessary!r}")
Exemplo n.º 5
0
class TestCore(unittest.TestCase):
    default_scheduler_args = {"num_runs": 1}

    def setUp(self):
        self.dagbag = DagBag(dag_folder=DEV_NULL,
                             include_examples=True,
                             read_dags_from_db=False)
        self.args = {'owner': 'airflow', 'start_date': DEFAULT_DATE}
        self.dag = DAG(TEST_DAG_ID, default_args=self.args)
        self.dag_bash = self.dagbag.dags['example_bash_operator']
        self.runme_0 = self.dag_bash.get_task('runme_0')
        self.run_after_loop = self.dag_bash.get_task('run_after_loop')
        self.run_this_last = self.dag_bash.get_task('run_this_last')

    def tearDown(self):
        session = Session()
        session.query(DagRun).filter(DagRun.dag_id == TEST_DAG_ID).delete(
            synchronize_session=False)
        session.query(TaskInstance).filter(
            TaskInstance.dag_id == TEST_DAG_ID).delete(
                synchronize_session=False)
        session.query(TaskFail).filter(TaskFail.dag_id == TEST_DAG_ID).delete(
            synchronize_session=False)
        session.commit()
        session.close()
        clear_db_dags()
        clear_db_runs()

    def test_check_operators(self):

        conn_id = "sqlite_default"

        captain_hook = BaseHook.get_hook(conn_id=conn_id)  # quite funny :D
        captain_hook.run("CREATE TABLE operator_test_table (a, b)")
        captain_hook.run("insert into operator_test_table values (1,2)")

        self.dag.create_dagrun(run_type=DagRunType.MANUAL,
                               state=State.RUNNING,
                               execution_date=DEFAULT_DATE)
        op = CheckOperator(task_id='check',
                           sql="select count(*) from operator_test_table",
                           conn_id=conn_id,
                           dag=self.dag)

        op.run(start_date=DEFAULT_DATE,
               end_date=DEFAULT_DATE,
               ignore_ti_state=True)

        op = ValueCheckOperator(
            task_id='value_check',
            pass_value=95,
            tolerance=0.1,
            conn_id=conn_id,
            sql="SELECT 100",
            dag=self.dag,
        )
        op.run(start_date=DEFAULT_DATE,
               end_date=DEFAULT_DATE,
               ignore_ti_state=True)

        captain_hook.run("drop table operator_test_table")

    def test_clear_api(self):
        task = self.dag_bash.tasks[0]
        task.clear(start_date=DEFAULT_DATE,
                   end_date=DEFAULT_DATE,
                   upstream=True,
                   downstream=True)
        ti = TaskInstance(task=task, execution_date=DEFAULT_DATE)
        ti.are_dependents_done()

    def test_illegal_args(self):
        """
        Tests that Operators reject illegal arguments
        """
        msg = 'Invalid arguments were passed to BashOperator (task_id: test_illegal_args).'
        with conf_vars({('operators', 'allow_illegal_arguments'): 'True'}):
            with self.assertWarns(PendingDeprecationWarning) as warning:
                BashOperator(
                    task_id='test_illegal_args',
                    bash_command='echo success',
                    dag=self.dag,
                    illegal_argument_1234='hello?',
                )
                assert any(msg in str(w) for w in warning.warnings)

    def test_illegal_args_forbidden(self):
        """
        Tests that operators raise exceptions on illegal arguments when
        illegal arguments are not allowed.
        """
        with self.assertRaises(AirflowException) as ctx:
            BashOperator(
                task_id='test_illegal_args',
                bash_command='echo success',
                dag=self.dag,
                illegal_argument_1234='hello?',
            )
        self.assertIn(
            'Invalid arguments were passed to BashOperator (task_id: test_illegal_args).',
            str(ctx.exception),
        )

    def test_bash_operator(self):
        op = BashOperator(task_id='test_bash_operator',
                          bash_command="echo success",
                          dag=self.dag)
        self.dag.create_dagrun(run_type=DagRunType.MANUAL,
                               state=State.RUNNING,
                               execution_date=DEFAULT_DATE)

        op.run(start_date=DEFAULT_DATE,
               end_date=DEFAULT_DATE,
               ignore_ti_state=True)

    def test_bash_operator_multi_byte_output(self):
        op = BashOperator(
            task_id='test_multi_byte_bash_operator',
            bash_command="echo \u2600",
            dag=self.dag,
            output_encoding='utf-8',
        )
        self.dag.create_dagrun(run_type=DagRunType.MANUAL,
                               state=State.RUNNING,
                               execution_date=DEFAULT_DATE)
        op.run(start_date=DEFAULT_DATE,
               end_date=DEFAULT_DATE,
               ignore_ti_state=True)

    def test_bash_operator_kill(self):
        import psutil

        sleep_time = "100%d" % os.getpid()
        op = BashOperator(
            task_id='test_bash_operator_kill',
            execution_timeout=timedelta(seconds=1),
            bash_command="/bin/bash -c 'sleep %s'" % sleep_time,
            dag=self.dag,
        )
        self.assertRaises(AirflowTaskTimeout,
                          op.run,
                          start_date=DEFAULT_DATE,
                          end_date=DEFAULT_DATE)
        sleep(2)
        pid = -1
        for proc in psutil.process_iter():
            if proc.cmdline() == ['sleep', sleep_time]:
                pid = proc.pid
        if pid != -1:
            os.kill(pid, signal.SIGTERM)
            self.fail(
                "BashOperator's subprocess still running after stopping on timeout!"
            )

    def test_on_failure_callback(self):
        # Annoying workaround for nonlocal not existing in python 2
        data = {'called': False}

        def check_failure(context, test_case=self):
            data['called'] = True
            error = context.get('exception')
            test_case.assertIsInstance(error, AirflowException)

        op = BashOperator(
            task_id='check_on_failure_callback',
            bash_command="exit 1",
            dag=self.dag,
            on_failure_callback=check_failure,
        )
        self.assertRaises(AirflowException,
                          op.run,
                          start_date=DEFAULT_DATE,
                          end_date=DEFAULT_DATE,
                          ignore_ti_state=True)
        self.assertTrue(data['called'])

    def test_dryrun(self):
        op = BashOperator(task_id='test_dryrun',
                          bash_command="echo success",
                          dag=self.dag)
        op.dry_run()

    def test_sqlite(self):
        import airflow.providers.sqlite.operators.sqlite

        op = airflow.providers.sqlite.operators.sqlite.SqliteOperator(
            task_id='time_sqlite',
            sql="CREATE TABLE IF NOT EXISTS unitest (dummy VARCHAR(20))",
            dag=self.dag)
        self.dag.create_dagrun(run_type=DagRunType.MANUAL,
                               state=State.RUNNING,
                               execution_date=DEFAULT_DATE)
        op.run(start_date=DEFAULT_DATE,
               end_date=DEFAULT_DATE,
               ignore_ti_state=True)

    def test_timeout(self):
        op = PythonOperator(
            task_id='test_timeout',
            execution_timeout=timedelta(seconds=1),
            python_callable=lambda: sleep(5),
            dag=self.dag,
        )
        self.assertRaises(AirflowTaskTimeout,
                          op.run,
                          start_date=DEFAULT_DATE,
                          end_date=DEFAULT_DATE,
                          ignore_ti_state=True)

    def test_python_op(self):
        def test_py_op(templates_dict, ds, **kwargs):
            if not templates_dict['ds'] == ds:
                raise Exception("failure")

        op = PythonOperator(task_id='test_py_op',
                            python_callable=test_py_op,
                            templates_dict={'ds': "{{ ds }}"},
                            dag=self.dag)
        self.dag.create_dagrun(run_type=DagRunType.MANUAL,
                               state=State.RUNNING,
                               execution_date=DEFAULT_DATE)
        op.run(start_date=DEFAULT_DATE,
               end_date=DEFAULT_DATE,
               ignore_ti_state=True)

    def test_complex_template(self):
        def verify_templated_field(context):
            self.assertEqual(context['ti'].task.some_templated_field['bar'][1],
                             context['ds'])

        op = OperatorSubclass(
            task_id='test_complex_template',
            some_templated_field={
                'foo': '123',
                'bar': ['baz', '{{ ds }}']
            },
            dag=self.dag,
        )
        op.execute = verify_templated_field
        self.dag.create_dagrun(run_type=DagRunType.MANUAL,
                               state=State.RUNNING,
                               execution_date=DEFAULT_DATE)
        op.run(start_date=DEFAULT_DATE,
               end_date=DEFAULT_DATE,
               ignore_ti_state=True)

    def test_template_non_bool(self):
        """
        Test templates can handle objects with no sense of truthiness
        """
        class NonBoolObject:
            def __len__(self):  # pylint: disable=invalid-length-returned
                return NotImplemented

            def __bool__(self):  # pylint: disable=invalid-bool-returned, bad-option-value
                return NotImplemented

        op = OperatorSubclass(task_id='test_bad_template_obj',
                              some_templated_field=NonBoolObject(),
                              dag=self.dag)
        op.resolve_template_files()

    def test_task_get_template(self):
        TI = TaskInstance
        ti = TI(task=self.runme_0, execution_date=DEFAULT_DATE)
        ti.dag = self.dag_bash
        self.dag_bash.create_dagrun(run_type=DagRunType.MANUAL,
                                    state=State.RUNNING,
                                    execution_date=DEFAULT_DATE)
        ti.run(ignore_ti_state=True)
        context = ti.get_template_context()

        # DEFAULT DATE is 2015-01-01
        self.assertEqual(context['ds'], '2015-01-01')
        self.assertEqual(context['ds_nodash'], '20150101')

        # next_ds is 2015-01-02 as the dag interval is daily
        self.assertEqual(context['next_ds'], '2015-01-02')
        self.assertEqual(context['next_ds_nodash'], '20150102')

        # prev_ds is 2014-12-31 as the dag interval is daily
        self.assertEqual(context['prev_ds'], '2014-12-31')
        self.assertEqual(context['prev_ds_nodash'], '20141231')

        self.assertEqual(context['ts'], '2015-01-01T00:00:00+00:00')
        self.assertEqual(context['ts_nodash'], '20150101T000000')
        self.assertEqual(context['ts_nodash_with_tz'], '20150101T000000+0000')

        self.assertEqual(context['yesterday_ds'], '2014-12-31')
        self.assertEqual(context['yesterday_ds_nodash'], '20141231')

        self.assertEqual(context['tomorrow_ds'], '2015-01-02')
        self.assertEqual(context['tomorrow_ds_nodash'], '20150102')

    def test_local_task_job(self):
        TI = TaskInstance
        ti = TI(task=self.runme_0, execution_date=DEFAULT_DATE)
        job = LocalTaskJob(task_instance=ti, ignore_ti_state=True)
        job.run()

    def test_raw_job(self):
        TI = TaskInstance
        ti = TI(task=self.runme_0, execution_date=DEFAULT_DATE)
        ti.dag = self.dag_bash
        self.dag_bash.create_dagrun(run_type=DagRunType.MANUAL,
                                    state=State.RUNNING,
                                    execution_date=DEFAULT_DATE)
        ti.run(ignore_ti_state=True)

    def test_round_time(self):

        rt1 = round_time(datetime(2015, 1, 1, 6), timedelta(days=1))
        self.assertEqual(datetime(2015, 1, 1, 0, 0), rt1)

        rt2 = round_time(datetime(2015, 1, 2), relativedelta(months=1))
        self.assertEqual(datetime(2015, 1, 1, 0, 0), rt2)

        rt3 = round_time(datetime(2015, 9, 16, 0, 0), timedelta(1),
                         datetime(2015, 9, 14, 0, 0))
        self.assertEqual(datetime(2015, 9, 16, 0, 0), rt3)

        rt4 = round_time(datetime(2015, 9, 15, 0, 0), timedelta(1),
                         datetime(2015, 9, 14, 0, 0))
        self.assertEqual(datetime(2015, 9, 15, 0, 0), rt4)

        rt5 = round_time(datetime(2015, 9, 14, 0, 0), timedelta(1),
                         datetime(2015, 9, 14, 0, 0))
        self.assertEqual(datetime(2015, 9, 14, 0, 0), rt5)

        rt6 = round_time(datetime(2015, 9, 13, 0, 0), timedelta(1),
                         datetime(2015, 9, 14, 0, 0))
        self.assertEqual(datetime(2015, 9, 14, 0, 0), rt6)

    def test_infer_time_unit(self):

        self.assertEqual('minutes', infer_time_unit([130, 5400, 10]))

        self.assertEqual('seconds', infer_time_unit([110, 50, 10, 100]))

        self.assertEqual('hours',
                         infer_time_unit([100000, 50000, 10000, 20000]))

        self.assertEqual('days', infer_time_unit([200000, 100000]))

    def test_scale_time_units(self):

        # use assert_almost_equal from numpy.testing since we are comparing
        # floating point arrays
        arr1 = scale_time_units([130, 5400, 10], 'minutes')
        assert_array_almost_equal(arr1, [2.167, 90.0, 0.167], decimal=3)

        arr2 = scale_time_units([110, 50, 10, 100], 'seconds')
        assert_array_almost_equal(arr2, [110.0, 50.0, 10.0, 100.0], decimal=3)

        arr3 = scale_time_units([100000, 50000, 10000, 20000], 'hours')
        assert_array_almost_equal(arr3, [27.778, 13.889, 2.778, 5.556],
                                  decimal=3)

        arr4 = scale_time_units([200000, 100000], 'days')
        assert_array_almost_equal(arr4, [2.315, 1.157], decimal=3)

    def test_bad_trigger_rule(self):
        with self.assertRaises(AirflowException):
            DummyOperator(task_id='test_bad_trigger',
                          trigger_rule="non_existent",
                          dag=self.dag)

    def test_terminate_task(self):
        """If a task instance's db state get deleted, it should fail"""
        from airflow.executors.sequential_executor import SequentialExecutor

        TI = TaskInstance
        dag = self.dagbag.dags.get('test_utils')
        task = dag.task_dict.get('sleeps_forever')

        ti = TI(task=task, execution_date=DEFAULT_DATE)
        job = LocalTaskJob(task_instance=ti,
                           ignore_ti_state=True,
                           executor=SequentialExecutor())

        # Running task instance asynchronously
        proc = multiprocessing.Process(target=job.run)
        proc.start()
        sleep(5)
        settings.engine.dispose()
        session = settings.Session()
        ti.refresh_from_db(session=session)
        # making sure it's actually running
        self.assertEqual(State.RUNNING, ti.state)
        ti = (session.query(TI).filter_by(dag_id=task.dag_id,
                                          task_id=task.task_id,
                                          execution_date=DEFAULT_DATE).one())

        # deleting the instance should result in a failure
        session.delete(ti)
        session.commit()
        # waiting for the async task to finish
        proc.join()

        # making sure that the task ended up as failed
        ti.refresh_from_db(session=session)
        self.assertEqual(State.FAILED, ti.state)
        session.close()

    def test_task_fail_duration(self):
        """If a task fails, the duration should be recorded in TaskFail"""

        op1 = BashOperator(task_id='pass_sleepy',
                           bash_command='sleep 3',
                           dag=self.dag)
        op2 = BashOperator(
            task_id='fail_sleepy',
            bash_command='sleep 5',
            execution_timeout=timedelta(seconds=3),
            retry_delay=timedelta(seconds=0),
            dag=self.dag,
        )
        session = settings.Session()
        try:
            op1.run(start_date=DEFAULT_DATE,
                    end_date=DEFAULT_DATE,
                    ignore_ti_state=True)
        except Exception:  # pylint: disable=broad-except
            pass
        try:
            op2.run(start_date=DEFAULT_DATE,
                    end_date=DEFAULT_DATE,
                    ignore_ti_state=True)
        except Exception:  # pylint: disable=broad-except
            pass
        op1_fails = (session.query(TaskFail).filter_by(
            task_id='pass_sleepy',
            dag_id=self.dag.dag_id,
            execution_date=DEFAULT_DATE).all())
        op2_fails = (session.query(TaskFail).filter_by(
            task_id='fail_sleepy',
            dag_id=self.dag.dag_id,
            execution_date=DEFAULT_DATE).all())

        self.assertEqual(0, len(op1_fails))
        self.assertEqual(1, len(op2_fails))
        self.assertGreaterEqual(sum([f.duration for f in op2_fails]), 3)

    def test_externally_triggered_dagrun(self):
        TI = TaskInstance

        # Create the dagrun between two "scheduled" execution dates of the DAG
        execution_date = DEFAULT_DATE + timedelta(days=2)
        execution_ds = execution_date.strftime('%Y-%m-%d')
        execution_ds_nodash = execution_ds.replace('-', '')

        dag = DAG(TEST_DAG_ID,
                  default_args=self.args,
                  schedule_interval=timedelta(weeks=1),
                  start_date=DEFAULT_DATE)
        task = DummyOperator(task_id='test_externally_triggered_dag_context',
                             dag=dag)
        dag.create_dagrun(
            run_type=DagRunType.SCHEDULED,
            execution_date=execution_date,
            state=State.RUNNING,
            external_trigger=True,
        )
        task.run(start_date=execution_date, end_date=execution_date)

        ti = TI(task=task, execution_date=execution_date)
        context = ti.get_template_context()

        # next_ds/prev_ds should be the execution date for manually triggered runs
        self.assertEqual(context['next_ds'], execution_ds)
        self.assertEqual(context['next_ds_nodash'], execution_ds_nodash)

        self.assertEqual(context['prev_ds'], execution_ds)
        self.assertEqual(context['prev_ds_nodash'], execution_ds_nodash)
Exemplo n.º 6
0
class TestBaseSensor(unittest.TestCase):
    @staticmethod
    def clean_db():
        db.clear_db_runs()
        db.clear_db_task_reschedule()
        db.clear_db_xcom()

    def setUp(self):
        args = {'owner': 'airflow', 'start_date': DEFAULT_DATE}
        self.dag = DAG(TEST_DAG_ID, default_args=args)
        self.clean_db()

    def tearDown(self) -> None:
        self.clean_db()

    def _make_dag_run(self):
        return self.dag.create_dagrun(
            run_type=DagRunType.MANUAL,
            start_date=timezone.utcnow(),
            execution_date=DEFAULT_DATE,
            state=State.RUNNING,
        )

    def _make_sensor(self, return_value, task_id=SENSOR_OP, **kwargs):
        poke_interval = 'poke_interval'
        timeout = 'timeout'

        if poke_interval not in kwargs:
            kwargs[poke_interval] = 0
        if timeout not in kwargs:
            kwargs[timeout] = 0

        sensor = DummySensor(task_id=task_id, return_value=return_value, dag=self.dag, **kwargs)

        dummy_op = DummyOperator(task_id=DUMMY_OP, dag=self.dag)
        dummy_op.set_upstream(sensor)
        return sensor

    @classmethod
    def _run(cls, task):
        task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)

    def test_ok(self):
        sensor = self._make_sensor(True)
        dr = self._make_dag_run()

        self._run(sensor)
        tis = dr.get_task_instances()
        self.assertEqual(len(tis), 2)
        for ti in tis:
            if ti.task_id == SENSOR_OP:
                self.assertEqual(ti.state, State.SUCCESS)
            if ti.task_id == DUMMY_OP:
                self.assertEqual(ti.state, State.NONE)

    def test_fail(self):
        sensor = self._make_sensor(False)
        dr = self._make_dag_run()

        with self.assertRaises(AirflowSensorTimeout):
            self._run(sensor)
        tis = dr.get_task_instances()
        self.assertEqual(len(tis), 2)
        for ti in tis:
            if ti.task_id == SENSOR_OP:
                self.assertEqual(ti.state, State.FAILED)
            if ti.task_id == DUMMY_OP:
                self.assertEqual(ti.state, State.NONE)

    def test_soft_fail(self):
        sensor = self._make_sensor(False, soft_fail=True)
        dr = self._make_dag_run()

        self._run(sensor)
        tis = dr.get_task_instances()
        self.assertEqual(len(tis), 2)
        for ti in tis:
            if ti.task_id == SENSOR_OP:
                self.assertEqual(ti.state, State.SKIPPED)
            if ti.task_id == DUMMY_OP:
                self.assertEqual(ti.state, State.NONE)

    def test_soft_fail_with_retries(self):
        sensor = self._make_sensor(
            return_value=False, soft_fail=True, retries=1, retry_delay=timedelta(milliseconds=1)
        )
        dr = self._make_dag_run()

        # first run fails and task instance is marked up to retry
        with self.assertRaises(AirflowSensorTimeout):
            self._run(sensor)
        tis = dr.get_task_instances()
        self.assertEqual(len(tis), 2)
        for ti in tis:
            if ti.task_id == SENSOR_OP:
                self.assertEqual(ti.state, State.UP_FOR_RETRY)
            if ti.task_id == DUMMY_OP:
                self.assertEqual(ti.state, State.NONE)

        sleep(0.001)
        # after retry DAG run is skipped
        self._run(sensor)
        tis = dr.get_task_instances()
        self.assertEqual(len(tis), 2)
        for ti in tis:
            if ti.task_id == SENSOR_OP:
                self.assertEqual(ti.state, State.SKIPPED)
            if ti.task_id == DUMMY_OP:
                self.assertEqual(ti.state, State.NONE)

    def test_ok_with_reschedule(self):
        sensor = self._make_sensor(return_value=None, poke_interval=10, timeout=25, mode='reschedule')
        sensor.poke = Mock(side_effect=[False, False, True])
        dr = self._make_dag_run()

        # first poke returns False and task is re-scheduled
        date1 = timezone.utcnow()
        with freeze_time(date1):
            self._run(sensor)
        tis = dr.get_task_instances()
        self.assertEqual(len(tis), 2)
        for ti in tis:
            if ti.task_id == SENSOR_OP:
                # verify task is re-scheduled, i.e. state set to NONE
                self.assertEqual(ti.state, State.UP_FOR_RESCHEDULE)
                # verify task start date is the initial one
                self.assertEqual(ti.start_date, date1)
                # verify one row in task_reschedule table
                task_reschedules = TaskReschedule.find_for_task_instance(ti)
                self.assertEqual(len(task_reschedules), 1)
                self.assertEqual(task_reschedules[0].start_date, date1)
                self.assertEqual(
                    task_reschedules[0].reschedule_date, date1 + timedelta(seconds=sensor.poke_interval)
                )
            if ti.task_id == DUMMY_OP:
                self.assertEqual(ti.state, State.NONE)

        # second poke returns False and task is re-scheduled
        date2 = date1 + timedelta(seconds=sensor.poke_interval)
        with freeze_time(date2):
            self._run(sensor)
        tis = dr.get_task_instances()
        self.assertEqual(len(tis), 2)
        for ti in tis:
            if ti.task_id == SENSOR_OP:
                # verify task is re-scheduled, i.e. state set to NONE
                self.assertEqual(ti.state, State.UP_FOR_RESCHEDULE)
                # verify task start date is the initial one
                self.assertEqual(ti.start_date, date1)
                # verify two rows in task_reschedule table
                task_reschedules = TaskReschedule.find_for_task_instance(ti)
                self.assertEqual(len(task_reschedules), 2)
                self.assertEqual(task_reschedules[1].start_date, date2)
                self.assertEqual(
                    task_reschedules[1].reschedule_date, date2 + timedelta(seconds=sensor.poke_interval)
                )
            if ti.task_id == DUMMY_OP:
                self.assertEqual(ti.state, State.NONE)

        # third poke returns True and task succeeds
        date3 = date2 + timedelta(seconds=sensor.poke_interval)
        with freeze_time(date3):
            self._run(sensor)
        tis = dr.get_task_instances()
        self.assertEqual(len(tis), 2)
        for ti in tis:
            if ti.task_id == SENSOR_OP:
                self.assertEqual(ti.state, State.SUCCESS)
                # verify task start date is the initial one
                self.assertEqual(ti.start_date, date1)
            if ti.task_id == DUMMY_OP:
                self.assertEqual(ti.state, State.NONE)

    def test_fail_with_reschedule(self):
        sensor = self._make_sensor(return_value=False, poke_interval=10, timeout=5, mode='reschedule')
        dr = self._make_dag_run()

        # first poke returns False and task is re-scheduled
        date1 = timezone.utcnow()
        with freeze_time(date1):
            self._run(sensor)
        tis = dr.get_task_instances()
        self.assertEqual(len(tis), 2)
        for ti in tis:
            if ti.task_id == SENSOR_OP:
                self.assertEqual(ti.state, State.UP_FOR_RESCHEDULE)
            if ti.task_id == DUMMY_OP:
                self.assertEqual(ti.state, State.NONE)

        # second poke returns False, timeout occurs
        date2 = date1 + timedelta(seconds=sensor.poke_interval)
        with freeze_time(date2):
            with self.assertRaises(AirflowSensorTimeout):
                self._run(sensor)
        tis = dr.get_task_instances()
        self.assertEqual(len(tis), 2)
        for ti in tis:
            if ti.task_id == SENSOR_OP:
                self.assertEqual(ti.state, State.FAILED)
            if ti.task_id == DUMMY_OP:
                self.assertEqual(ti.state, State.NONE)

    def test_soft_fail_with_reschedule(self):
        sensor = self._make_sensor(
            return_value=False, poke_interval=10, timeout=5, soft_fail=True, mode='reschedule'
        )
        dr = self._make_dag_run()

        # first poke returns False and task is re-scheduled
        date1 = timezone.utcnow()
        with freeze_time(date1):
            self._run(sensor)
        tis = dr.get_task_instances()
        self.assertEqual(len(tis), 2)
        for ti in tis:
            if ti.task_id == SENSOR_OP:
                self.assertEqual(ti.state, State.UP_FOR_RESCHEDULE)
            if ti.task_id == DUMMY_OP:
                self.assertEqual(ti.state, State.NONE)

        # second poke returns False, timeout occurs
        date2 = date1 + timedelta(seconds=sensor.poke_interval)
        with freeze_time(date2):
            self._run(sensor)
        tis = dr.get_task_instances()
        self.assertEqual(len(tis), 2)
        for ti in tis:
            if ti.task_id == SENSOR_OP:
                self.assertEqual(ti.state, State.SKIPPED)
            if ti.task_id == DUMMY_OP:
                self.assertEqual(ti.state, State.NONE)

    def test_ok_with_reschedule_and_retry(self):
        sensor = self._make_sensor(
            return_value=None,
            poke_interval=10,
            timeout=5,
            retries=1,
            retry_delay=timedelta(seconds=10),
            mode='reschedule',
        )
        sensor.poke = Mock(side_effect=[False, False, False, True])
        dr = self._make_dag_run()

        # first poke returns False and task is re-scheduled
        date1 = timezone.utcnow()
        with freeze_time(date1):
            self._run(sensor)
        tis = dr.get_task_instances()
        self.assertEqual(len(tis), 2)
        for ti in tis:
            if ti.task_id == SENSOR_OP:
                self.assertEqual(ti.state, State.UP_FOR_RESCHEDULE)
                # verify one row in task_reschedule table
                task_reschedules = TaskReschedule.find_for_task_instance(ti)
                self.assertEqual(len(task_reschedules), 1)
                self.assertEqual(task_reschedules[0].start_date, date1)
                self.assertEqual(
                    task_reschedules[0].reschedule_date, date1 + timedelta(seconds=sensor.poke_interval)
                )
                self.assertEqual(task_reschedules[0].try_number, 1)
            if ti.task_id == DUMMY_OP:
                self.assertEqual(ti.state, State.NONE)

        # second poke fails and task instance is marked up to retry
        date2 = date1 + timedelta(seconds=sensor.poke_interval)
        with freeze_time(date2):
            with self.assertRaises(AirflowSensorTimeout):
                self._run(sensor)
        tis = dr.get_task_instances()
        self.assertEqual(len(tis), 2)
        for ti in tis:
            if ti.task_id == SENSOR_OP:
                self.assertEqual(ti.state, State.UP_FOR_RETRY)
            if ti.task_id == DUMMY_OP:
                self.assertEqual(ti.state, State.NONE)

        # third poke returns False and task is rescheduled again
        date3 = date2 + timedelta(seconds=sensor.poke_interval) + sensor.retry_delay
        with freeze_time(date3):
            self._run(sensor)
        tis = dr.get_task_instances()
        self.assertEqual(len(tis), 2)
        for ti in tis:
            if ti.task_id == SENSOR_OP:
                self.assertEqual(ti.state, State.UP_FOR_RESCHEDULE)
                # verify one row in task_reschedule table
                task_reschedules = TaskReschedule.find_for_task_instance(ti)
                self.assertEqual(len(task_reschedules), 1)
                self.assertEqual(task_reschedules[0].start_date, date3)
                self.assertEqual(
                    task_reschedules[0].reschedule_date, date3 + timedelta(seconds=sensor.poke_interval)
                )
                self.assertEqual(task_reschedules[0].try_number, 2)
            if ti.task_id == DUMMY_OP:
                self.assertEqual(ti.state, State.NONE)

        # fourth poke return True and task succeeds
        date4 = date3 + timedelta(seconds=sensor.poke_interval)
        with freeze_time(date4):
            self._run(sensor)
        tis = dr.get_task_instances()
        self.assertEqual(len(tis), 2)
        for ti in tis:
            if ti.task_id == SENSOR_OP:
                self.assertEqual(ti.state, State.SUCCESS)
            if ti.task_id == DUMMY_OP:
                self.assertEqual(ti.state, State.NONE)

    def test_should_include_ready_to_reschedule_dep_in_reschedule_mode(self):
        sensor = self._make_sensor(True, mode='reschedule')
        deps = sensor.deps
        self.assertIn(ReadyToRescheduleDep(), deps)

    def test_should_not_include_ready_to_reschedule_dep_in_poke_mode(self):
        sensor = self._make_sensor(True)
        deps = sensor.deps
        self.assertNotIn(ReadyToRescheduleDep(), deps)

    def test_invalid_mode(self):
        with self.assertRaises(AirflowException):
            self._make_sensor(return_value=True, mode='foo')

    def test_ok_with_custom_reschedule_exception(self):
        sensor = self._make_sensor(return_value=None, mode='reschedule')
        date1 = timezone.utcnow()
        date2 = date1 + timedelta(seconds=60)
        date3 = date1 + timedelta(seconds=120)
        sensor.poke = Mock(
            side_effect=[
                AirflowRescheduleException(date2),
                AirflowRescheduleException(date3),
                True,
            ]
        )
        dr = self._make_dag_run()

        # first poke returns False and task is re-scheduled
        with freeze_time(date1):
            self._run(sensor)
        tis = dr.get_task_instances()
        self.assertEqual(len(tis), 2)
        for ti in tis:
            if ti.task_id == SENSOR_OP:
                # verify task is re-scheduled, i.e. state set to NONE
                self.assertEqual(ti.state, State.UP_FOR_RESCHEDULE)
                # verify one row in task_reschedule table
                task_reschedules = TaskReschedule.find_for_task_instance(ti)
                self.assertEqual(len(task_reschedules), 1)
                self.assertEqual(task_reschedules[0].start_date, date1)
                self.assertEqual(task_reschedules[0].reschedule_date, date2)
            if ti.task_id == DUMMY_OP:
                self.assertEqual(ti.state, State.NONE)

        # second poke returns False and task is re-scheduled
        with freeze_time(date2):
            self._run(sensor)
        tis = dr.get_task_instances()
        self.assertEqual(len(tis), 2)
        for ti in tis:
            if ti.task_id == SENSOR_OP:
                # verify task is re-scheduled, i.e. state set to NONE
                self.assertEqual(ti.state, State.UP_FOR_RESCHEDULE)
                # verify two rows in task_reschedule table
                task_reschedules = TaskReschedule.find_for_task_instance(ti)
                self.assertEqual(len(task_reschedules), 2)
                self.assertEqual(task_reschedules[1].start_date, date2)
                self.assertEqual(task_reschedules[1].reschedule_date, date3)
            if ti.task_id == DUMMY_OP:
                self.assertEqual(ti.state, State.NONE)

        # third poke returns True and task succeeds
        with freeze_time(date3):
            self._run(sensor)
        tis = dr.get_task_instances()
        self.assertEqual(len(tis), 2)
        for ti in tis:
            if ti.task_id == SENSOR_OP:
                self.assertEqual(ti.state, State.SUCCESS)
            if ti.task_id == DUMMY_OP:
                self.assertEqual(ti.state, State.NONE)

    def test_reschedule_with_test_mode(self):
        sensor = self._make_sensor(return_value=None, poke_interval=10, timeout=25, mode='reschedule')
        sensor.poke = Mock(side_effect=[False])
        dr = self._make_dag_run()

        # poke returns False and AirflowRescheduleException is raised
        date1 = timezone.utcnow()
        with freeze_time(date1):
            for date in self.dag.date_range(DEFAULT_DATE, end_date=DEFAULT_DATE):
                TaskInstance(sensor, date).run(ignore_ti_state=True, test_mode=True)
        tis = dr.get_task_instances()
        self.assertEqual(len(tis), 2)
        for ti in tis:
            if ti.task_id == SENSOR_OP:
                # in test mode state is not modified
                self.assertEqual(ti.state, State.NONE)
                # in test mode no reschedule request is recorded
                task_reschedules = TaskReschedule.find_for_task_instance(ti)
                self.assertEqual(len(task_reschedules), 0)
            if ti.task_id == DUMMY_OP:
                self.assertEqual(ti.state, State.NONE)

    def test_sensor_with_invalid_poke_interval(self):
        negative_poke_interval = -10
        non_number_poke_interval = "abcd"
        positive_poke_interval = 10
        with self.assertRaises(AirflowException):
            self._make_sensor(
                task_id='test_sensor_task_1',
                return_value=None,
                poke_interval=negative_poke_interval,
                timeout=25,
            )

        with self.assertRaises(AirflowException):
            self._make_sensor(
                task_id='test_sensor_task_2',
                return_value=None,
                poke_interval=non_number_poke_interval,
                timeout=25,
            )

        self._make_sensor(
            task_id='test_sensor_task_3', return_value=None, poke_interval=positive_poke_interval, timeout=25
        )

    def test_sensor_with_invalid_timeout(self):
        negative_timeout = -25
        non_number_timeout = "abcd"
        positive_timeout = 25
        with self.assertRaises(AirflowException):
            self._make_sensor(
                task_id='test_sensor_task_1', return_value=None, poke_interval=10, timeout=negative_timeout
            )

        with self.assertRaises(AirflowException):
            self._make_sensor(
                task_id='test_sensor_task_2', return_value=None, poke_interval=10, timeout=non_number_timeout
            )

        self._make_sensor(
            task_id='test_sensor_task_3', return_value=None, poke_interval=10, timeout=positive_timeout
        )

    def test_sensor_with_exponential_backoff_off(self):
        sensor = self._make_sensor(return_value=None, poke_interval=5, timeout=60, exponential_backoff=False)

        started_at = timezone.utcnow() - timedelta(seconds=10)
        self.assertEqual(sensor._get_next_poke_interval(started_at, 1), sensor.poke_interval)
        self.assertEqual(sensor._get_next_poke_interval(started_at, 2), sensor.poke_interval)

    def test_sensor_with_exponential_backoff_on(self):

        sensor = self._make_sensor(return_value=None, poke_interval=5, timeout=60, exponential_backoff=True)

        with patch('airflow.utils.timezone.utcnow') as mock_utctime:
            mock_utctime.return_value = DEFAULT_DATE

            started_at = timezone.utcnow() - timedelta(seconds=10)
            print(started_at)

            interval1 = sensor._get_next_poke_interval(started_at, 1)
            interval2 = sensor._get_next_poke_interval(started_at, 2)

            self.assertTrue(interval1 >= 0)
            self.assertTrue(interval1 <= sensor.poke_interval)
            self.assertTrue(interval2 >= sensor.poke_interval)
            self.assertTrue(interval2 > interval1)
Exemplo n.º 7
0
    def test_mark_success_on_success_callback(self):
        """
        Test that ensures that where a task is marked suceess in the UI
        on_success_callback gets executed
        """
        # use shared memory value so we can properly track value change even if
        # it's been updated across processes.
        success_callback_called = Value('i', 0)
        task_terminated_externally = Value('i', 1)
        shared_mem_lock = Lock()

        def success_callback(context):
            with shared_mem_lock:
                success_callback_called.value += 1
            assert context['dag_run'].dag_id == 'test_mark_success'

        dag = DAG(dag_id='test_mark_success', start_date=DEFAULT_DATE, default_args={'owner': 'owner1'})

        def task_function(ti):
            # pylint: disable=unused-argument
            time.sleep(60)
            # This should not happen -- the state change should be noticed and the task should get killed
            with shared_mem_lock:
                task_terminated_externally.value = 0

        task = PythonOperator(
            task_id='test_state_succeeded1',
            python_callable=task_function,
            on_success_callback=success_callback,
            dag=dag,
        )

        session = settings.Session()

        dag.clear()
        dag.create_dagrun(
            run_id="test",
            state=State.RUNNING,
            execution_date=DEFAULT_DATE,
            start_date=DEFAULT_DATE,
            session=session,
        )
        ti = TaskInstance(task=task, execution_date=DEFAULT_DATE)
        ti.refresh_from_db()
        job1 = LocalTaskJob(task_instance=ti, ignore_ti_state=True, executor=SequentialExecutor())
        job1.task_runner = StandardTaskRunner(job1)

        settings.engine.dispose()
        process = multiprocessing.Process(target=job1.run)
        process.start()

        for _ in range(0, 25):
            ti.refresh_from_db()
            if ti.state == State.RUNNING:
                break
            time.sleep(0.2)
        assert ti.state == State.RUNNING
        ti.state = State.SUCCESS
        session.merge(ti)
        session.commit()

        process.join(timeout=10)
        assert success_callback_called.value == 1
        assert task_terminated_externally.value == 1
        assert not process.is_alive()
Exemplo n.º 8
0
    def test_failure_callback_only_called_once(self, mock_return_code, _check_call):
        """
        Test that ensures that when a task exits with failure by itself,
        failure callback is only called once
        """
        # use shared memory value so we can properly track value change even if
        # it's been updated across processes.
        failure_callback_called = Value('i', 0)
        callback_count_lock = Lock()

        def failure_callback(context):
            with callback_count_lock:
                failure_callback_called.value += 1
            assert context['dag_run'].dag_id == 'test_failure_callback_race'
            assert isinstance(context['exception'], AirflowFailException)

        def task_function(ti):
            raise AirflowFailException()

        dag = DAG(dag_id='test_failure_callback_race', start_date=DEFAULT_DATE)
        task = PythonOperator(
            task_id='test_exit_on_failure',
            python_callable=task_function,
            on_failure_callback=failure_callback,
            dag=dag,
        )

        dag.clear()
        with create_session() as session:
            dag.create_dagrun(
                run_id="test",
                state=State.RUNNING,
                execution_date=DEFAULT_DATE,
                start_date=DEFAULT_DATE,
                session=session,
            )
        ti = TaskInstance(task=task, execution_date=DEFAULT_DATE)
        ti.refresh_from_db()

        job1 = LocalTaskJob(task_instance=ti, ignore_ti_state=True, executor=SequentialExecutor())

        # Simulate race condition where job1 heartbeat ran right after task
        # state got set to failed by ti.handle_failure but before task process
        # fully exits. See _execute loop in airflow/jobs/local_task_job.py.
        # In this case, we have:
        #  * task_runner.return_code() is None
        #  * ti.state == State.Failed
        #
        # We also need to set return_code to a valid int after job1.terminating
        # is set to True so _execute loop won't loop forever.
        def dummy_return_code(*args, **kwargs):
            return None if not job1.terminating else -9

        mock_return_code.side_effect = dummy_return_code

        with timeout(10):
            # This should be _much_ shorter to run.
            # If you change this limit, make the timeout in the callbable above bigger
            job1.run()

        ti.refresh_from_db()
        assert ti.state == State.FAILED  # task exits with failure state
        assert failure_callback_called.value == 1