コード例 #1
0
 def test_delete_pool(self):
     self.client.create_pool(name='foo', slots=1, description='')
     with create_session() as session:
         self.assertEqual(session.query(models.Pool).count(), 1)
     self.client.delete_pool(name='foo')
     with create_session() as session:
         self.assertEqual(session.query(models.Pool).count(), 0)
コード例 #2
0
    def test_delete_dag(self):
        key = "my_dag_id"

        with create_session() as session:
            self.assertEqual(session.query(DagModel).filter(DagModel.dag_id == key).count(), 0)
            session.add(DagModel(dag_id=key))

        with create_session() as session:
            self.assertEqual(session.query(DagModel).filter(DagModel.dag_id == key).count(), 1)

            self.client.delete_dag(dag_id=key)
            self.assertEqual(session.query(DagModel).filter(DagModel.dag_id == key).count(), 0)
コード例 #3
0
    def test_without_dag_run(self):
        """This checks the defensive against non existent tasks in a dag run"""
        self.branch_op = BranchPythonOperator(task_id='make_choice',
                                              dag=self.dag,
                                              python_callable=lambda: 'branch_1')
        self.branch_1.set_upstream(self.branch_op)
        self.branch_2.set_upstream(self.branch_op)
        self.dag.clear()

        self.branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)

        with create_session() as session:
            tis = session.query(TI).filter(
                TI.dag_id == self.dag.dag_id,
                TI.execution_date == DEFAULT_DATE
            )

            for ti in tis:
                if ti.task_id == 'make_choice':
                    self.assertEqual(ti.state, State.SUCCESS)
                elif ti.task_id == 'branch_1':
                    # should exist with state None
                    self.assertEqual(ti.state, State.NONE)
                elif ti.task_id == 'branch_2':
                    self.assertEqual(ti.state, State.SKIPPED)
                else:
                    raise Exception
コード例 #4
0
 def snapshot_state(dag, execution_dates):
     TI = models.TaskInstance
     with create_session() as session:
         return session.query(TI).filter(
             TI.dag_id == dag.dag_id,
             TI.execution_date.in_(execution_dates)
         ).all()
コード例 #5
0
    def tearDown(self):
        self.dag1.clear()
        self.dag2.clear()

        with create_session() as session:
            session.query(models.DagRun).delete()
            session.query(models.TaskInstance).delete()
コード例 #6
0
    def setUp(self):
        self.key = "test_dag_id"

        task = DummyOperator(task_id='dummy',
                             dag=models.DAG(dag_id=self.key,
                                            default_args={'start_date': days_ago(2)}),
                             owner='airflow')

        d = days_ago(1)
        with create_session() as session:
            session.add(DM(dag_id=self.key))
            session.add(DR(dag_id=self.key))
            session.add(TI(task=task,
                           execution_date=d,
                           state=State.SUCCESS))
            # flush to ensure task instance if written before
            # task reschedule because of FK constraint
            session.flush()
            session.add(LOG(dag_id=self.key, task_id=None, task_instance=None,
                            execution_date=d, event="varimport"))
            session.add(TF(task=task, execution_date=d,
                           start_date=d, end_date=d))
            session.add(TR(task=task, execution_date=d,
                           start_date=d, end_date=d,
                           try_number=1, reschedule_date=d))
コード例 #7
0
    def test_none_skipped_tr_success(self):
        """
        None-skipped trigger rule success
        """

        ti = self._get_task_instance(TriggerRule.NONE_SKIPPED,
                                     upstream_task_ids=["FakeTaskID",
                                                        "OtherFakeTaskID",
                                                        "FailedFakeTaskID"])
        with create_session() as session:
            dep_statuses = tuple(TriggerRuleDep()._evaluate_trigger_rule(
                ti=ti,
                successes=2,
                skipped=0,
                failed=1,
                upstream_failed=0,
                done=3,
                flag_upstream_failed=False,
                session=session))
            self.assertEqual(len(dep_statuses), 0)

            # with `flag_upstream_failed` set to True
            dep_statuses = tuple(TriggerRuleDep()._evaluate_trigger_rule(
                ti=ti,
                successes=0,
                skipped=0,
                failed=3,
                upstream_failed=0,
                done=3,
                flag_upstream_failed=True,
                session=session))
            self.assertEqual(len(dep_statuses), 0)
コード例 #8
0
    def test_branch_list_without_dag_run(self):
        """This checks if the BranchPythonOperator supports branching off to a list of tasks."""
        self.branch_op = BranchPythonOperator(task_id='make_choice',
                                              dag=self.dag,
                                              python_callable=lambda: ['branch_1', 'branch_2'])
        self.branch_1.set_upstream(self.branch_op)
        self.branch_2.set_upstream(self.branch_op)
        self.branch_3 = DummyOperator(task_id='branch_3', dag=self.dag)
        self.branch_3.set_upstream(self.branch_op)
        self.dag.clear()

        self.branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)

        with create_session() as session:
            tis = session.query(TI).filter(
                TI.dag_id == self.dag.dag_id,
                TI.execution_date == DEFAULT_DATE
            )

            expected = {
                "make_choice": State.SUCCESS,
                "branch_1": State.NONE,
                "branch_2": State.NONE,
                "branch_3": State.SKIPPED,
            }

            for ti in tis:
                if ti.task_id in expected:
                    self.assertEqual(ti.state, expected[ti.task_id])
                else:
                    raise Exception
コード例 #9
0
    def test_none_skipped_tr_failure(self):
        """
        None-skipped trigger rule failure
        """
        ti = self._get_task_instance(TriggerRule.NONE_SKIPPED,
                                     upstream_task_ids=["FakeTaskID",
                                                        "SkippedTaskID"])

        with create_session() as session:
            dep_statuses = tuple(TriggerRuleDep()._evaluate_trigger_rule(
                ti=ti,
                successes=1,
                skipped=1,
                failed=0,
                upstream_failed=0,
                done=2,
                flag_upstream_failed=False,
                session=session))
            self.assertEqual(len(dep_statuses), 1)
            self.assertFalse(dep_statuses[0].passed)

            # with `flag_upstream_failed` set to True
            dep_statuses = tuple(TriggerRuleDep()._evaluate_trigger_rule(
                ti=ti,
                successes=1,
                skipped=1,
                failed=0,
                upstream_failed=0,
                done=2,
                flag_upstream_failed=True,
                session=session))
            self.assertEqual(len(dep_statuses), 1)
            self.assertFalse(dep_statuses[0].passed)
コード例 #10
0
    def tearDown(self):
        self.dag1.clear()
        self.dag2.clear()

        # just to make sure we are fully cleaned up
        with create_session() as session:
            session.query(models.DagRun).delete()
            session.query(models.TaskInstance).delete()
コード例 #11
0
 def tearDown(self):
     with create_session() as session:
         session.query(TR).filter(TR.dag_id == self.key).delete()
         session.query(TF).filter(TF.dag_id == self.key).delete()
         session.query(TI).filter(TI.dag_id == self.key).delete()
         session.query(DR).filter(DR.dag_id == self.key).delete()
         session.query(DM).filter(DM.dag_id == self.key).delete()
         session.query(LOG).filter(LOG.dag_id == self.key).delete()
コード例 #12
0
    def test_delete_dag_successful_delete(self):
        with create_session() as session:
            self.assertEqual(session.query(DM).filter(DM.dag_id == self.key).count(), 1)
            self.assertEqual(session.query(DR).filter(DR.dag_id == self.key).count(), 1)
            self.assertEqual(session.query(TI).filter(TI.dag_id == self.key).count(), 1)
            self.assertEqual(session.query(TF).filter(TF.dag_id == self.key).count(), 1)
            self.assertEqual(session.query(TR).filter(TR.dag_id == self.key).count(), 1)
            self.assertEqual(session.query(LOG).filter(LOG.dag_id == self.key).count(), 1)

        delete_dag(dag_id=self.key)

        with create_session() as session:
            self.assertEqual(session.query(DM).filter(DM.dag_id == self.key).count(), 0)
            self.assertEqual(session.query(DR).filter(DR.dag_id == self.key).count(), 0)
            self.assertEqual(session.query(TI).filter(TI.dag_id == self.key).count(), 0)
            self.assertEqual(session.query(TF).filter(TF.dag_id == self.key).count(), 0)
            self.assertEqual(session.query(TR).filter(TR.dag_id == self.key).count(), 0)
            self.assertEqual(session.query(LOG).filter(LOG.dag_id == self.key).count(), 1)
コード例 #13
0
ファイル: test_pool.py プロジェクト: Fokko/incubator-airflow
 def test_create_pool_existing(self):
     pool = pool_api.create_pool(name=self.pools[0].pool,
                                 slots=5,
                                 description='')
     self.assertEqual(pool.pool, self.pools[0].pool)
     self.assertEqual(pool.slots, 5)
     self.assertEqual(pool.description, '')
     with create_session() as session:
         self.assertEqual(session.query(models.Pool).count(), 2)
コード例 #14
0
ファイル: test_pool.py プロジェクト: Fokko/incubator-airflow
 def test_create_pool(self):
     pool = pool_api.create_pool(name='foo',
                                 slots=5,
                                 description='')
     self.assertEqual(pool.pool, 'foo')
     self.assertEqual(pool.slots, 5)
     self.assertEqual(pool.description, '')
     with create_session() as session:
         self.assertEqual(session.query(models.Pool).count(), 3)
コード例 #15
0
    def test_delete_dag_successful_delete_not_keeping_records_in_log(self):

        with create_session() as session:
            self.assertEqual(session.query(DM).filter(DM.dag_id == self.key).count(), 1)
            self.assertEqual(session.query(DR).filter(DR.dag_id == self.key).count(), 1)
            self.assertEqual(session.query(TI).filter(TI.dag_id == self.key).count(), 1)
            self.assertEqual(session.query(TF).filter(TF.dag_id == self.key).count(), 1)
            self.assertEqual(session.query(TR).filter(TR.dag_id == self.key).count(), 1)
            self.assertEqual(session.query(LOG).filter(LOG.dag_id == self.key).count(), 1)

        delete_dag(dag_id=self.key, keep_records_in_log=False)

        with create_session() as session:
            self.assertEqual(session.query(DM).filter(DM.dag_id == self.key).count(), 0)
            self.assertEqual(session.query(DR).filter(DR.dag_id == self.key).count(), 0)
            self.assertEqual(session.query(TI).filter(TI.dag_id == self.key).count(), 0)
            self.assertEqual(session.query(TF).filter(TF.dag_id == self.key).count(), 0)
            self.assertEqual(session.query(TR).filter(TR.dag_id == self.key).count(), 0)
            self.assertEqual(session.query(LOG).filter(LOG.dag_id == self.key).count(), 0)
コード例 #16
0
def default_action_log(log, **_):
    """
    A default action logger callback that behave same as www.utils.action_logging
    which uses global session and pushes log ORM object.
    :param log: An log ORM instance
    :param **_: other keyword arguments that is not being used by this function
    :return: None
    """
    with create_session() as session:
        session.add(log)
コード例 #17
0
    def tearDown(self):
        super().tearDown()

        with create_session() as session:
            session.query(DagRun).delete()
            session.query(TI).delete()

        for var in TI_CONTEXT_ENV_VARS:
            if var in os.environ:
                del os.environ[var]
コード例 #18
0
ファイル: test_pool.py プロジェクト: Fokko/incubator-airflow
 def setUp(self):
     self.pools = []
     for i in range(2):
         name = 'experimental_%s' % (i + 1)
         pool = models.Pool(
             pool=name,
             slots=i,
             description=name,
         )
         self.pools.append(pool)
     with create_session() as session:
         session.add_all(self.pools)
コード例 #19
0
    def test_without_dag_run(self):
        """This checks the defensive against non existent tasks in a dag run"""
        value = False
        dag = DAG('shortcircuit_operator_test_without_dag_run',
                  default_args={
                      'owner': 'airflow',
                      'start_date': DEFAULT_DATE
                  },
                  schedule_interval=INTERVAL)
        short_op = ShortCircuitOperator(task_id='make_choice',
                                        dag=dag,
                                        python_callable=lambda: value)
        branch_1 = DummyOperator(task_id='branch_1', dag=dag)
        branch_1.set_upstream(short_op)
        branch_2 = DummyOperator(task_id='branch_2', dag=dag)
        branch_2.set_upstream(branch_1)
        upstream = DummyOperator(task_id='upstream', dag=dag)
        upstream.set_downstream(short_op)
        dag.clear()

        short_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)

        with create_session() as session:
            tis = session.query(TI).filter(
                TI.dag_id == dag.dag_id,
                TI.execution_date == DEFAULT_DATE
            )

            for ti in tis:
                if ti.task_id == 'make_choice':
                    self.assertEqual(ti.state, State.SUCCESS)
                elif ti.task_id == 'upstream':
                    # should not exist
                    raise Exception
                elif ti.task_id == 'branch_1' or ti.task_id == 'branch_2':
                    self.assertEqual(ti.state, State.SKIPPED)
                else:
                    raise Exception

            value = True
            dag.clear()

            short_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
            for ti in tis:
                if ti.task_id == 'make_choice':
                    self.assertEqual(ti.state, State.SUCCESS)
                elif ti.task_id == 'upstream':
                    # should not exist
                    raise Exception
                elif ti.task_id == 'branch_1' or ti.task_id == 'branch_2':
                    self.assertEqual(ti.state, State.NONE)
                else:
                    raise Exception
コード例 #20
0
    def test_deactivate_unknown_dags(self):
        """
        Test that dag_ids not passed into deactivate_unknown_dags
        are deactivated when function is invoked
        """
        dagbag = DagBag(include_examples=True)
        dag_id = "test_deactivate_unknown_dags"
        expected_active_dags = dagbag.dags.keys()

        model_before = DagModel(dag_id=dag_id, is_active=True)
        with create_session() as session:
            session.merge(model_before)

        models.DAG.deactivate_unknown_dags(expected_active_dags)

        after_model = DagModel.get_dagmodel(dag_id)
        self.assertTrue(model_before.is_active)
        self.assertFalse(after_model.is_active)

        # clean up
        with create_session() as session:
            session.query(DagModel).filter(DagModel.dag_id == 'test_deactivate_unknown_dags').delete()
コード例 #21
0
    def _labels_to_key(self, labels):
        try_num = 1
        try:
            try_num = int(labels.get('try_number', '1'))
        except ValueError:
            self.log.warn("could not get try_number as an int: %s", labels.get('try_number', '1'))

        try:
            dag_id = labels['dag_id']
            task_id = labels['task_id']
            ex_time = self._label_safe_datestring_to_datetime(labels['execution_date'])
        except Exception as e:
            self.log.warn(
                'Error while retrieving labels; labels: %s; exception: %s',
                labels, e
            )
            return None

        with create_session() as session:
            tasks = (
                session
                .query(TaskInstance)
                .filter_by(execution_date=ex_time).all()
            )
            self.log.info(
                'Checking %s task instances.',
                len(tasks)
            )
            for task in tasks:
                if (
                    self._make_safe_label_value(task.dag_id) == dag_id and
                    self._make_safe_label_value(task.task_id) == task_id and
                    task.execution_date == ex_time
                ):
                    self.log.info(
                        'Found matching task %s-%s (%s) with current state of %s',
                        task.dag_id, task.task_id, task.execution_date, task.state
                    )
                    dag_id = task.dag_id
                    task_id = task.task_id
                    return (dag_id, task_id, ex_time, try_num)
        self.log.warn(
            'Failed to find and match task details to a pod; labels: %s',
            labels
        )
        return None
コード例 #22
0
 def execute(self, context):
     dro = DagRunOrder(run_id='trig__' + timezone.utcnow().isoformat())
     dro = self.python_callable(context, dro)
     if dro:
         with create_session() as session:
             dbag = DagBag(settings.DAGS_FOLDER)
             trigger_dag = dbag.get_dag(self.trigger_dag_id)
             dr = trigger_dag.create_dagrun(
                 run_id=dro.run_id,
                 state=State.RUNNING,
                 conf=dro.payload,
                 external_trigger=True)
             self.log.info("Creating DagRun %s", dr)
             session.add(dr)
             session.commit()
     else:
         self.log.info("Criteria not met, moving on")
コード例 #23
0
def dag_paused(dag_id, paused):
    """(Un)pauses a dag"""

    DagModel = models.DagModel
    with create_session() as session:
        orm_dag = (
            session.query(DagModel)
                   .filter(DagModel.dag_id == dag_id).first()
        )
        if paused == 'true':
            orm_dag.is_paused = True
        else:
            orm_dag.is_paused = False
        session.merge(orm_dag)
        session.commit()

    return jsonify({'response': 'ok'})
コード例 #24
0
    def decorated(*args, **kwargs):
        from flask import request

        header = request.headers.get("Authorization")
        if header:
            userpass = ''.join(header.split()[1:])
            username, password = base64.b64decode(userpass).decode("utf-8").split(":", 1)

            with create_session() as session:
                try:
                    authenticate(session, username, password)

                    response = function(*args, **kwargs)
                    response = make_response(response)
                    return response

                except AuthenticationError:
                    return _forbidden()

        return _unauthorized()
コード例 #25
0
ファイル: kubernetes_executor.py プロジェクト: wooga/airflow
 def _change_state(self, key, state, pod_id):
     if state != State.RUNNING:
         self.kube_scheduler.delete_pod(pod_id)
         try:
             self.log.info('Deleted pod: %s', str(key))
             self.running.pop(key)
         except KeyError:
             self.log.debug('Could not find key: %s', str(key))
             pass
     self.event_buffer[key] = state
     (dag_id, task_id, ex_time, try_number) = key
     with create_session() as session:
         item = session.query(TaskInstance).filter_by(
             dag_id=dag_id,
             task_id=task_id,
             execution_date=ex_time
         ).one()
         if state:
             item.state = state
             session.add(item)
コード例 #26
0
    def test_delete_dag_dag_still_in_dagbag(self):
        with create_session() as session:
            models_to_check = ['DagModel', 'DagRun', 'TaskInstance']
            record_counts = {}

            for model_name in models_to_check:
                m = getattr(models, model_name)
                record_counts[model_name] = session.query(m).filter(m.dag_id == self.dag_id).count()

            with self.assertRaises(DagFileExists):
                delete_dag(self.dag_id)

            # No change should happen in DB
            for model_name in models_to_check:
                m = getattr(models, model_name)
                self.assertEqual(
                    session.query(m).filter(
                        m.dag_id == self.dag_id
                    ).count(),
                    record_counts[model_name]
                )
コード例 #27
0
    def test_kill_zombies(self, mock_ti_handle_failure):
        """
        Test that kill zombies call TIs failure handler with proper context
        """
        dagbag = models.DagBag()
        with create_session() as session:
            session.query(TI).delete()
            dag = dagbag.get_dag('example_branch_operator')
            task = dag.get_task(task_id='run_this_first')

            ti = TI(task, DEFAULT_DATE, State.RUNNING)

            session.add(ti)
            session.commit()

            zombies = [SimpleTaskInstance(ti)]
            dagbag.kill_zombies(zombies)
            mock_ti_handle_failure \
                .assert_called_with(ANY,
                                    configuration.getboolean('core',
                                                             'unit_test_mode'),
                                    ANY)
コード例 #28
0
ファイル: utils.py プロジェクト: 7digital/incubator-airflow
    def wrapper(*args, **kwargs):
        if current_user and hasattr(current_user, 'username'):
            user = current_user.username
        else:
            user = '******'

        log = models.Log(
            event=f.__name__,
            task_instance=None,
            owner=user,
            extra=str(list(request.args.items())),
            task_id=request.args.get('task_id'),
            dag_id=request.args.get('dag_id'))

        if 'execution_date' in request.args:
            log.execution_date = timezone.parse(request.args.get('execution_date'))

        with create_session() as session:
            session.add(log)
            session.commit()

        return f(*args, **kwargs)
コード例 #29
0
    def test_find_zombies(self):
        manager = DagFileProcessorManager(
            dag_directory='directory',
            file_paths=['abc.txt'],
            max_runs=1,
            processor_factory=MagicMock().return_value,
            signal_conn=MagicMock(),
            stat_queue=MagicMock(),
            result_queue=MagicMock,
            async_mode=True)

        dagbag = DagBag(TEST_DAG_FOLDER)
        with create_session() as session:
            session.query(LJ).delete()
            dag = dagbag.get_dag('example_branch_operator')
            task = dag.get_task(task_id='run_this_first')

            ti = TI(task, DEFAULT_DATE, State.RUNNING)
            lj = LJ(ti)
            lj.state = State.SHUTDOWN
            lj.id = 1
            ti.job_id = lj.id

            session.add(lj)
            session.add(ti)
            session.commit()

            manager._last_zombie_query_time = timezone.utcnow() - timedelta(
                seconds=manager._zombie_threshold_secs + 1)
            zombies = manager._find_zombies()
            self.assertEqual(1, len(zombies))
            self.assertIsInstance(zombies[0], SimpleTaskInstance)
            self.assertEqual(ti.dag_id, zombies[0].dag_id)
            self.assertEqual(ti.task_id, zombies[0].task_id)
            self.assertEqual(ti.execution_date, zombies[0].execution_date)

            session.query(TI).delete()
            session.query(LJ).delete()
コード例 #30
0
    def test_kill_zombies(self, mock_ti_handle_failure):
        """
        Test that kill zombies call TIs failure handler with proper context
        """
        dagbag = models.DagBag(dag_folder=self.empty_dir,
                               include_examples=True)
        with create_session() as session:
            session.query(TI).delete()
            dag = dagbag.get_dag('example_branch_operator')
            task = dag.get_task(task_id='run_this_first')

            ti = TI(task, DEFAULT_DATE, State.RUNNING)

            session.add(ti)
            session.commit()

            zombies = [SimpleTaskInstance(ti)]
            dagbag.kill_zombies(zombies)
            mock_ti_handle_failure \
                .assert_called_with(ANY,
                                    configuration.getboolean('core',
                                                             'unit_test_mode'),
                                    ANY)
コード例 #31
0
    def test_without_dag_run(self):
        """This checks the defensive against non existent tasks in a dag run"""
        self.branch_op = ChooseBranchOne(task_id="make_choice", dag=self.dag)
        self.branch_1.set_upstream(self.branch_op)
        self.branch_2.set_upstream(self.branch_op)
        self.dag.clear()

        self.branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)

        with create_session() as session:
            tis = session.query(TI).filter(TI.dag_id == self.dag.dag_id,
                                           TI.execution_date == DEFAULT_DATE)

            for ti in tis:
                if ti.task_id == 'make_choice':
                    self.assertEqual(ti.state, State.SUCCESS)
                elif ti.task_id == 'branch_1':
                    # should exist with state None
                    self.assertEqual(ti.state, State.NONE)
                elif ti.task_id == 'branch_2':
                    self.assertEqual(ti.state, State.SKIPPED)
                else:
                    raise Exception
コード例 #32
0
ファイル: decorators.py プロジェクト: Fokko/incubator-airflow
    def wrapper(*args, **kwargs):

        with create_session() as session:
            if g.user.is_anonymous:
                user = '******'
            else:
                user = g.user.username

            log = Log(
                event=f.__name__,
                task_instance=None,
                owner=user,
                extra=str(list(request.args.items())),
                task_id=request.args.get('task_id'),
                dag_id=request.args.get('dag_id'))

            if 'execution_date' in request.args:
                log.execution_date = pendulum.parse(
                    request.args.get('execution_date'))

            session.add(log)

        return f(*args, **kwargs)
コード例 #33
0
    def test_not_requeue_non_requeueable_task_instance(self):
        dag = models.DAG(dag_id='test_not_requeue_non_requeueable_task_instance')
        # Use BaseSensorOperator because sensor got
        # one additional DEP in BaseSensorOperator().deps
        task = BaseSensorOperator(
            task_id='test_not_requeue_non_requeueable_task_instance_op',
            dag=dag,
            pool='test_pool',
            owner='airflow',
            start_date=timezone.datetime(2016, 2, 1, 0, 0, 0))
        ti = TI(
            task=task, execution_date=timezone.utcnow(), state=State.QUEUED)
        with create_session() as session:
            session.add(ti)
            session.commit()

        all_deps = RUNNING_DEPS | task.deps
        all_non_requeueable_deps = all_deps - REQUEUEABLE_DEPS
        patch_dict = {}
        for dep in all_non_requeueable_deps:
            class_name = dep.__class__.__name__
            dep_patch = patch('%s.%s.%s' % (dep.__module__, class_name,
                                            dep._get_dep_statuses.__name__))
            method_patch = dep_patch.start()
            method_patch.return_value = iter([TIDepStatus('mock_' + class_name, True,
                                                          'mock')])
            patch_dict[class_name] = (dep_patch, method_patch)

        for class_name, (dep_patch, method_patch) in patch_dict.items():
            method_patch.return_value = iter(
                [TIDepStatus('mock_' + class_name, False, 'mock')])
            ti.run()
            self.assertEqual(ti.state, State.QUEUED)
            dep_patch.return_value = TIDepStatus('mock_' + class_name, True, 'mock')

        for (dep_patch, method_patch) in patch_dict.values():
            dep_patch.stop()
コード例 #34
0
    def test_rerun_failed_subdag(self):
        """
        When there is an existing DagRun with failed state, reset the DagRun and the
        corresponding TaskInstances
        """
        dag = DAG('parent', default_args=default_args)
        subdag = DAG('parent.test', default_args=default_args)
        subdag_task = SubDagOperator(task_id='test',
                                     subdag=subdag,
                                     dag=dag,
                                     poke_interval=1)
        dummy_task = DummyOperator(task_id='dummy', dag=subdag)

        with create_session() as session:
            dummy_task_instance = TaskInstance(
                task=dummy_task,
                execution_date=DEFAULT_DATE,
                state=State.FAILED,
            )
            session.add(dummy_task_instance)
            session.commit()

        sub_dagrun = subdag.create_dagrun(
            run_id="scheduled__{}".format(DEFAULT_DATE.isoformat()),
            execution_date=DEFAULT_DATE,
            state=State.FAILED,
            external_trigger=True,
        )

        subdag_task._reset_dag_run_and_task_instances(
            sub_dagrun, execution_date=DEFAULT_DATE)

        dummy_task_instance.refresh_from_db()
        self.assertEqual(dummy_task_instance.state, State.NONE)

        sub_dagrun.refresh_from_db()
        self.assertEqual(sub_dagrun.state, State.RUNNING)
コード例 #35
0
ファイル: yml_loader.py プロジェクト: Gemma-Analytics/ewah
 def __init__(self, stream):
     with create_session() as session:
         airflow_variables = {
             var.key: var.val
             for var in session.query(Variable)
         }
     self._root = os.path.split(stream.name)[0]
     ctx = {
         "env": os.environ,
         "airflow_variables": airflow_variables,
         "pytz": {name: getattr(pytz, name)
                  for name in pytz.__all__},
         "datetime": {
             name: getattr(datetime, name)
             for name in
             ["date", "datetime", "time", "timedelta", "tzinfo"]
         },
         "re": {name: getattr(re, name)
                for name in re.__all__},
     }
     # Enable Jinja2 in the yaml files
     yaml_stream = StringIO(Template(stream.read()).render(ctx))
     yaml_stream.name = stream.name
     super().__init__(yaml_stream)
コード例 #36
0
def branched_context():
    """
    Generic Airflow context fixture with branched tasks that can be passed to a
    callable that requires context
    """
    with create_session() as session:
        dag = DAGFactory()

        branch_a = PythonOperatorFactory(task_id='branch_a', dag=dag)
        branch_b = PythonOperatorFactory(task_id='branch_b', dag=dag)
        current_task = PythonOperatorFactory(task_id='current_task', dag=dag)
        next_task = PythonOperatorFactory(task_id='next_task', dag=dag)

        branch_a.set_downstream(current_task)
        branch_b.set_downstream(current_task)
        # join
        current_task.set_downstream(next_task)

        dag_run = dag.create_dagrun(run_id="manual__",
                                    start_date=timezone.utcnow(),
                                    execution_date=timezone.utcnow(),
                                    state=State.RUNNING,
                                    conf=None,
                                    session=session)

    ti = None
    for instance in dag_run.get_task_instances():
        if instance.task_id == 'current_task':
            ti = instance
            break

    if ti is None:
        raise ValueError('Unable to find the current task')

    ti.task = current_task
    return ti.get_template_context()
コード例 #37
0
 def get_job(self, job_name: Text, execution_id: Text) -> Optional[JobInfo]:
     with create_session() as session:
         dag_run = session.query(DagRun).filter(
             DagRun.run_id == execution_id).first()
         if dag_run is None:
             return None
         task = session.query(TaskInstance).filter(
             TaskInstance.dag_id == dag_run.dag_id,
             TaskInstance.execution_date == dag_run.execution_date,
             TaskInstance.task_id == job_name).first()
         if task is None:
             return None
         else:
             project_name, workflow_name = self.dag_id_to_namespace_workflow(
                 dag_run.dag_id)
             return JobInfo(job_name=job_name,
                            state=self.airflow_state_to_state(task.state),
                            workflow_execution=WorkflowExecutionInfo(
                                workflow_info=WorkflowInfo(
                                    namespace=project_name,
                                    workflow_name=workflow_name),
                                execution_id=dag_run.run_id,
                                state=self.airflow_state_to_state(
                                    dag_run.state)))
コード例 #38
0
    def wrapper(*args, **kwargs):
        # AnonymousUserMixin() has user attribute but its value is None.
        if current_user and hasattr(current_user,
                                    'user') and current_user.user:
            user = current_user.user.username
        else:
            user = '******'

        log = models.Log(event=f.__name__,
                         task_instance=None,
                         owner=user,
                         extra=str(list(request.values.items())),
                         task_id=request.values.get('task_id'),
                         dag_id=request.values.get('dag_id'))

        if request.values.get('execution_date'):
            log.execution_date = timezone.parse(
                request.values.get('execution_date'))

        with create_session() as session:
            session.add(log)
            session.commit()

        return f(*args, **kwargs)
コード例 #39
0
    def setUp(self):
        self.key = "test_dag_id"

        task = DummyOperator(task_id='dummy',
                             dag=models.DAG(
                                 dag_id=self.key,
                                 default_args={'start_date': days_ago(2)}),
                             owner='airflow')

        test_date = days_ago(1)
        with create_session() as session:
            session.add(DM(dag_id=self.key))
            session.add(DR(dag_id=self.key))
            session.add(
                TI(task=task, execution_date=test_date, state=State.SUCCESS))
            # flush to ensure task instance if written before
            # task reschedule because of FK constraint
            session.flush()
            session.add(
                LOG(dag_id=self.key,
                    task_id=None,
                    task_instance=None,
                    execution_date=test_date,
                    event="varimport"))
            session.add(
                TF(task=task,
                   execution_date=test_date,
                   start_date=test_date,
                   end_date=test_date))
            session.add(
                TR(task=task,
                   execution_date=test_date,
                   start_date=test_date,
                   end_date=test_date,
                   try_number=1,
                   reschedule_date=test_date))
コード例 #40
0
    def test_kill_zombies_doesn_nothing(self, mock_ti_handle_failure):
        """
        Test that kill zombies does nothing when job is running and received heartbeat
        """
        dagbag = models.DagBag(dag_folder=self.empty_dir, include_examples=True)
        with create_session() as session:
            session.query(TI).delete()
            session.query(LJ).delete()
            dag = dagbag.get_dag('example_branch_operator')
            task = dag.get_task(task_id='run_this_first')

            ti = TI(task, DEFAULT_DATE, State.RUNNING)
            lj = LJ(ti)
            lj.latest_heartbeat = utcnow()
            lj.state = State.RUNNING
            lj.id = 1
            ti.job_id = lj.id

            session.add(lj)
            session.add(ti)
            session.commit()

            dagbag.kill_zombies()
            mock_ti_handle_failure.assert_not_called()
コード例 #41
0
    def tearDown(self):
        super(ShortCircuitOperatorTest, self).tearDown()

        with create_session() as session:
            session.query(DagRun).delete()
            session.query(TI).delete()
コード例 #42
0
    def tearDown(self):
        super().tearDown()

        with create_session() as session:
            session.query(DagRun).delete()
            session.query(TI).delete()
コード例 #43
0
    def setUpClass(cls):
        super().setUpClass()

        with create_session() as session:
            session.query(DagRun).delete()
            session.query(TI).delete()
コード例 #44
0
def clear_rendered_ti_fields():
    with create_session() as session:
        session.query(RenderedTaskInstanceFields).delete()
コード例 #45
0
def main(num_runs, repeat, pre_create_dag_runs, executor_class, dag_ids):
    """
    This script can be used to measure the total "scheduler overhead" of Airflow.

    By overhead we mean if the tasks executed instantly as soon as they are
    executed (i.e. they do nothing) how quickly could we schedule them.

    It will monitor the task completion of the Mock/stub executor (no actual
    tasks are run) and after the required number of dag runs for all the
    specified dags have completed all their tasks, it will cleanly shut down
    the scheduler.

    The dags you run with need to have an early enough start_date to create the
    desired number of runs.

    Care should be taken that other limits (DAG concurrency, pool size etc) are
    not the bottleneck. This script doesn't help you in that regard.

    It is recommended to repeat the test at least 3 times (`--repeat=3`, the
    default) so that you can get somewhat-accurate variance on the reported
    timing numbers, but this can be disabled for longer runs if needed.
    """

    # Turn on unit test mode so that we don't do any sleep() in the scheduler
    # loop - not needed on master, but this script can run against older
    # releases too!
    os.environ['AIRFLOW__CORE__UNIT_TEST_MODE'] = 'True'

    os.environ['AIRFLOW__CORE__DAG_CONCURRENCY'] = '500'

    # Set this so that dags can dynamically configure their end_date
    os.environ['AIRFLOW_BENCHMARK_MAX_DAG_RUNS'] = str(num_runs)
    os.environ['PERF_MAX_RUNS'] = str(num_runs)

    if pre_create_dag_runs:
        os.environ['AIRFLOW__SCHEDULER__USE_JOB_SCHEDULE'] = 'False'

    from airflow.jobs.scheduler_job import SchedulerJob
    from airflow.models.dagbag import DagBag
    from airflow.utils import db

    dagbag = DagBag()

    dags = []

    with db.create_session() as session:
        pause_all_dags(session)
        for dag_id in dag_ids:
            dag = dagbag.get_dag(dag_id)
            dag.sync_to_db(session=session)
            dags.append(dag)
            reset_dag(dag, session)

            next_run_date = dag.normalize_schedule(dag.start_date
                                                   or min(t.start_date
                                                          for t in dag.tasks))

            for _ in range(num_runs - 1):
                next_run_date = dag.following_schedule(next_run_date)

            end_date = dag.end_date or dag.default_args.get('end_date')
            if end_date != next_run_date:
                message = (
                    f"DAG {dag_id} has incorrect end_date ({end_date}) for number of runs! "
                    f"It should be "
                    f" {next_run_date}")
                sys.exit(message)

            if pre_create_dag_runs:
                create_dag_runs(dag, num_runs, session)

    ShortCircuitExecutor = get_executor_under_test(executor_class)

    executor = ShortCircuitExecutor(dag_ids_to_watch=dag_ids,
                                    num_runs=num_runs)
    scheduler_job = SchedulerJob(dag_ids=dag_ids,
                                 do_pickle=False,
                                 executor=executor)
    executor.scheduler_job = scheduler_job

    total_tasks = sum(len(dag.tasks) for dag in dags)

    if 'PYSPY' in os.environ:
        pid = str(os.getpid())
        filename = os.environ.get('PYSPY_O', 'flame-' + pid + '.html')
        os.spawnlp(os.P_NOWAIT, 'sudo', 'sudo', 'py-spy', 'record', '-o',
                   filename, '-p', pid, '--idle')

    times = []

    # Need a lambda to refer to the _latest_ value for scheduler_job, not just
    # the initial one
    code_to_test = lambda: scheduler_job.run()  # pylint: disable=unnecessary-lambda

    for count in range(repeat):
        gc.disable()
        start = time.perf_counter()

        code_to_test()
        times.append(time.perf_counter() - start)
        gc.enable()
        print("Run %d time: %.5f" % (count + 1, times[-1]))

        if count + 1 != repeat:
            with db.create_session() as session:
                for dag in dags:
                    reset_dag(dag, session)

            executor.reset(dag_ids)
            scheduler_job = SchedulerJob(dag_ids=dag_ids,
                                         do_pickle=False,
                                         executor=executor)
            executor.scheduler_job = scheduler_job

    print()
    print()
    msg = "Time for %d dag runs of %d dags with %d total tasks: %.4fs"

    if len(times) > 1:
        print((msg + " (±%.3fs)") %
              (num_runs, len(dags), total_tasks, statistics.mean(times),
               statistics.stdev(times)))
    else:
        print(msg % (num_runs, len(dags), total_tasks, times[0]))

    print()
    print()
コード例 #46
0
 def test_create_pool(self):
     pool = self.client.create_pool(name='foo', slots=1, description='')
     self.assertEqual(pool, ('foo', 1, ''))
     with create_session() as session:
         self.assertEqual(session.query(models.Pool).count(), 1)
コード例 #47
0
def clear_db_dag_code():
    with create_session() as session:
        session.query(DagCode).delete()
コード例 #48
0
def lineage_parent_id(run_id, task):
    with create_session() as session:
        job_name = f"{task.dag_id}.{task.task_id}"
        ids = str(JobIdMapping.get(job_name, run_id, session))
        return f"{os.getenv('OPENLINEAGE_NAMESPACE')}/{job_name}/{ids}"
コード例 #49
0
def clear_db_task_instance():
    with create_session() as session:
        session.query(TaskInstance).delete()
        session.query(TaskExecution).delete()
        session.query(TaskState).delete()
コード例 #50
0
def clear_db_dag_pickle():
    with create_session() as session:
        session.query(DagPickle).delete()
コード例 #51
0
def clear_db_event_model():
    with create_session() as session:
        session.query(EventModel).delete()
コード例 #52
0
    def setUpClass(cls):
        super(PythonOperatorTest, cls).setUpClass()

        with create_session() as session:
            session.query(DagRun).delete()
            session.query(TI).delete()
コード例 #53
0
def clear_db_variables():
    with create_session() as session:
        session.query(Variable).delete()
コード例 #54
0
 def setUp(self):
     db.clear_db_pools()
     with create_session() as session:
         test_pool = Pool(pool='test_pool', slots=1)
         session.add(test_pool)
         session.commit()
コード例 #55
0
from airflow import DAG
from airflow.models import Variable
from airflow.operators.dummy_operator import DummyOperator
from airflow.utils.dates import days_ago
from airflow.utils.db import create_session

with create_session() as session:
    session.query(Variable).all()


with DAG(dag_id="dag_no_top_level_query_fail", schedule_interval=None, start_date=days_ago(1)) as dag:
    DummyOperator(task_id="test")
コード例 #56
0
 def snapshot_state(dag, execution_dates):
     TI = models.TaskInstance
     with create_session() as session:
         return session.query(TI).filter(
             TI.dag_id == dag.dag_id,
             TI.execution_date.in_(execution_dates)).all()
コード例 #57
0
def clear_db_connections():
    with create_session() as session:
        session.query(Connection).delete()
        create_default_connections(session)
コード例 #58
0
 def tearDown(self):
     with create_session() as session:
         session.query(TaskFail).delete()
         session.query(TaskReschedule).delete()
         session.query(models.TaskInstance).delete()
コード例 #59
0
ファイル: test_backfill_job.py プロジェクト: shivamx/airflow
    def test_backfill_max_limit_check(self):
        dag_id = 'test_backfill_max_limit_check'
        run_id = 'test_dagrun'
        start_date = DEFAULT_DATE - datetime.timedelta(hours=1)
        end_date = DEFAULT_DATE

        dag_run_created_cond = threading.Condition()

        def run_backfill(cond):
            cond.acquire()
            # this session object is different than the one in the main thread
            with create_session() as thread_session:
                try:
                    dag = self._get_dag_test_max_active_limits(dag_id)

                    # Existing dagrun that is not within the backfill range
                    dag.create_dagrun(
                        run_id=run_id,
                        state=State.RUNNING,
                        execution_date=DEFAULT_DATE +
                        datetime.timedelta(hours=1),
                        start_date=DEFAULT_DATE,
                    )

                    thread_session.commit()
                    cond.notify()
                finally:
                    cond.release()
                    thread_session.close()

                executor = MockExecutor()
                job = BackfillJob(dag=dag,
                                  start_date=start_date,
                                  end_date=end_date,
                                  executor=executor,
                                  donot_pickle=True)
                job.run()

        backfill_job_thread = threading.Thread(target=run_backfill,
                                               name="run_backfill",
                                               args=(dag_run_created_cond, ))

        dag_run_created_cond.acquire()
        with create_session() as session:
            backfill_job_thread.start()
            try:
                # at this point backfill can't run since the max_active_runs has been
                # reached, so it is waiting
                dag_run_created_cond.wait(timeout=1.5)
                dagruns = DagRun.find(dag_id=dag_id)
                dr = dagruns[0]
                self.assertEqual(1, len(dagruns))
                self.assertEqual(dr.run_id, run_id)

                # allow the backfill to execute
                # by setting the existing dag run to SUCCESS,
                # backfill will execute dag runs 1 by 1
                dr.set_state(State.SUCCESS)
                session.merge(dr)
                session.commit()

                backfill_job_thread.join()

                dagruns = DagRun.find(dag_id=dag_id)
                self.assertEqual(3,
                                 len(dagruns))  # 2 from backfill + 1 existing
                self.assertEqual(dagruns[-1].run_id, dr.run_id)
            finally:
                dag_run_created_cond.release()
コード例 #60
0
def set_default_pool_slots(slots):
    with create_session() as session:
        default_pool = Pool.get_default_pool(session)
        default_pool.slots = slots