Beispiel #1
0
    def test_backfill_pooled_tasks(self):
        """
        Test that queued tasks are executed by BackfillJob

        Test for https://github.com/airbnb/airflow/pull/1225
        """
        session = settings.Session()
        pool = Pool(pool='test_backfill_pooled_task_pool', slots=1)
        session.add(pool)
        session.commit()

        dag = self.dagbag.get_dag('test_backfill_pooled_task_dag')
        dag.clear()

        job = BackfillJob(
            dag=dag,
            start_date=DEFAULT_DATE,
            end_date=DEFAULT_DATE)

        # run with timeout because this creates an infinite loop if not
        # caught
        with timeout(seconds=30):
            job.run()

        ti = TI(
            task=dag.get_task('test_backfill_pooled_task'),
            execution_date=DEFAULT_DATE)
        ti.refresh_from_db()
        self.assertEqual(ti.state, State.SUCCESS)
Beispiel #2
0
def send_task_to_executor(task_tuple):
    key, simple_ti, command, queue, task = task_tuple
    try:
        with timeout(seconds=2):
            result = task.apply_async(args=[command], queue=queue)
    except Exception as e:
        exception_traceback = "Celery Task ID: {}\n{}".format(key,
                                                              traceback.format_exc())
        result = ExceptionWithTraceback(e, exception_traceback)

    return key, command, result
def send_task_to_executor(task_tuple):
    key, simple_ti, command, queue, task = task_tuple
    try:
        with timeout(seconds=2):
            result = task.apply_async(args=[command], queue=queue)
    except Exception as e:
        exception_traceback = "Celery Task ID: {}\n{}".format(
            key, traceback.format_exc())
        result = ExceptionWithTraceback(e, exception_traceback)

    return key, command, result
Beispiel #4
0
    def test_mark_failure_on_failure_callback(self):
        """
        Test that ensures that mark_failure in the UI fails
        the task, and executes on_failure_callback
        """
        data = {'called': False}

        def check_failure(context):
            self.assertEqual(context['dag_run'].dag_id, 'test_mark_failure')
            data['called'] = True

        def task_function(ti):
            print("python_callable run in pid %s", os.getpid())
            with create_session() as session:
                self.assertEqual(State.RUNNING, ti.state)
                ti.log.info("Marking TI as failed 'externally'")
                ti.state = State.FAILED
                session.merge(ti)
                session.commit()

            time.sleep(60)
            # This should not happen -- the state change should be noticed and the task should get killed
            data['reached_end_of_sleep'] = True

        with DAG(dag_id='test_mark_failure', start_date=DEFAULT_DATE) as dag:
            task = PythonOperator(task_id='test_state_succeeded1',
                                  python_callable=task_function,
                                  on_failure_callback=check_failure)

        session = settings.Session()

        dag.clear()
        dag.create_dagrun(run_id="test",
                          state=State.RUNNING,
                          execution_date=DEFAULT_DATE,
                          start_date=DEFAULT_DATE,
                          session=session)
        ti = TaskInstance(task=task, execution_date=DEFAULT_DATE)
        ti.refresh_from_db()

        job1 = LocalTaskJob(task_instance=ti,
                            ignore_ti_state=True,
                            executor=SequentialExecutor())
        with timeout(30):
            # This should be _much_ shorter to run.
            # If you change this limit, make the timeout in the callbable above bigger
            job1.run()

        ti.refresh_from_db()
        self.assertEqual(ti.state, State.FAILED)
        self.assertTrue(data['called'])
        self.assertNotIn(
            'reached_end_of_sleep', data,
            'Task should not have been allowed to run to completion')
Beispiel #5
0
def send_task_to_executor(task_tuple: TaskInstanceInCelery) \
        -> Tuple[TaskInstanceKey, CommandType, Union[AsyncResult, ExceptionWithTraceback]]:
    """Sends task to executor."""
    key, _, command, queue, task_to_run = task_tuple
    try:
        with timeout(seconds=OPERATION_TIMEOUT):
            result = task_to_run.apply_async(args=[command], queue=queue)
    except Exception as e:  # pylint: disable=broad-except
        exception_traceback = "Celery Task ID: {}\n{}".format(key, traceback.format_exc())
        result = ExceptionWithTraceback(e, exception_traceback)

    return key, command, result
Beispiel #6
0
    def _load_modules_from_file(self, filepath, safe_mode):
        if not might_contain_dag(filepath, safe_mode):
            # Don't want to spam user with skip messages
            if not self.has_logged:
                self.has_logged = True
                self.log.info("File %s assumed to contain no DAGs. Skipping.",
                              filepath)
            return []

        self.log.debug("Importing %s", filepath)
        org_mod_name, _ = os.path.splitext(os.path.split(filepath)[-1])
        path_hash = hashlib.sha1(filepath.encode('utf-8')).hexdigest()
        mod_name = f'unusual_prefix_{path_hash}_{org_mod_name}'

        if mod_name in sys.modules:
            del sys.modules[mod_name]

        def parse(mod_name, filepath):
            try:
                loader = importlib.machinery.SourceFileLoader(
                    mod_name, filepath)
                spec = importlib.util.spec_from_loader(mod_name, loader)
                new_module = importlib.util.module_from_spec(spec)
                sys.modules[spec.name] = new_module
                loader.exec_module(new_module)
                return [new_module]
            except Exception as e:
                self.log.exception("Failed to import: %s", filepath)
                if self.dagbag_import_error_tracebacks:
                    self.import_errors[filepath] = traceback.format_exc(
                        limit=-self.dagbag_import_error_traceback_depth)
                else:
                    self.import_errors[filepath] = str(e)
                return []

        dagbag_import_timeout = settings.get_dagbag_import_timeout(filepath)

        if not isinstance(dagbag_import_timeout, (int, float)):
            raise TypeError(
                f'Value ({dagbag_import_timeout}) from get_dagbag_import_timeout must be int or float'
            )

        if dagbag_import_timeout <= 0:  # no parsing timeout
            return parse(mod_name, filepath)

        timeout_msg = (
            f"DagBag import timeout for {filepath} after {dagbag_import_timeout}s.\n"
            "Please take a look at these docs to improve your DAG import time:\n"
            f"* {get_docs_url('best-practices.html#top-level-python-code')}\n"
            f"* {get_docs_url('best-practices.html#reducing-dag-complexity')}")
        with timeout(dagbag_import_timeout, error_message=timeout_msg):
            return parse(mod_name, filepath)
Beispiel #7
0
    def test_subdag_clear_parentdag_downstream_clear(self):
        dag = self.dagbag.get_dag('example_subdag_operator')
        subdag_op_task = dag.get_task('section-1')

        subdag = subdag_op_task.subdag
        subdag.schedule_interval = '@daily'

        executor = TestExecutor()
        job = BackfillJob(dag=subdag,
                          start_date=DEFAULT_DATE,
                          end_date=DEFAULT_DATE,
                          executor=executor,
                          donot_pickle=True)

        with timeout(seconds=30):
            job.run()

        ti0 = TI(
            task=subdag.get_task('section-1-task-1'),
            execution_date=DEFAULT_DATE)
        ti0.refresh_from_db()
        self.assertEqual(ti0.state, State.SUCCESS)

        sdag = subdag.sub_dag(
            task_regex='section-1-task-1',
            include_downstream=True,
            include_upstream=False)

        sdag.clear(
            start_date=DEFAULT_DATE,
            end_date=DEFAULT_DATE,
            include_parentdag=True)

        ti0.refresh_from_db()
        self.assertEqual(State.NONE, ti0.state)

        ti1 = TI(
            task=dag.get_task('some-other-task'),
            execution_date=DEFAULT_DATE)
        self.assertEqual(State.NONE, ti1.state)

        # Checks that all the Downstream tasks for Parent DAG
        # have been cleared
        for task in subdag_op_task.downstream_list:
            ti = TI(
                task=dag.get_task(task.task_id),
                execution_date=DEFAULT_DATE
            )
            self.assertEqual(State.NONE, ti.state)

        subdag.clear()
        dag.clear()
Beispiel #8
0
    def test_subdag_clear_parentdag_downstream_clear(self):
        dag = self.dagbag.get_dag('clear_subdag_test_dag')
        subdag_op_task = dag.get_task('daily_job')

        subdag = subdag_op_task.subdag

        executor = MockExecutor()
        job = BackfillJob(dag=dag,
                          start_date=DEFAULT_DATE,
                          end_date=DEFAULT_DATE,
                          executor=executor,
                          donot_pickle=True)

        with timeout(seconds=30):
            job.run()

        ti_subdag = TI(task=dag.get_task('daily_job'),
                       execution_date=DEFAULT_DATE)
        ti_subdag.refresh_from_db()
        self.assertEqual(ti_subdag.state, State.SUCCESS)

        ti_irrelevant = TI(task=dag.get_task('daily_job_irrelevant'),
                           execution_date=DEFAULT_DATE)
        ti_irrelevant.refresh_from_db()
        self.assertEqual(ti_irrelevant.state, State.SUCCESS)

        ti_downstream = TI(task=dag.get_task('daily_job_downstream'),
                           execution_date=DEFAULT_DATE)
        ti_downstream.refresh_from_db()
        self.assertEqual(ti_downstream.state, State.SUCCESS)

        sdag = subdag.sub_dag(task_regex='daily_job_subdag_task',
                              include_downstream=True,
                              include_upstream=False)

        sdag.clear(start_date=DEFAULT_DATE,
                   end_date=DEFAULT_DATE,
                   include_parentdag=True)

        ti_subdag.refresh_from_db()
        self.assertEqual(State.NONE, ti_subdag.state)

        ti_irrelevant.refresh_from_db()
        self.assertEqual(State.SUCCESS, ti_irrelevant.state)

        ti_downstream.refresh_from_db()
        self.assertEqual(State.NONE, ti_downstream.state)

        subdag.clear()
        dag.clear()
def _execute_task(self, context, task_copy):
    """Executes Task (optionally with a Timeout) and pushes Xcom results"""
    # If a timeout is specified for the task, make it fail
    # if it goes beyond
    if task_copy.execution_timeout:
        try:
            with timeout(task_copy.execution_timeout.total_seconds()):
                result = task_copy.execute(context=context)
        except AirflowTaskTimeout:
            task_copy.on_kill()
            raise
    else:
        result = task_copy.execute(context=context)
    # If the task returns a result, push an XCom containing it
    # if task_copy.do_xcom_push and result is not None:
    #     self.xcom_push(key=XCOM_RETURN_KEY, value=result)
    return result
Beispiel #10
0
    def test_backfill_execute_subdag_with_removed_task(self):
        """
        Ensure that subdag operators execute properly in the case where
        an associated task of the subdag has been removed from the dag
        definition, but has instances in the database from previous runs.
        """
        dag = self.dagbag.get_dag('example_subdag_operator')
        subdag = dag.get_task('section-1').subdag

        executor = MockExecutor()
        job = BackfillJob(dag=subdag,
                          start_date=DEFAULT_DATE,
                          end_date=DEFAULT_DATE,
                          executor=executor,
                          donot_pickle=True)

        removed_task_ti = TI(
            task=DummyOperator(task_id='removed_task'),
            execution_date=DEFAULT_DATE,
            state=State.REMOVED)
        removed_task_ti.dag_id = subdag.dag_id

        session = settings.Session()
        session.merge(removed_task_ti)
        session.commit()

        with timeout(seconds=30):
            job.run()

        for task in subdag.tasks:
            instance = session.query(TI).filter(
                TI.dag_id == subdag.dag_id,
                TI.task_id == task.task_id,
                TI.execution_date == DEFAULT_DATE).first()

            self.assertIsNotNone(instance)
            self.assertEqual(instance.state, State.SUCCESS)

        removed_task_ti.refresh_from_db()
        self.assertEqual(removed_task_ti.state, State.REMOVED)

        subdag.clear()
        dag.clear()
Beispiel #11
0
    def process_file(self, filepath, only_if_updated=True, safe_mode=False):
        found_templates = []

        if filepath is None or not os.path.isfile(filepath):
            return found_templates

        try:
            # This failed before in what may have been a git sync
            # race condition
            file_last_changed_on_disk = datetime.fromtimestamp(
                os.path.getmtime(filepath))
            if only_if_updated and file_last_changed_on_disk == self.file_last_changed.get(
                    filepath, None):
                return found_templates
        except Exception as e:
            self.log.exception(e)
            return found_templates

        if safe_mode:
            # TODO: heuristic to process only if file contains template
            pass

        self.log.debug(f'Importing {filepath}')
        modname, _ = os.path.splitext(os.path.split(filepath)[-1])
        mods = []
        with timeout(self.TEMPLATE_IMPORT_TIMEOUT):
            try:
                m = imp.load_source(modname, filepath)
                mods.append(m)
            except Exception as e:
                self.log.exception(f'Failed to import: {filepath}')
                self.import_errors[filepath] = str(e)
                self.file_last_changed[filepath] = file_last_changed_on_disk
        for mod in mods:
            for val in list(m.__dict__.values()):
                if isinstance(val,
                              type) and val != BaseDagTemplate and issubclass(
                                  val, BaseDagTemplate):
                    tmpl = val
                    self._add_template(tmpl)
                    found_templates.append(tmpl)
        self.file_last_changed[filepath] = file_last_changed_on_disk
        return found_templates
Beispiel #12
0
def fetch_celery_task_state(async_result: AsyncResult) -> Tuple[str, Union[str, ExceptionWithTraceback], Any]:
    """
    Fetch and return the state of the given celery task. The scope of this function is
    global so that it can be called by subprocesses in the pool.

    :param async_result: a tuple of the Celery task key and the async Celery object used
        to fetch the task's state
    :return: a tuple of the Celery task key and the Celery state and the celery info
        of the task
    :rtype: tuple[str, str, str]
    """
    try:
        with timeout(seconds=OPERATION_TIMEOUT):
            # Accessing state property of celery task will make actual network request
            # to get the current state of the task
            info = async_result.info if hasattr(async_result, 'info') else None
            return async_result.task_id, async_result.state, info
    except Exception as e:
        exception_traceback = f"Celery Task ID: {async_result}\n{traceback.format_exc()}"
        return async_result.task_id, ExceptionWithTraceback(e, exception_traceback), None
Beispiel #13
0
def fetch_celery_task_state(celery_task):
    """
    Fetch and return the state of the given celery task. The scope of this function is
    global so that it can be called by subprocesses in the pool.
    :param celery_task: a tuple of the Celery task key and the async Celery object used
                        to fetch the task's state
    :type celery_task: (str, celery.result.AsyncResult)
    :return: a tuple of the Celery task key and the Celery state of the task
    :rtype: (str, str)
    """

    try:
        with timeout(seconds=2):
            # Accessing state property of celery task will make actual network request
            # to get the current state of the task.
            res = (celery_task[0], celery_task[1].state)
    except Exception as e:
        exception_traceback = "Celery Task ID: {}\n{}".format(
            celery_task[0], traceback.format_exc())
        res = ExceptionWithTraceback(e, exception_traceback)
    return res
def fetch_celery_task_state(celery_task):
    """
    Fetch and return the state of the given celery task. The scope of this function is
    global so that it can be called by subprocesses in the pool.
    :param celery_task: a tuple of the Celery task key and the async Celery object used
                        to fetch the task's state
    :type celery_task: (str, celery.result.AsyncResult)
    :return: a tuple of the Celery task key and the Celery state of the task
    :rtype: (str, str)
    """

    try:
        with timeout(seconds=2):
            # Accessing state property of celery task will make actual network request
            # to get the current state of the task.
            res = (celery_task[0], celery_task[1].state)
    except Exception as e:
        exception_traceback = "Celery Task ID: {}\n{}".format(celery_task[0],
                                                              traceback.format_exc())
        res = ExceptionWithTraceback(e, exception_traceback)
    return res
Beispiel #15
0
    def execution_parallelism(self, parallelism=0):
        executor = LocalExecutor(parallelism=parallelism)
        executor.start()

        success_key = 'success {}'
        success_command = ['true', 'some_parameter']
        fail_command = ['false', 'some_parameter']

        for i in range(self.TEST_SUCCESS_COMMANDS):
            key, command = success_key.format(i), success_command
            executor.execute_async(key=key, command=command)
            executor.running[key] = True

        # errors are propagated for some reason
        try:
            executor.execute_async(key='fail', command=fail_command)
        except Exception:
            pass

        executor.running['fail'] = True

        if parallelism == 0:
            with timeout(seconds=10):
                executor.end()
        else:
            executor.end()

        if isinstance(executor.impl, LocalExecutor._LimitedParallelism):
            self.assertTrue(executor.queue.empty())

        self.assertEqual(len(executor.running), 0)
        self.assertTrue(executor.result_queue.empty())

        for i in range(self.TEST_SUCCESS_COMMANDS):
            key = success_key.format(i)
            self.assertEqual(executor.event_buffer[key], State.SUCCESS)
        self.assertEqual(executor.event_buffer['fail'], State.FAILED)

        expected = self.TEST_SUCCESS_COMMANDS + 1 if parallelism == 0 else parallelism
        self.assertEqual(executor.workers_used, expected)
Beispiel #16
0
def fetch_celery_task_state(celery_task: Tuple[TaskInstanceKeyType, AsyncResult]) \
        -> Union[TaskInstanceStateType, ExceptionWithTraceback]:
    """
    Fetch and return the state of the given celery task. The scope of this function is
    global so that it can be called by subprocesses in the pool.

    :param celery_task: a tuple of the Celery task key and the async Celery object used
        to fetch the task's state
    :type celery_task: tuple(str, celery.result.AsyncResult)
    :return: a tuple of the Celery task key and the Celery state of the task
    :rtype: tuple[str, str]
    """

    try:
        with timeout(seconds=OPERATION_TIMEOUT):
            # Accessing state property of celery task will make actual network request
            # to get the current state of the task.
            return celery_task[0], celery_task[1].state
    except Exception as e:  # pylint: disable=broad-except
        exception_traceback = "Celery Task ID: {}\n{}".format(
            celery_task[0], traceback.format_exc())
        return ExceptionWithTraceback(e, exception_traceback)
    def execution_parallelism(self, parallelism=0):
        executor = LocalExecutor(parallelism=parallelism)
        executor.start()

        success_key = 'success {}'
        success_command = 'echo {}'
        fail_command = 'exit 1'

        for i in range(self.TEST_SUCCESS_COMMANDS):
            key, command = success_key.format(i), success_command.format(i)
            executor.execute_async(key=key, command=command)
            executor.running[key] = True

        # errors are propagated for some reason
        try:
            executor.execute_async(key='fail', command=fail_command)
        except:
            pass

        executor.running['fail'] = True

        if parallelism == 0:
            with timeout(seconds=5):
                executor.end()
        else:
            executor.end()

        for i in range(self.TEST_SUCCESS_COMMANDS):
            key = success_key.format(i)
            self.assertTrue(executor.event_buffer[key], State.SUCCESS)
        self.assertTrue(executor.event_buffer['fail'], State.FAILED)

        for i in range(self.TEST_SUCCESS_COMMANDS):
            self.assertNotIn(success_key.format(i), executor.running)
        self.assertNotIn('fail', executor.running)

        expected = self.TEST_SUCCESS_COMMANDS + 1 if parallelism == 0 else parallelism
        self.assertEqual(executor.workers_used, expected)
Beispiel #18
0
    def process_file(self, filepath, only_if_updated=True):
        """
        Given a path to a python module or zip file, this method imports
        the module and look for dag objects within it.
        """
        found_dags = []
        # if the source file no longer exists in the DB or in the filesystem,
        # return an empty list
        # todo: raise exception?
        if filepath is None or not os.path.isfile(filepath):
            return found_dags

        try:
            # This failed before in what may have been a git sync
            # race condition
            file_last_changed = datetime.fromtimestamp(
                os.path.getmtime(filepath))
            if only_if_updated \
                    and filepath in self.dagbag.file_last_changed \
                    and file_last_changed == self.dagbag.file_last_changed[filepath]:
                return found_dags

        except Exception as e:
            self.log.exception(e)
            return found_dags

        mods = []
        if not zipfile.is_zipfile(filepath):
            if self.safe_mode and os.path.isfile(filepath):
                with open(filepath, 'rb') as f:
                    content = f.read()
                    if not all([s in content for s in (b'DAG', b'airflow')]):
                        self.dagbag.file_last_changed[
                            filepath] = file_last_changed
                        return found_dags

            self.log.debug("Importing %s", filepath)
            org_mod_name, _ = os.path.splitext(os.path.split(filepath)[-1])
            mod_name = ('unusual_prefix_' +
                        hashlib.sha1(filepath.encode('utf-8')).hexdigest() +
                        '_' + org_mod_name)

            if mod_name in sys.modules:
                del sys.modules[mod_name]

            with timeout(configuration.getint('core',
                                              "DAGBAG_IMPORT_TIMEOUT")):
                try:
                    m = imp.load_source(mod_name, filepath)
                    mods.append(m)
                except Exception as e:
                    self.log.exception("Failed to import: %s", filepath)
                    self.dagbag.import_errors[filepath] = str(e)
                    self.dagbag.file_last_changed[filepath] = file_last_changed

        else:
            zip_file = zipfile.ZipFile(filepath)
            for mod in zip_file.infolist():
                head, _ = os.path.split(mod.filename)
                mod_name, ext = os.path.splitext(mod.filename)
                if not head and (ext == '.py' or ext == '.pyc'):
                    if mod_name == '__init__':
                        self.log.warning("Found __init__.%s at root of %s",
                                         ext, filepath)
                    if self.safe_mode:
                        with zip_file.open(mod.filename) as zf:
                            self.log.debug("Reading %s from %s", mod.filename,
                                           filepath)
                            content = zf.read()
                            if not all(
                                [s in content for s in (b'DAG', b'airflow')]):
                                self.dagbag.file_last_changed[filepath] = (
                                    file_last_changed)
                                # todo: create ignore list
                                return found_dags

                    if mod_name in sys.modules:
                        del sys.modules[mod_name]

                    try:
                        sys.path.insert(0, filepath)
                        m = importlib.import_module(mod_name)
                        mods.append(m)
                    except Exception as e:
                        self.log.exception("Failed to import: %s", filepath)
                        self.dagbag.import_errors[filepath] = str(e)
                        self.dagbag.file_last_changed[
                            filepath] = file_last_changed

        for m in mods:
            for dag in list(m.__dict__.values()):
                if isinstance(dag, airflow.models.DAG):
                    if not dag.full_filepath:
                        dag.full_filepath = filepath
                        if dag.fileloc != filepath:
                            dag.fileloc = filepath
                    try:
                        dag.is_subdag = False
                        self.dagbag.bag_dag(dag, parent_dag=dag, root_dag=dag)
                        found_dags.append(dag)
                        found_dags += dag.subdags
                    except AirflowDagCycleException as cycle_exception:
                        self.log.exception("Failed to bag_dag: %s",
                                           dag.full_filepath)
                        self.dagbag.import_errors[dag.full_filepath] = \
                            str(cycle_exception)
                        self.dagbag.file_last_changed[dag.full_filepath] = \
                            file_last_changed

        self.dagbag.file_last_changed[filepath] = file_last_changed
        return found_dags
Beispiel #19
0
 def run_with_timeout():
     with timeout(seconds=30):
         job.run()
Beispiel #20
0
    def process_file(self, filepath, only_if_updated=True, safe_mode=True):
        """
        Given a path to a python module or zip file, this method imports
        the module and look for dag objects within it.
        """
        from airflow.models.dag import DAG  # Avoid circular import

        found_dags = []

        # if the source file no longer exists in the DB or in the filesystem,
        # return an empty list
        # todo: raise exception?
        if filepath is None or not os.path.isfile(filepath):
            return found_dags

        try:
            # This failed before in what may have been a git sync
            # race condition
            file_last_changed_on_disk = datetime.fromtimestamp(
                os.path.getmtime(filepath))
            if only_if_updated \
                    and filepath in self.file_last_changed \
                    and file_last_changed_on_disk == self.file_last_changed[filepath]:
                return found_dags

        except Exception as e:
            self.log.exception(e)
            return found_dags

        mods = []
        is_zipfile = zipfile.is_zipfile(filepath)
        if not is_zipfile:
            if safe_mode:
                with open(filepath, 'rb') as f:
                    content = f.read()
                    if not all([s in content for s in (b'DAG', b'airflow')]):
                        self.file_last_changed[
                            filepath] = file_last_changed_on_disk
                        # Don't want to spam user with skip messages
                        if not self.has_logged:
                            self.has_logged = True
                            self.log.info(
                                "File %s assumed to contain no DAGs. Skipping.",
                                filepath)
                        return found_dags

            self.log.debug("Importing %s", filepath)
            org_mod_name, _ = os.path.splitext(os.path.split(filepath)[-1])
            mod_name = ('unusual_prefix_' +
                        hashlib.sha1(filepath.encode('utf-8')).hexdigest() +
                        '_' + org_mod_name)

            if mod_name in sys.modules:
                del sys.modules[mod_name]

            with timeout(self.DAGBAG_IMPORT_TIMEOUT):
                try:
                    m = imp.load_source(mod_name, filepath)
                    mods.append(m)
                except Exception as e:
                    self.log.exception("Failed to import: %s", filepath)
                    self.import_errors[filepath] = str(e)
                    self.file_last_changed[
                        filepath] = file_last_changed_on_disk

        else:
            zip_file = zipfile.ZipFile(filepath)
            for mod in zip_file.infolist():
                head, _ = os.path.split(mod.filename)
                mod_name, ext = os.path.splitext(mod.filename)
                if not head and (ext == '.py' or ext == '.pyc'):
                    if mod_name == '__init__':
                        self.log.warning("Found __init__.%s at root of %s",
                                         ext, filepath)
                    if safe_mode:
                        with zip_file.open(mod.filename) as zf:
                            self.log.debug("Reading %s from %s", mod.filename,
                                           filepath)
                            content = zf.read()
                            if not all(
                                [s in content for s in (b'DAG', b'airflow')]):
                                self.file_last_changed[filepath] = (
                                    file_last_changed_on_disk)
                                # todo: create ignore list
                                # Don't want to spam user with skip messages
                                if not self.has_logged:
                                    self.has_logged = True
                                    self.log.info(
                                        "File %s assumed to contain no DAGs. Skipping.",
                                        filepath)

                    if mod_name in sys.modules:
                        del sys.modules[mod_name]

                    try:
                        sys.path.insert(0, filepath)
                        m = importlib.import_module(mod_name)
                        mods.append(m)
                    except Exception as e:
                        self.log.exception("Failed to import: %s", filepath)
                        self.import_errors[filepath] = str(e)
                        self.file_last_changed[
                            filepath] = file_last_changed_on_disk

        for m in mods:
            for dag in list(m.__dict__.values()):
                if isinstance(dag, DAG):
                    if not dag.full_filepath:
                        dag.full_filepath = filepath
                        if dag.fileloc != filepath and not is_zipfile:
                            dag.fileloc = filepath
                    try:
                        dag.is_subdag = False
                        self.bag_dag(dag, parent_dag=dag, root_dag=dag)
                        if isinstance(dag._schedule_interval,
                                      six.string_types):
                            croniter(dag._schedule_interval)
                        found_dags.append(dag)
                        found_dags += dag.subdags
                    except (CroniterBadCronError, CroniterBadDateError,
                            CroniterNotAlphaError) as cron_e:
                        self.log.exception("Failed to bag_dag: %s",
                                           dag.full_filepath)
                        self.import_errors[dag.full_filepath] = \
                            "Invalid Cron expression: " + str(cron_e)
                        self.file_last_changed[dag.full_filepath] = \
                            file_last_changed_on_disk
                    except AirflowDagCycleException as cycle_exception:
                        self.log.exception("Failed to bag_dag: %s",
                                           dag.full_filepath)
                        self.import_errors[dag.full_filepath] = str(
                            cycle_exception)
                        self.file_last_changed[dag.full_filepath] = \
                            file_last_changed_on_disk

        self.file_last_changed[filepath] = file_last_changed_on_disk
        return found_dags
Beispiel #21
0
    def test_failure_callback_only_called_once(self, mock_return_code, _check_call):
        """
        Test that ensures that when a task exits with failure by itself,
        failure callback is only called once
        """
        # use shared memory value so we can properly track value change even if
        # it's been updated across processes.
        failure_callback_called = Value('i', 0)
        callback_count_lock = Lock()

        def failure_callback(context):
            with callback_count_lock:
                failure_callback_called.value += 1
            assert context['dag_run'].dag_id == 'test_failure_callback_race'
            assert isinstance(context['exception'], AirflowFailException)

        def task_function(ti):
            raise AirflowFailException()

        dag = DAG(dag_id='test_failure_callback_race', start_date=DEFAULT_DATE)
        task = PythonOperator(
            task_id='test_exit_on_failure',
            python_callable=task_function,
            on_failure_callback=failure_callback,
            dag=dag,
        )

        dag.clear()
        with create_session() as session:
            dag.create_dagrun(
                run_id="test",
                state=State.RUNNING,
                execution_date=DEFAULT_DATE,
                start_date=DEFAULT_DATE,
                session=session,
            )
        ti = TaskInstance(task=task, execution_date=DEFAULT_DATE)
        ti.refresh_from_db()

        job1 = LocalTaskJob(task_instance=ti, ignore_ti_state=True, executor=SequentialExecutor())

        # Simulate race condition where job1 heartbeat ran right after task
        # state got set to failed by ti.handle_failure but before task process
        # fully exits. See _execute loop in airflow/jobs/local_task_job.py.
        # In this case, we have:
        #  * task_runner.return_code() is None
        #  * ti.state == State.Failed
        #
        # We also need to set return_code to a valid int after job1.terminating
        # is set to True so _execute loop won't loop forever.
        def dummy_return_code(*args, **kwargs):
            return None if not job1.terminating else -9

        mock_return_code.side_effect = dummy_return_code

        with timeout(10):
            # This should be _much_ shorter to run.
            # If you change this limit, make the timeout in the callbable above bigger
            job1.run()

        ti.refresh_from_db()
        assert ti.state == State.FAILED  # task exits with failure state
        assert failure_callback_called.value == 1
Beispiel #22
0
 def run_with_timeout():
     with timeout(seconds=30):
         job.run()