def test_gauge_executor_metrics(self, mock_stats_gauge, mock_trigger_tasks, mock_sync):
     executor = SequentialExecutor()
     executor.heartbeat()
     calls = [mock.call('executor.open_slots', mock.ANY),
              mock.call('executor.queued_tasks', mock.ANY),
              mock.call('executor.running_tasks', mock.ANY)]
     mock_stats_gauge.assert_has_calls(calls)
예제 #2
0
    def test_mark_failure_on_failure_callback(self):
        """
        Test that ensures that mark_failure in the UI fails
        the task, and executes on_failure_callback
        """
        # use shared memory value so we can properly track value change even if
        # it's been updated across processes.
        failure_callback_called = Value('i', 0)
        task_terminated_externally = Value('i', 1)

        def check_failure(context):
            with failure_callback_called.get_lock():
                failure_callback_called.value += 1
            assert context['dag_run'].dag_id == 'test_mark_failure'
            assert context['exception'] == "task marked as failed externally"

        def task_function(ti):
            with create_session() as session:
                assert State.RUNNING == ti.state
                ti.log.info("Marking TI as failed 'externally'")
                ti.state = State.FAILED
                session.merge(ti)
                session.commit()

            time.sleep(10)
            # This should not happen -- the state change should be noticed and the task should get killed
            with task_terminated_externally.get_lock():
                task_terminated_externally.value = 0

        with DAG(dag_id='test_mark_failure', start_date=DEFAULT_DATE) as dag:
            task = PythonOperator(
                task_id='test_state_succeeded1',
                python_callable=task_function,
                on_failure_callback=check_failure,
            )

        dag.clear()
        with create_session() as session:
            dag.create_dagrun(
                run_id="test",
                state=State.RUNNING,
                execution_date=DEFAULT_DATE,
                start_date=DEFAULT_DATE,
                session=session,
            )
        ti = TaskInstance(task=task, execution_date=DEFAULT_DATE)
        ti.refresh_from_db()

        job1 = LocalTaskJob(task_instance=ti,
                            ignore_ti_state=True,
                            executor=SequentialExecutor())
        with timeout(30):
            # This should be _much_ shorter to run.
            # If you change this limit, make the timeout in the callable above bigger
            job1.run()

        ti.refresh_from_db()
        assert ti.state == State.FAILED
        assert failure_callback_called.value == 1
        assert task_terminated_externally.value == 1
예제 #3
0
    def _get_executor(executor_name: str) -> BaseExecutor:
        """
        Creates a new instance of the named executor.
        In case the executor name is unknown in airflow,
        look for it in the plugins
        """
        if executor_name == ExecutorLoader.LOCAL_EXECUTOR:
            from airflow.executors.local_executor import LocalExecutor
            return LocalExecutor()
        elif executor_name == ExecutorLoader.SEQUENTIAL_EXECUTOR:
            from airflow.executors.sequential_executor import SequentialExecutor
            return SequentialExecutor()
        elif executor_name == ExecutorLoader.CELERY_EXECUTOR:
            from airflow.executors.celery_executor import CeleryExecutor
            return CeleryExecutor()
        elif executor_name == ExecutorLoader.DASK_EXECUTOR:
            from airflow.executors.dask_executor import DaskExecutor
            return DaskExecutor()
        elif executor_name == ExecutorLoader.KUBERNETES_EXECUTOR:
            from airflow.executors.kubernetes_executor import KubernetesExecutor
            return KubernetesExecutor()
        else:
            # Load plugins here for executors as at that time the plugins might not have been initialized yet
            # TODO: verify the above and remove two lines below in case plugins are always initialized first
            from airflow import plugins_manager
            plugins_manager.integrate_executor_plugins()
            executor_path = executor_name.split('.')
            assert len(executor_path) == 2, f"Executor {executor_name} not supported: " \
                                            f"please specify in format plugin_module.executor"

            assert executor_path[0] in globals(
            ), f"Executor {executor_name} not supported"
            return globals()[executor_path[0]].__dict__[executor_path[1]]()
예제 #4
0
    def test_terminate_task(self):
        """If a task instance's db state get deleted, it should fail"""
        from airflow.executors.sequential_executor import SequentialExecutor
        TI = TaskInstance
        dag = self.dagbag.dags.get('test_utils')
        task = dag.task_dict.get('sleeps_forever')

        ti = TI(task=task, execution_date=DEFAULT_DATE)
        job = LocalTaskJob(task_instance=ti,
                           ignore_ti_state=True,
                           executor=SequentialExecutor())

        # Running task instance asynchronously
        proc = multiprocessing.Process(target=job.run)
        proc.start()
        sleep(5)
        settings.engine.dispose()
        session = settings.Session()
        ti.refresh_from_db(session=session)
        # making sure it's actually running
        self.assertEqual(State.RUNNING, ti.state)
        ti = session.query(TI).filter_by(dag_id=task.dag_id,
                                         task_id=task.task_id,
                                         execution_date=DEFAULT_DATE).one()

        # deleting the instance should result in a failure
        session.delete(ti)
        session.commit()
        # waiting for the async task to finish
        proc.join()

        # making sure that the task ended up as failed
        ti.refresh_from_db(session=session)
        self.assertEqual(State.FAILED, ti.state)
        session.close()
예제 #5
0
def _get_executor(executor_name):
    """
    Creates a new instance of the named executor.
    In case the executor name is not know in airflow,
    look for it in the plugins
    """
    parallelism = PARALLELISM
    if executor_name == Executors.LocalExecutor:
        return LocalExecutor(parallelism)
    elif executor_name == Executors.SequentialExecutor:
        return SequentialExecutor(parallelism)
    elif executor_name == Executors.CeleryExecutor:
        from airflow.executors.celery_executor import CeleryExecutor, execute_command
        return CeleryExecutor(parallelism, execute_command)
    elif executor_name == Executors.DaskExecutor:
        from airflow.executors.dask_executor import DaskExecutor
        cluster_address = configuration.conf.get('dask', 'cluster_address')
        tls_ca = configuration.conf.get('dask', 'tls_ca')
        tls_key = configuration.conf.get('dask', 'tls_key')
        tls_cert = configuration.conf.get('dask', 'tls_cert')
        return DaskExecutor(parallelism, cluster_address, tls_ca, tls_key,
                            tls_cert)
    elif executor_name == Executors.MesosExecutor:
        from airflow.contrib.executors.mesos_executor import MesosExecutor
        return MesosExecutor(parallelism)
    elif executor_name == Executors.KubernetesExecutor:
        from airflow.contrib.executors.kubernetes_executor import KubernetesExecutor
        return KubernetesExecutor()
    else:
        # Loading plugins
        _integrate_plugins()
        # 从插件模块中获取指定类
        args = []
        kwargs = {'parallelism': PARALLELISM}
        return create_object_from_plugin_module(executor_name, *args, **kwargs)
예제 #6
0
    def test_localtaskjob_essential_attr(self):
        """
        Check whether essential attributes
        of LocalTaskJob can be assigned with
        proper values without intervention
        """
        dag = DAG('test_localtaskjob_essential_attr',
                  start_date=DEFAULT_DATE,
                  default_args={'owner': 'owner1'})

        with dag:
            op1 = DummyOperator(task_id='op1')

        dag.clear()
        dr = dag.create_dagrun(run_id="test",
                               state=State.SUCCESS,
                               execution_date=DEFAULT_DATE,
                               start_date=DEFAULT_DATE)
        ti = dr.get_task_instance(task_id=op1.task_id)

        job1 = LocalTaskJob(task_instance=ti,
                            ignore_ti_state=True,
                            executor=SequentialExecutor())

        essential_attr = ["dag_id", "job_type", "start_date", "hostname"]

        check_result_1 = [hasattr(job1, attr) for attr in essential_attr]
        self.assertTrue(all(check_result_1))

        check_result_2 = [
            getattr(job1, attr) is not None for attr in essential_attr
        ]
        self.assertTrue(all(check_result_2))
예제 #7
0
def get_executor_for_test():
    try:
        from dbnd_airflow.executors.simple_executor import InProcessExecutor

        return InProcessExecutor()
    except Exception:
        return SequentialExecutor()
예제 #8
0
def _get_executor(executor_name):
    """
    Creates a new instance of the named executor. In case the executor name is not know in airflow, 
    look for it in the plugins
    """
    if executor_name == 'LocalExecutor':
        return LocalExecutor()
    elif executor_name == 'SequentialExecutor':
        return SequentialExecutor()
    elif executor_name == 'CeleryExecutor':
        from airflow.executors.celery_executor import CeleryExecutor
        return CeleryExecutor()
    elif executor_name == 'DaskExecutor':
        from airflow.executors.dask_executor import DaskExecutor
        return DaskExecutor()
    elif executor_name == 'MesosExecutor':
        from airflow.contrib.executors.mesos_executor import MesosExecutor
        return MesosExecutor()
    else:
        # Loading plugins
        _integrate_plugins()
        executor_path = executor_name.split('.')
        if len(executor_path) != 2:
            raise AirflowException(
                "Executor {0} not supported: please specify in format plugin_module.executor"
                .format(executor_name))

        if executor_path[0] in globals():
            return globals()[executor_path[0]].__dict__[executor_path[1]]()
        else:
            raise AirflowException(
                "Executor {0} not supported.".format(executor_name))
예제 #9
0
    def test_mark_failure_on_failure_callback(self):
        """
        Test that ensures that mark_failure in the UI fails
        the task, and executes on_failure_callback
        """
        data = {'called': False}

        def check_failure(context):
            self.assertEqual(context['dag_run'].dag_id, 'test_mark_failure')
            data['called'] = True

        def task_function(ti):
            print("python_callable run in pid %s", os.getpid())
            with create_session() as session:
                self.assertEqual(State.RUNNING, ti.state)
                ti.log.info("Marking TI as failed 'externally'")
                ti.state = State.FAILED
                session.merge(ti)
                session.commit()

            time.sleep(60)
            # This should not happen -- the state change should be noticed and the task should get killed
            data['reached_end_of_sleep'] = True

        with DAG(dag_id='test_mark_failure', start_date=DEFAULT_DATE) as dag:
            task = PythonOperator(
                task_id='test_state_succeeded1',
                python_callable=task_function,
                on_failure_callback=check_failure,
            )

        session = settings.Session()

        dag.clear()
        dag.create_dagrun(
            run_id="test",
            state=State.RUNNING,
            execution_date=DEFAULT_DATE,
            start_date=DEFAULT_DATE,
            session=session,
        )
        ti = TaskInstance(task=task, execution_date=DEFAULT_DATE)
        ti.refresh_from_db()

        job1 = LocalTaskJob(task_instance=ti,
                            ignore_ti_state=True,
                            executor=SequentialExecutor())
        with timeout(30):
            # This should be _much_ shorter to run.
            # If you change this limit, make the timeout in the callbable above bigger
            job1.run()

        ti.refresh_from_db()
        self.assertEqual(ti.state, State.FAILED)
        self.assertTrue(data['called'])
        self.assertNotIn(
            'reached_end_of_sleep', data,
            'Task should not have been allowed to run to completion')
예제 #10
0
    def test_mark_success_on_success_callback(self):
        """
        Test that ensures that where a task is marked suceess in the UI
        on_success_callback gets executed
        """
        data = {'called': False}

        def success_callback(context):
            self.assertEqual(context['dag_run'].dag_id, 'test_mark_success')
            data['called'] = True

        dag = DAG(dag_id='test_mark_success',
                  start_date=DEFAULT_DATE,
                  default_args={'owner': 'owner1'})

        task = DummyOperator(task_id='test_state_succeeded1',
                             dag=dag,
                             on_success_callback=success_callback)

        session = settings.Session()

        dag.clear()
        dag.create_dagrun(
            run_id="test",
            state=State.RUNNING,
            execution_date=DEFAULT_DATE,
            start_date=DEFAULT_DATE,
            session=session,
        )
        ti = TaskInstance(task=task, execution_date=DEFAULT_DATE)
        ti.refresh_from_db()
        job1 = LocalTaskJob(task_instance=ti,
                            ignore_ti_state=True,
                            executor=SequentialExecutor())
        from airflow.task.task_runner.standard_task_runner import StandardTaskRunner

        job1.task_runner = StandardTaskRunner(job1)
        process = multiprocessing.Process(target=job1.run)
        process.start()
        ti.refresh_from_db()
        for _ in range(0, 50):
            if ti.state == State.RUNNING:
                break
            time.sleep(0.1)
            ti.refresh_from_db()
        self.assertEqual(State.RUNNING, ti.state)
        ti.state = State.SUCCESS
        session.merge(ti)
        session.commit()

        job1.heartbeat_callback(session=None)
        self.assertTrue(data['called'])
        process.join(timeout=10)
        self.assertFalse(process.is_alive())
예제 #11
0
    def __init__(self, subdag, executor=SequentialExecutor(), *args, **kwargs):
        """
        This runs a sub dag. By convention, a sub dag's dag_id
        should be prefixed by its parent and a dot. As in `parent.child`.

        :param subdag: the DAG object to run as a subdag of the current DAG.
        :type subdag: airflow.DAG.
        :param dag: the parent DAG for the subdag.
        :type dag: airflow.DAG.
        :param executor: the executor for this subdag. Default to use SequentialExecutor.
                         Please find AIRFLOW-74 for more details.
        :type executor: airflow.executors.
        """
        import airflow.models
        dag = kwargs.get('dag') or airflow.models._CONTEXT_MANAGER_DAG
        if not dag:
            raise AirflowException('Please pass in the `dag` param or call '
                                   'within a DAG context manager')
        session = kwargs.pop('session')
        super(SubDagOperator, self).__init__(*args, **kwargs)

        # validate subdag name
        if dag.dag_id + '.' + kwargs['task_id'] != subdag.dag_id:
            raise AirflowException(
                "The subdag's dag_id should have the form "
                "'{{parent_dag_id}}.{{this_task_id}}'. Expected "
                "'{d}.{t}'; received '{rcvd}'.".format(d=dag.dag_id,
                                                       t=kwargs['task_id'],
                                                       rcvd=subdag.dag_id))

        # validate that subdag operator and subdag tasks don't have a
        # pool conflict
        if self.pool:
            conflicts = [t for t in subdag.tasks if t.pool == self.pool]
            if conflicts:
                # only query for pool conflicts if one may exist
                pool = (session.query(Pool).filter(Pool.slots == 1).filter(
                    Pool.pool == self.pool).first())
                if pool and any(t.pool == self.pool for t in subdag.tasks):
                    raise AirflowException(
                        'SubDagOperator {sd} and subdag task{plural} {t} both '
                        'use pool {p}, but the pool only has 1 slot. The '
                        'subdag tasks will never run.'.format(
                            sd=self.task_id,
                            plural=len(conflicts) > 1,
                            t=', '.join(t.task_id for t in conflicts),
                            p=self.pool))

        self.subdag = subdag
        # Airflow pool is not honored by SubDagOperator.
        # Hence resources could be consumed by SubdagOperators
        # Use other executor with your own risk.
        self.executor = executor
예제 #12
0
    def test_localtaskjob_maintain_heart_rate(self):
        dagbag = DagBag(
            dag_folder=TEST_DAG_FOLDER,
            include_examples=False,
        )
        dag = dagbag.dags.get('test_localtaskjob_double_trigger')
        task = dag.get_task('test_localtaskjob_double_trigger_task')

        session = settings.Session()

        dag.clear()
        dag.create_dagrun(
            run_id="test",
            state=State.SUCCESS,
            execution_date=DEFAULT_DATE,
            start_date=DEFAULT_DATE,
            session=session,
        )

        ti_run = TaskInstance(task=task, execution_date=DEFAULT_DATE)
        ti_run.refresh_from_db()
        job1 = LocalTaskJob(task_instance=ti_run,
                            executor=SequentialExecutor())

        # this should make sure we only heartbeat once and exit at the second
        # loop in _execute()
        return_codes = [None, 0]

        def multi_return_code():
            return return_codes.pop(0)

        time_start = time.time()
        from airflow.task.task_runner.standard_task_runner import StandardTaskRunner

        with patch.object(StandardTaskRunner, 'start',
                          return_value=None) as mock_start:
            with patch.object(StandardTaskRunner,
                              'return_code') as mock_ret_code:
                mock_ret_code.side_effect = multi_return_code
                job1.run()
                self.assertEqual(mock_start.call_count, 1)
                self.assertEqual(mock_ret_code.call_count, 2)
        time_end = time.time()

        self.assertEqual(self.mock_base_job_sleep.call_count, 1)
        self.assertEqual(job1.state, State.SUCCESS)

        # Consider we have patched sleep call, it should not be sleeping to
        # keep up with the heart rate in other unpatched places
        #
        # We already make sure patched sleep call is only called once
        self.assertLess(time_end - time_start, job1.heartrate)
        session.close()
예제 #13
0
    def test_essential_attr(self, mock_getuser, mock_hostname, mock_default_executor):
        mock_sequential_executor = SequentialExecutor()
        mock_hostname.return_value = "test_hostname"
        mock_getuser.return_value = "testuser"
        mock_default_executor.return_value = mock_sequential_executor

        test_job = self.TestJob(None, heartrate=10, dag_id="example_dag", state=State.RUNNING)
        self.assertEqual(test_job.executor_class, "SequentialExecutor")
        self.assertEqual(test_job.heartrate, 10)
        self.assertEqual(test_job.dag_id, "example_dag")
        self.assertEqual(test_job.hostname, "test_hostname")
        self.assertEqual(test_job.max_tis_per_query, 100)
        self.assertEqual(test_job.unixname, "testuser")
        self.assertEqual(test_job.state, "running")
        self.assertEqual(test_job.executor, mock_sequential_executor)
예제 #14
0
    def test_essential_attr(self, mock_getuser, mock_hostname, mock_default_executor):
        mock_sequential_executor = SequentialExecutor()
        mock_hostname.return_value = "test_hostname"
        mock_getuser.return_value = "testuser"
        mock_default_executor.return_value = mock_sequential_executor

        test_job = MockJob(None, heartrate=10, dag_id="example_dag", state=State.RUNNING)
        assert test_job.executor_class == "SequentialExecutor"
        assert test_job.heartrate == 10
        assert test_job.dag_id == "example_dag"
        assert test_job.hostname == "test_hostname"
        assert test_job.max_tis_per_query == 100
        assert test_job.unixname == "testuser"
        assert test_job.state == "running"
        assert test_job.executor == mock_sequential_executor
예제 #15
0
    def __init__(self,
                 subdag: DAG,
                 executor: BaseExecutor = SequentialExecutor(),
                 *args,
                 **kwargs) -> None:
        dag = kwargs.get('dag') or settings.CONTEXT_MANAGER_DAG
        if not dag:
            raise AirflowException('Please pass in the `dag` param or call '
                                   'within a DAG context manager')
        session = kwargs.pop('session')
        super().__init__(*args, **kwargs)

        # validate subdag name
        if dag.dag_id + '.' + kwargs['task_id'] != subdag.dag_id:
            raise AirflowException(
                "The subdag's dag_id should have the form "
                "'{{parent_dag_id}}.{{this_task_id}}'. Expected "
                "'{d}.{t}'; received '{rcvd}'.".format(d=dag.dag_id,
                                                       t=kwargs['task_id'],
                                                       rcvd=subdag.dag_id))

        # validate that subdag operator and subdag tasks don't have a
        # pool conflict
        if self.pool:
            conflicts = [t for t in subdag.tasks if t.pool == self.pool]
            if conflicts:
                # only query for pool conflicts if one may exist
                pool = (session.query(Pool).filter(Pool.slots == 1).filter(
                    Pool.pool == self.pool).first())
                if pool and any(t.pool == self.pool for t in subdag.tasks):
                    raise AirflowException(
                        'SubDagOperator {sd} and subdag task{plural} {t} both '
                        'use pool {p}, but the pool only has 1 slot. The '
                        'subdag tasks will never run.'.format(
                            sd=self.task_id,
                            plural=len(conflicts) > 1,
                            t=', '.join(t.task_id for t in conflicts),
                            p=self.pool))

        self.subdag = subdag
        # Airflow pool is not honored by SubDagOperator.
        # Hence resources could be consumed by SubdagOperators
        # Use other executor with your own risk.
        self.executor = executor
예제 #16
0
    def test_localtaskjob_heartbeat(self):
        session = settings.Session()
        dag = DAG('test_localtaskjob_heartbeat',
                  start_date=DEFAULT_DATE,
                  default_args={'owner': 'owner1'})

        with dag:
            op1 = DummyOperator(task_id='op1')

        dag.clear()
        dr = dag.create_dagrun(
            run_id="test",
            state=State.SUCCESS,
            execution_date=DEFAULT_DATE,
            start_date=DEFAULT_DATE,
            session=session,
        )
        ti = dr.get_task_instance(task_id=op1.task_id, session=session)
        ti.state = State.RUNNING
        ti.hostname = "blablabla"
        session.commit()

        job1 = LocalTaskJob(task_instance=ti,
                            ignore_ti_state=True,
                            executor=SequentialExecutor())
        ti.task = op1
        ti.refresh_from_task(op1)
        job1.task_runner = StandardTaskRunner(job1)
        job1.task_runner.process = mock.Mock()
        with pytest.raises(AirflowException):
            job1.heartbeat_callback()  # pylint: disable=no-value-for-parameter

        job1.task_runner.process.pid = 1
        ti.state = State.RUNNING
        ti.hostname = get_hostname()
        ti.pid = 1
        session.merge(ti)
        session.commit()
        assert ti.pid != os.getpid()
        job1.heartbeat_callback(session=None)

        job1.task_runner.process.pid = 2
        with pytest.raises(AirflowException):
            job1.heartbeat_callback()  # pylint: disable=no-value-for-parameter
예제 #17
0
    def test_localtaskjob_double_trigger(self):
        dagbag = DagBag(
            dag_folder=TEST_DAG_FOLDER,
            include_examples=False,
        )
        dag = dagbag.dags.get('test_localtaskjob_double_trigger')
        task = dag.get_task('test_localtaskjob_double_trigger_task')

        session = settings.Session()

        dag.clear()
        dr = dag.create_dagrun(
            run_id="test",
            state=State.SUCCESS,
            execution_date=DEFAULT_DATE,
            start_date=DEFAULT_DATE,
            session=session,
        )
        ti = dr.get_task_instance(task_id=task.task_id, session=session)
        ti.state = State.RUNNING
        ti.hostname = get_hostname()
        ti.pid = 1
        session.merge(ti)
        session.commit()

        ti_run = TaskInstance(task=task, execution_date=DEFAULT_DATE)
        ti_run.refresh_from_db()
        job1 = LocalTaskJob(task_instance=ti_run,
                            executor=SequentialExecutor())
        from airflow.task.task_runner.standard_task_runner import StandardTaskRunner

        with patch.object(StandardTaskRunner, 'start',
                          return_value=None) as mock_method:
            job1.run()
            mock_method.assert_not_called()

        ti = dr.get_task_instance(task_id=task.task_id, session=session)
        self.assertEqual(ti.pid, 1)
        self.assertEqual(ti.state, State.RUNNING)

        session.close()
예제 #18
0
    def test_localtaskjob_heartbeat(self, mock_pid):
        session = settings.Session()
        dag = DAG('test_localtaskjob_heartbeat',
                  start_date=DEFAULT_DATE,
                  default_args={'owner': 'owner1'})

        with dag:
            op1 = DummyOperator(task_id='op1')

        dag.clear()
        dr = dag.create_dagrun(
            run_id="test",
            state=State.SUCCESS,
            execution_date=DEFAULT_DATE,
            start_date=DEFAULT_DATE,
            session=session,
        )
        ti = dr.get_task_instance(task_id=op1.task_id, session=session)
        ti.state = State.RUNNING
        ti.hostname = "blablabla"
        session.commit()

        job1 = LocalTaskJob(task_instance=ti,
                            ignore_ti_state=True,
                            executor=SequentialExecutor())
        self.assertRaises(AirflowException, job1.heartbeat_callback)

        mock_pid.return_value = 1
        ti.state = State.RUNNING
        ti.hostname = get_hostname()
        ti.pid = 1
        session.merge(ti)
        session.commit()

        job1.heartbeat_callback(session=None)

        mock_pid.return_value = 2
        self.assertRaises(AirflowException, job1.heartbeat_callback)
예제 #19
0
    def __init__(self, subdag, executor=SequentialExecutor(), *args, **kwargs):
        self.subdag = subdag
        self.executor = executor

        super(AirflowSubDagOperator, self).__init__(*args, **kwargs)
예제 #20
0

def _integrate_plugins():
    """Integrate plugins to the context."""
    from airflow.plugins_manager import executors_modules
    for executors_module in executors_modules:
        sys.modules[executors_module.__name__] = executors_module
        globals()[executors_module._name] = executors_module


_EXECUTOR = configuration.get('core', 'EXECUTOR')

if _EXECUTOR == 'LocalExecutor':
    DEFAULT_EXECUTOR = LocalExecutor()
elif _EXECUTOR == 'CeleryExecutor':
    DEFAULT_EXECUTOR = CeleryExecutor()
elif _EXECUTOR == 'SequentialExecutor':
    DEFAULT_EXECUTOR = SequentialExecutor()
elif _EXECUTOR == 'MesosExecutor':
    from airflow.contrib.executors.mesos_executor import MesosExecutor
    DEFAULT_EXECUTOR = MesosExecutor()
else:
    # Loading plugins
    _integrate_plugins()
    if _EXECUTOR in globals():
        DEFAULT_EXECUTOR = globals()[_EXECUTOR]()
    else:
        raise AirflowException("Executor {0} not supported.".format(_EXECUTOR))

_log.info("Using executor " + _EXECUTOR)
예제 #21
0
    def test_failure_callback_only_called_once(self, mock_return_code, _check_call):
        """
        Test that ensures that when a task exits with failure by itself,
        failure callback is only called once
        """
        # use shared memory value so we can properly track value change even if
        # it's been updated across processes.
        failure_callback_called = Value('i', 0)
        callback_count_lock = Lock()

        def failure_callback(context):
            with callback_count_lock:
                failure_callback_called.value += 1
            assert context['dag_run'].dag_id == 'test_failure_callback_race'
            assert isinstance(context['exception'], AirflowFailException)

        def task_function(ti):
            raise AirflowFailException()

        dag = DAG(dag_id='test_failure_callback_race', start_date=DEFAULT_DATE)
        task = PythonOperator(
            task_id='test_exit_on_failure',
            python_callable=task_function,
            on_failure_callback=failure_callback,
            dag=dag,
        )

        dag.clear()
        with create_session() as session:
            dag.create_dagrun(
                run_id="test",
                state=State.RUNNING,
                execution_date=DEFAULT_DATE,
                start_date=DEFAULT_DATE,
                session=session,
            )
        ti = TaskInstance(task=task, execution_date=DEFAULT_DATE)
        ti.refresh_from_db()

        job1 = LocalTaskJob(task_instance=ti, ignore_ti_state=True, executor=SequentialExecutor())

        # Simulate race condition where job1 heartbeat ran right after task
        # state got set to failed by ti.handle_failure but before task process
        # fully exits. See _execute loop in airflow/jobs/local_task_job.py.
        # In this case, we have:
        #  * task_runner.return_code() is None
        #  * ti.state == State.Failed
        #
        # We also need to set return_code to a valid int after job1.terminating
        # is set to True so _execute loop won't loop forever.
        def dummy_return_code(*args, **kwargs):
            return None if not job1.terminating else -9

        mock_return_code.side_effect = dummy_return_code

        with timeout(10):
            # This should be _much_ shorter to run.
            # If you change this limit, make the timeout in the callbable above bigger
            job1.run()

        ti.refresh_from_db()
        assert ti.state == State.FAILED  # task exits with failure state
        assert failure_callback_called.value == 1
예제 #22
0
    def test_mark_success_on_success_callback(self):
        """
        Test that ensures that where a task is marked suceess in the UI
        on_success_callback gets executed
        """
        # use shared memory value so we can properly track value change even if
        # it's been updated across processes.
        success_callback_called = Value('i', 0)
        task_terminated_externally = Value('i', 1)
        shared_mem_lock = Lock()

        def success_callback(context):
            with shared_mem_lock:
                success_callback_called.value += 1
            assert context['dag_run'].dag_id == 'test_mark_success'

        dag = DAG(dag_id='test_mark_success', start_date=DEFAULT_DATE, default_args={'owner': 'owner1'})

        def task_function(ti):
            # pylint: disable=unused-argument
            time.sleep(60)
            # This should not happen -- the state change should be noticed and the task should get killed
            with shared_mem_lock:
                task_terminated_externally.value = 0

        task = PythonOperator(
            task_id='test_state_succeeded1',
            python_callable=task_function,
            on_success_callback=success_callback,
            dag=dag,
        )

        session = settings.Session()

        dag.clear()
        dag.create_dagrun(
            run_id="test",
            state=State.RUNNING,
            execution_date=DEFAULT_DATE,
            start_date=DEFAULT_DATE,
            session=session,
        )
        ti = TaskInstance(task=task, execution_date=DEFAULT_DATE)
        ti.refresh_from_db()
        job1 = LocalTaskJob(task_instance=ti, ignore_ti_state=True, executor=SequentialExecutor())
        job1.task_runner = StandardTaskRunner(job1)

        settings.engine.dispose()
        process = multiprocessing.Process(target=job1.run)
        process.start()

        for _ in range(0, 25):
            ti.refresh_from_db()
            if ti.state == State.RUNNING:
                break
            time.sleep(0.2)
        assert ti.state == State.RUNNING
        ti.state = State.SUCCESS
        session.merge(ti)
        session.commit()

        process.join(timeout=10)
        assert success_callback_called.value == 1
        assert task_terminated_externally.value == 1
        assert not process.is_alive()
예제 #23
0
from airflow.operators.dummy_operator import DummyOperator
from airflow.operators.subdag_operator import SubDagOperator
from airflow.executors.sequential_executor import SequentialExecutor
from airflow.executors.celery_executor import CeleryExecutor

DAG_NAME = "test_subdag"

default_args = {
    'owner': 'Airflow',
    'start_date': airflow.utils.dates.days_ago(2)
}

with DAG(dag_id=DAG_NAME, default_args=default_args,
         schedule_interval="@once") as dag:
    start = DummyOperator(task_id='start')

    subdag_1 = SubDagOperator(task_id='subdag-1',
                              subdag=factory_subdag(DAG_NAME, 'subdag-1',
                                                    default_args),
                              executor=SequentialExecutor())

    some_other_task = DummyOperator(task_id='check')

    subdag_2 = SubDagOperator(task_id='subdag-2',
                              subdag=factory_subdag(DAG_NAME, 'subdag-2',
                                                    default_args),
                              executor=SequentialExecutor())

    end = DummyOperator(task_id='final')

    start >> subdag_1 >> some_other_task >> subdag_2 >> end