Пример #1
0
    def test_default_pool_open_slots(self):
        set_default_pool_slots(5)
        assert 5 == Pool.get_default_pool().open_slots()

        dag = DAG(
            dag_id='test_default_pool_open_slots',
            start_date=DEFAULT_DATE,
        )
        op1 = DummyOperator(task_id='dummy1', dag=dag)
        op2 = DummyOperator(task_id='dummy2', dag=dag, pool_slots=2)
        ti1 = TI(task=op1, execution_date=DEFAULT_DATE)
        ti2 = TI(task=op2, execution_date=DEFAULT_DATE)
        ti1.state = State.RUNNING
        ti2.state = State.QUEUED

        session = settings.Session
        session.add(ti1)
        session.add(ti2)
        session.commit()
        session.close()

        assert 2 == Pool.get_default_pool().open_slots()
        assert {
            "default_pool": {
                "open": 2,
                "queued": 2,
                "total": 5,
                "running": 1,
            }
        } == Pool.slots_stats()
Пример #2
0
    def kill_zombies(self, zombies, session=None):
        """
        Fail given zombie tasks, which are tasks that haven't
        had a heartbeat for too long, in the current DagBag.

        :param zombies: zombie task instances to kill.
        :type zombies: airflow.utils.dag_processing.SimpleTaskInstance
        :param session: DB session.
        :type session: sqlalchemy.orm.session.Session
        """
        from airflow.models.taskinstance import TaskInstance  # Avoid circular import

        for zombie in zombies:
            if zombie.dag_id in self.dags:
                dag = self.dags[zombie.dag_id]
                if zombie.task_id in dag.task_ids:
                    task = dag.get_task(zombie.task_id)
                    ti = TaskInstance(task, zombie.execution_date)
                    # Get properties needed for failure handling from SimpleTaskInstance.
                    ti.start_date = zombie.start_date
                    ti.end_date = zombie.end_date
                    ti.try_number = zombie.try_number
                    ti.state = zombie.state
                    ti.test_mode = self.UNIT_TEST_MODE
                    ti.handle_failure("{} detected as zombie".format(ti),
                                      ti.test_mode, ti.get_template_context())
                    self.log.info('Marked zombie job %s as %s', ti, ti.state)
        session.commit()
Пример #3
0
    def _set_state_to_skipped(self, dag_run, execution_date, tasks, session):
        """
        Used internally to set state of task instances to skipped from the same dag run.
        """
        task_ids = [d.task_id for d in tasks]
        now = timezone.utcnow()

        if dag_run:
            session.query(TaskInstance).filter(
                TaskInstance.dag_id == dag_run.dag_id,
                TaskInstance.execution_date == dag_run.execution_date,
                TaskInstance.task_id.in_(task_ids),
            ).update(
                {
                    TaskInstance.state: State.SKIPPED,
                    TaskInstance.start_date: now,
                    TaskInstance.end_date: now,
                },
                synchronize_session=False,
            )
        else:
            if execution_date is None:
                raise ValueError("Execution date is None and no dag run")

            self.log.warning("No DAG RUN present this should not happen")
            # this is defensive against dag runs that are not complete
            for task in tasks:
                ti = TaskInstance(task, execution_date=execution_date)
                ti.state = State.SKIPPED
                ti.start_date = now
                ti.end_date = now
                session.merge(ti)
Пример #4
0
    def test_default_pool_open_slots(self):
        dag = DAG(
            dag_id='test_default_pool_open_slots',
            start_date=DEFAULT_DATE, )
        t1 = DummyOperator(task_id='dummy1', dag=dag)
        t2 = DummyOperator(task_id='dummy2', dag=dag)
        ti1 = TI(task=t1, execution_date=DEFAULT_DATE)
        ti2 = TI(task=t2, execution_date=DEFAULT_DATE)
        ti1.state = State.RUNNING
        ti2.state = State.QUEUED
        ti1.pool = Pool.default_pool_name
        ti2.pool = Pool.default_pool_name

        session = settings.Session
        session.add(ti1)
        session.add(ti2)
        session.commit()
        session.close()

        self.assertEqual(3, Pool.default_pool_open_slots())
Пример #5
0
    def test_mark_success_on_success_callback(self):
        """
        Test that ensures that where a task is marked suceess in the UI
        on_success_callback gets executed
        """
        data = {'called': False}

        def success_callback(context):
            self.assertEqual(context['dag_run'].dag_id, 'test_mark_success')
            data['called'] = True

        dag = DAG(dag_id='test_mark_success',
                  start_date=DEFAULT_DATE,
                  default_args={'owner': 'owner1'})

        task = DummyOperator(task_id='test_state_succeeded1',
                             dag=dag,
                             on_success_callback=success_callback)

        session = settings.Session()

        dag.clear()
        dag.create_dagrun(
            run_id="test",
            state=State.RUNNING,
            execution_date=DEFAULT_DATE,
            start_date=DEFAULT_DATE,
            session=session,
        )
        ti = TaskInstance(task=task, execution_date=DEFAULT_DATE)
        ti.refresh_from_db()
        job1 = LocalTaskJob(task_instance=ti,
                            ignore_ti_state=True,
                            executor=SequentialExecutor())
        from airflow.task.task_runner.standard_task_runner import StandardTaskRunner

        job1.task_runner = StandardTaskRunner(job1)
        process = multiprocessing.Process(target=job1.run)
        process.start()
        ti.refresh_from_db()
        for _ in range(0, 50):
            if ti.state == State.RUNNING:
                break
            time.sleep(0.1)
            ti.refresh_from_db()
        self.assertEqual(State.RUNNING, ti.state)
        ti.state = State.SUCCESS
        session.merge(ti)
        session.commit()

        job1.heartbeat_callback(session=None)
        self.assertTrue(data['called'])
        process.join(timeout=10)
        self.assertFalse(process.is_alive())
Пример #6
0
    def test_open_slots(self):
        pool = Pool(pool='test_pool', slots=5)
        dag = DAG(
            dag_id='test_open_slots',
            start_date=DEFAULT_DATE, )
        t1 = DummyOperator(task_id='dummy1', dag=dag, pool='test_pool')
        t2 = DummyOperator(task_id='dummy2', dag=dag, pool='test_pool')
        ti1 = TI(task=t1, execution_date=DEFAULT_DATE)
        ti2 = TI(task=t2, execution_date=DEFAULT_DATE)
        ti1.state = State.RUNNING
        ti2.state = State.QUEUED

        session = settings.Session
        session.add(pool)
        session.add(ti1)
        session.add(ti2)
        session.commit()
        session.close()

        self.assertEqual(3, pool.open_slots())
Пример #7
0
    def test_open_slots(self):
        pool = Pool(pool='test_pool', slots=5)
        dag = DAG(
            dag_id='test_open_slots',
            start_date=DEFAULT_DATE,
        )
        op1 = DummyOperator(task_id='dummy1', dag=dag, pool='test_pool')
        op2 = DummyOperator(task_id='dummy2', dag=dag, pool='test_pool')
        ti1 = TI(task=op1, execution_date=DEFAULT_DATE)
        ti2 = TI(task=op2, execution_date=DEFAULT_DATE)
        ti1.state = State.RUNNING
        ti2.state = State.QUEUED

        session = settings.Session
        session.add(pool)
        session.add(ti1)
        session.add(ti2)
        session.commit()
        session.close()

        self.assertEqual(3, pool.open_slots())  # pylint: disable=no-value-for-parameter
        self.assertEqual(1, pool.running_slots())  # pylint: disable=no-value-for-parameter
        self.assertEqual(1, pool.queued_slots())  # pylint: disable=no-value-for-parameter
        self.assertEqual(2, pool.occupied_slots())  # pylint: disable=no-value-for-parameter
        self.assertEqual(
            {
                "default_pool": {
                    "open": 128,
                    "queued": 0,
                    "total": 128,
                    "running": 0,
                },
                "test_pool": {
                    "open": 3,
                    "queued": 1,
                    "running": 1,
                    "total": 5,
                },
            },
            pool.slots_stats(),
        )
Пример #8
0
    def test_default_pool_open_slots(self):
        set_default_pool_slots(5)
        self.assertEqual(5, Pool.get_default_pool().open_slots())

        dag = DAG(
            dag_id='test_default_pool_open_slots',
            start_date=DEFAULT_DATE, )
        op1 = DummyOperator(task_id='dummy1', dag=dag)
        op2 = DummyOperator(task_id='dummy2', dag=dag, pool_slots=2)
        ti1 = TI(task=op1, execution_date=DEFAULT_DATE)
        ti2 = TI(task=op2, execution_date=DEFAULT_DATE)
        ti1.state = State.RUNNING
        ti2.state = State.QUEUED

        session = settings.Session
        session.add(ti1)
        session.add(ti2)
        session.commit()
        session.close()

        self.assertEqual(2, Pool.get_default_pool().open_slots())
Пример #9
0
    def test_infinite_slots(self):
        pool = Pool(pool='test_pool', slots=-1)
        dag = DAG(
            dag_id='test_infinite_slots',
            start_date=DEFAULT_DATE, )
        op1 = DummyOperator(task_id='dummy1', dag=dag, pool='test_pool')
        op2 = DummyOperator(task_id='dummy2', dag=dag, pool='test_pool')
        ti1 = TI(task=op1, execution_date=DEFAULT_DATE)
        ti2 = TI(task=op2, execution_date=DEFAULT_DATE)
        ti1.state = State.RUNNING
        ti2.state = State.QUEUED

        session = settings.Session
        session.add(pool)
        session.add(ti1)
        session.add(ti2)
        session.commit()
        session.close()

        self.assertEqual(float('inf'), pool.open_slots())  # pylint: disable=no-value-for-parameter
        self.assertEqual(1, pool.used_slots())  # pylint: disable=no-value-for-parameter
        self.assertEqual(1, pool.queued_slots())  # pylint: disable=no-value-for-parameter
        self.assertEqual(2, pool.occupied_slots())  # pylint: disable=no-value-for-parameter
Пример #10
0
    def test_heartbeat_failed_fast(self):
        """
        Test that task heartbeat will sleep when it fails fast
        """
        self.mock_base_job_sleep.side_effect = time.sleep

        with create_session() as session:
            dagbag = DagBag(
                dag_folder=TEST_DAG_FOLDER,
                include_examples=False,
            )
            dag_id = 'test_heartbeat_failed_fast'
            task_id = 'test_heartbeat_failed_fast_op'
            dag = dagbag.get_dag(dag_id)
            task = dag.get_task(task_id)

            dag.create_dagrun(
                run_id="test_heartbeat_failed_fast_run",
                state=State.RUNNING,
                execution_date=DEFAULT_DATE,
                start_date=DEFAULT_DATE,
                session=session,
            )
            ti = TaskInstance(task=task, execution_date=DEFAULT_DATE)
            ti.refresh_from_db()
            ti.state = State.RUNNING
            ti.hostname = get_hostname()
            ti.pid = 1
            session.commit()

            job = LocalTaskJob(task_instance=ti,
                               executor=MockExecutor(do_update=False))
            job.heartrate = 2
            heartbeat_records = []
            job.heartbeat_callback = lambda session: heartbeat_records.append(
                job.latest_heartbeat)
            job._execute()
            self.assertGreater(len(heartbeat_records), 2)
            for i in range(1, len(heartbeat_records)):
                time1 = heartbeat_records[i - 1]
                time2 = heartbeat_records[i]
                # Assert that difference small enough
                delta = (time2 - time1).total_seconds()
                self.assertAlmostEqual(delta, job.heartrate, delta=0.05)
Пример #11
0
    def test_mark_success_no_kill(self):
        """
        Test that ensures that mark_success in the UI doesn't cause
        the task to fail, and that the task exits
        """
        dagbag = DagBag(
            dag_folder=TEST_DAG_FOLDER,
            include_examples=False,
        )
        dag = dagbag.dags.get('test_mark_success')
        task = dag.get_task('task1')

        session = settings.Session()

        dag.clear()
        dag.create_dagrun(
            run_id="test",
            state=State.RUNNING,
            execution_date=DEFAULT_DATE,
            start_date=DEFAULT_DATE,
            session=session,
        )
        ti = TaskInstance(task=task, execution_date=DEFAULT_DATE)
        ti.refresh_from_db()
        job1 = LocalTaskJob(task_instance=ti, ignore_ti_state=True)
        process = multiprocessing.Process(target=job1.run)
        process.start()
        ti.refresh_from_db()
        for _ in range(0, 50):
            if ti.state == State.RUNNING:
                break
            time.sleep(0.1)
            ti.refresh_from_db()
        self.assertEqual(State.RUNNING, ti.state)
        ti.state = State.SUCCESS
        session.merge(ti)
        session.commit()

        process.join(timeout=10)
        self.assertFalse(process.is_alive())
        ti.refresh_from_db()
        self.assertEqual(State.SUCCESS, ti.state)
    def skip(self, dag_run, execution_date, tasks, session=None):
        """
        Sets tasks instances to skipped from the same dag run.

        :param dag_run: the DagRun for which to set the tasks to skipped
        :param execution_date: execution_date
        :param tasks: tasks to skip (not task_ids)
        :param session: db session to use
        """
        if not tasks:
            return

        task_ids = [d.task_id for d in tasks]
        now = timezone.utcnow()

        if dag_run:
            session.query(TaskInstance).filter(
                TaskInstance.dag_id == dag_run.dag_id,
                TaskInstance.execution_date == dag_run.execution_date,
                TaskInstance.task_id.in_(task_ids)).update(
                    {
                        TaskInstance.state: State.SKIPPED,
                        TaskInstance.start_date: now,
                        TaskInstance.end_date: now
                    },
                    synchronize_session=False)
            session.commit()
        else:
            if execution_date is None:
                raise ValueError("Execution date is None and no dag run")

            self.log.warning("No DAG RUN present this should not happen")
            # this is defensive against dag runs that are not complete
            for task in tasks:
                ti = TaskInstance(task, execution_date=execution_date)
                ti.state = State.SKIPPED
                ti.start_date = now
                ti.end_date = now
                session.merge(ti)

            session.commit()
Пример #13
0
    def skip(self, dag_run, execution_date, tasks, session=None):
        """
        Sets tasks instances to skipped from the same dag run.

        :param dag_run: the DagRun for which to set the tasks to skipped
        :param execution_date: execution_date
        :param tasks: tasks to skip (not task_ids)
        :param session: db session to use
        """
        if not tasks:
            return

        task_ids = [d.task_id for d in tasks]
        now = timezone.utcnow()

        if dag_run:
            session.query(TaskInstance).filter(
                TaskInstance.dag_id == dag_run.dag_id,
                TaskInstance.execution_date == dag_run.execution_date,
                TaskInstance.task_id.in_(task_ids)
            ).update({TaskInstance.state: State.SKIPPED,
                      TaskInstance.start_date: now,
                      TaskInstance.end_date: now},
                     synchronize_session=False)
            session.commit()
        else:
            assert execution_date is not None, "Execution date is None and no dag run"

            self.log.warning("No DAG RUN present this should not happen")
            # this is defensive against dag runs that are not complete
            for task in tasks:
                ti = TaskInstance(task, execution_date=execution_date)
                ti.state = State.SKIPPED
                ti.start_date = now
                ti.end_date = now
                session.merge(ti)

            session.commit()
Пример #14
0
    def _kill_zombies(self, dag, zombies, session):
        """
        copy paste from airflow.models.dagbag.DagBag.kill_zombies
        """
        from airflow.models.taskinstance import TaskInstance  # Avoid circular import

        for zombie in zombies:
            if zombie.task_id in dag.task_ids:
                task = dag.get_task(zombie.task_id)
                ti = TaskInstance(task, zombie.execution_date)
                # Get properties needed for failure handling from SimpleTaskInstance.
                ti.start_date = zombie.start_date
                ti.end_date = zombie.end_date
                ti.try_number = zombie.try_number
                ti.state = zombie.state
                # ti.test_mode = self.UNIT_TEST_MODE
                ti.handle_failure(
                    "{} detected as zombie".format(ti),
                    ti.test_mode,
                    ti.get_template_context(),
                )
                self.log.info("Marked zombie job %s as %s", ti, ti.state)
        session.commit()
Пример #15
0
            def _per_task_process(key, ti: TaskInstance, session=None):
                ti.refresh_from_db(lock_for_update=True, session=session)

                task = self.dag.get_task(ti.task_id, include_subdags=True)
                ti.task = task

                self.log.debug("Task instance to run %s state %s", ti,
                               ti.state)

                # The task was already marked successful or skipped by a
                # different Job. Don't rerun it.
                if ti.state == State.SUCCESS and not self.rerun_succeeded_tasks:
                    ti_status.succeeded.add(key)
                    self.log.debug("Task instance %s succeeded. Don't rerun.",
                                   ti)
                    ti_status.to_run.pop(key)
                    if key in ti_status.running:
                        ti_status.running.pop(key)
                    return
                elif ti.state == State.SKIPPED:
                    ti_status.skipped.add(key)
                    self.log.debug("Task instance %s skipped. Don't rerun.",
                                   ti)
                    ti_status.to_run.pop(key)
                    if key in ti_status.running:
                        ti_status.running.pop(key)
                    return

                # guard against externally modified tasks instances or
                # in case max concurrency has been reached at task runtime
                elif ti.state == State.NONE:
                    self.log.warning(
                        "FIXME: Task instance %s state was set to None externally. This should not happen",
                        ti)
                    ti.set_state(State.SCHEDULED, session=session)
                if self.rerun_failed_tasks:
                    # Rerun failed tasks or upstreamed failed tasks
                    if ti.state in (State.FAILED, State.UPSTREAM_FAILED):
                        self.log.error("Task instance %s with state %s", ti,
                                       ti.state)
                        if key in ti_status.running:
                            ti_status.running.pop(key)
                        # Reset the failed task in backfill to scheduled state
                        ti.set_state(State.SCHEDULED, session=session)
                elif self.rerun_succeeded_tasks and ti.state == State.SUCCESS:
                    # Rerun succeeded tasks
                    self.log.info(
                        "Task instance %s with state %s, rerunning succeeded task ",
                        ti, ti.state)
                    if key in ti_status.running:
                        ti_status.running.pop(key)
                    # Reset the succeeded task in backfill to scheduled state
                    ti.set_state(State.SCHEDULED, session=session)
                else:
                    # Default behaviour which works for subdag.
                    if ti.state in (State.FAILED, State.UPSTREAM_FAILED):
                        self.log.error("Task instance %s with state %s", ti,
                                       ti.state)
                        ti_status.failed.add(key)
                        ti_status.to_run.pop(key)
                        if key in ti_status.running:
                            ti_status.running.pop(key)
                        return

                if self.ignore_first_depends_on_past:
                    dagrun = ti.get_dagrun(session=session)
                    ignore_depends_on_past = dagrun.execution_date == (
                        start_date or ti.start_date)
                else:
                    ignore_depends_on_past = False

                backfill_context = DepContext(
                    deps=BACKFILL_QUEUED_DEPS,
                    ignore_depends_on_past=ignore_depends_on_past,
                    ignore_task_deps=self.ignore_task_deps,
                    flag_upstream_failed=True,
                )

                # Is the task runnable? -- then run it
                # the dependency checker can change states of tis
                if ti.are_dependencies_met(dep_context=backfill_context,
                                           session=session,
                                           verbose=self.verbose):
                    if executor.has_task(ti):
                        self.log.debug(
                            "Task Instance %s already in executor waiting for queue to clear",
                            ti)
                    else:
                        self.log.debug('Sending %s to executor', ti)
                        # Skip scheduled state, we are executing immediately
                        ti.state = State.QUEUED
                        ti.queued_by_job_id = self.id
                        ti.queued_dttm = timezone.utcnow()
                        session.merge(ti)

                        cfg_path = None
                        if self.executor_class in (
                                executor_constants.LOCAL_EXECUTOR,
                                executor_constants.SEQUENTIAL_EXECUTOR,
                        ):
                            cfg_path = tmp_configuration_copy()

                        executor.queue_task_instance(
                            ti,
                            mark_success=self.mark_success,
                            pickle_id=pickle_id,
                            ignore_task_deps=self.ignore_task_deps,
                            ignore_depends_on_past=ignore_depends_on_past,
                            pool=self.pool,
                            cfg_path=cfg_path,
                        )
                        ti_status.running[key] = ti
                        ti_status.to_run.pop(key)
                    session.commit()
                    return

                if ti.state == State.UPSTREAM_FAILED:
                    self.log.error("Task instance %s upstream failed", ti)
                    ti_status.failed.add(key)
                    ti_status.to_run.pop(key)
                    if key in ti_status.running:
                        ti_status.running.pop(key)
                    return

                # special case
                if ti.state == State.UP_FOR_RETRY:
                    self.log.debug(
                        "Task instance %s retry period not expired yet", ti)
                    if key in ti_status.running:
                        ti_status.running.pop(key)
                    ti_status.to_run[key] = ti
                    return

                # special case
                if ti.state == State.UP_FOR_RESCHEDULE:
                    self.log.debug(
                        "Task instance %s reschedule period not expired yet",
                        ti)
                    if key in ti_status.running:
                        ti_status.running.pop(key)
                    ti_status.to_run[key] = ti
                    return

                # all remaining tasks
                self.log.debug('Adding %s to not_ready', ti)
                ti_status.not_ready.add(key)
Пример #16
0
    def test_mark_success_on_success_callback(self):
        """
        Test that ensures that where a task is marked suceess in the UI
        on_success_callback gets executed
        """
        # use shared memory value so we can properly track value change even if
        # it's been updated across processes.
        success_callback_called = Value('i', 0)
        task_terminated_externally = Value('i', 1)
        shared_mem_lock = Lock()

        def success_callback(context):
            with shared_mem_lock:
                success_callback_called.value += 1
            assert context['dag_run'].dag_id == 'test_mark_success'

        dag = DAG(dag_id='test_mark_success', start_date=DEFAULT_DATE, default_args={'owner': 'owner1'})

        def task_function(ti):
            # pylint: disable=unused-argument
            time.sleep(60)
            # This should not happen -- the state change should be noticed and the task should get killed
            with shared_mem_lock:
                task_terminated_externally.value = 0

        task = PythonOperator(
            task_id='test_state_succeeded1',
            python_callable=task_function,
            on_success_callback=success_callback,
            dag=dag,
        )

        session = settings.Session()

        dag.clear()
        dag.create_dagrun(
            run_id="test",
            state=State.RUNNING,
            execution_date=DEFAULT_DATE,
            start_date=DEFAULT_DATE,
            session=session,
        )
        ti = TaskInstance(task=task, execution_date=DEFAULT_DATE)
        ti.refresh_from_db()
        job1 = LocalTaskJob(task_instance=ti, ignore_ti_state=True, executor=SequentialExecutor())
        job1.task_runner = StandardTaskRunner(job1)

        settings.engine.dispose()
        process = multiprocessing.Process(target=job1.run)
        process.start()

        for _ in range(0, 25):
            ti.refresh_from_db()
            if ti.state == State.RUNNING:
                break
            time.sleep(0.2)
        assert ti.state == State.RUNNING
        ti.state = State.SUCCESS
        session.merge(ti)
        session.commit()

        process.join(timeout=10)
        assert success_callback_called.value == 1
        assert task_terminated_externally.value == 1
        assert not process.is_alive()