def test_get_templated_fields(self, templated_field, expected_rendered_field):
        """
        Test that template_fields are rendered correctly, stored in the Database,
        and are correctly fetched using RTIF.get_templated_fields
        """
        dag = DAG("test_serialized_rendered_fields", start_date=START_DATE)
        with dag:
            task = BashOperator(task_id="test", bash_command=templated_field)

        ti = TI(task=task, execution_date=EXECUTION_DATE)
        rtif = RTIF(ti=ti)
        self.assertEqual(ti.dag_id, rtif.dag_id)
        self.assertEqual(ti.task_id, rtif.task_id)
        self.assertEqual(ti.execution_date, rtif.execution_date)
        self.assertEqual(expected_rendered_field, rtif.rendered_fields.get("bash_command"))

        with create_session() as session:
            session.add(rtif)

        self.assertEqual(
            {"bash_command": expected_rendered_field, "env": None}, RTIF.get_templated_fields(ti=ti)
        )

        # Test the else part of get_templated_fields
        # i.e. for the TIs that are not stored in RTIF table
        # Fetching them will return None
        with dag:
            task_2 = BashOperator(task_id="test2", bash_command=templated_field)

        ti2 = TI(task_2, EXECUTION_DATE)
        self.assertIsNone(RTIF.get_templated_fields(ti=ti2))
Esempio n. 2
0
    def test_default_pool_open_slots(self):
        set_default_pool_slots(5)
        assert 5 == Pool.get_default_pool().open_slots()

        dag = DAG(
            dag_id='test_default_pool_open_slots',
            start_date=DEFAULT_DATE,
        )
        op1 = DummyOperator(task_id='dummy1', dag=dag)
        op2 = DummyOperator(task_id='dummy2', dag=dag, pool_slots=2)
        ti1 = TI(task=op1, execution_date=DEFAULT_DATE)
        ti2 = TI(task=op2, execution_date=DEFAULT_DATE)
        ti1.state = State.RUNNING
        ti2.state = State.QUEUED

        session = settings.Session
        session.add(ti1)
        session.add(ti2)
        session.commit()
        session.close()

        assert 2 == Pool.get_default_pool().open_slots()
        assert {
            "default_pool": {
                "open": 2,
                "queued": 2,
                "total": 5,
                "running": 1,
            }
        } == Pool.slots_stats()
    def test_write(self):
        """
        Test records can be written and overwritten
        """
        Variable.set(key="test_key", value="test_val")

        session = settings.Session()
        result = session.query(RTIF).all()
        assert [] == result

        with DAG("test_write", start_date=START_DATE):
            task = BashOperator(task_id="test",
                                bash_command="echo {{ var.value.test_key }}")

        rtif = RTIF(TI(task=task, execution_date=EXECUTION_DATE))
        rtif.write()
        result = (session.query(RTIF.dag_id, RTIF.task_id,
                                RTIF.rendered_fields).filter(
                                    RTIF.dag_id == rtif.dag_id,
                                    RTIF.task_id == rtif.task_id,
                                    RTIF.execution_date == rtif.execution_date,
                                ).first())
        assert ('test_write', 'test', {
            'bash_command': 'echo test_val',
            'env': None
        }) == result

        # Test that overwrite saves new values to the DB
        Variable.delete("test_key")
        Variable.set(key="test_key", value="test_val_updated")

        with DAG("test_write", start_date=START_DATE):
            updated_task = BashOperator(
                task_id="test", bash_command="echo {{ var.value.test_key }}")

        rtif_updated = RTIF(
            TI(task=updated_task, execution_date=EXECUTION_DATE))
        rtif_updated.write()

        result_updated = (session.query(
            RTIF.dag_id, RTIF.task_id, RTIF.rendered_fields).filter(
                RTIF.dag_id == rtif_updated.dag_id,
                RTIF.task_id == rtif_updated.task_id,
                RTIF.execution_date == rtif_updated.execution_date,
            ).first())
        assert (
            'test_write',
            'test',
            {
                'bash_command': 'echo test_val_updated',
                'env': None
            },
        ) == result_updated
Esempio n. 4
0
 def create_ti(task: "Operator", indexes: Tuple[int,
                                                ...]) -> Generator:
     for map_index in indexes:
         ti = TI(task, run_id=self.run_id, map_index=map_index)
         task_instance_mutation_hook(ti)
         created_counts[ti.operator] += 1
         yield ti
    def test_delete_old_records(self, rtif_num, num_to_keep, remaining_rtifs, expected_query_count):
        """
        Test that old records are deleted from rendered_task_instance_fields table
        for a given task_id and dag_id.
        """
        session = settings.Session()
        dag = DAG("test_delete_old_records", start_date=START_DATE)
        with dag:
            task = BashOperator(task_id="test", bash_command="echo {{ ds }}")

        rtif_list = [
            RTIF(TI(task=task, execution_date=EXECUTION_DATE + timedelta(days=num)))
            for num in range(rtif_num)
        ]

        session.add_all(rtif_list)
        session.commit()

        result = session.query(RTIF).filter(RTIF.dag_id == dag.dag_id, RTIF.task_id == task.task_id).all()

        for rtif in rtif_list:
            self.assertIn(rtif, result)

        self.assertEqual(rtif_num, len(result))

        # Verify old records are deleted and only 'num_to_keep' records are kept
        with assert_queries_count(expected_query_count):
            RTIF.delete_old_records(task_id=task.task_id, dag_id=task.dag_id, num_to_keep=num_to_keep)
        result = session.query(RTIF).filter(RTIF.dag_id == dag.dag_id, RTIF.task_id == task.task_id).all()
        self.assertEqual(remaining_rtifs, len(result))
Esempio n. 6
0
    def verify_integrity(self, session: Session = None):
        """
        Verifies the DagRun by checking for removed tasks or tasks that are not in the
        database yet. It will set state to removed or add the task if required.

        :param session: Sqlalchemy ORM Session
        :type session: Session
        """
        dag = self.get_dag()
        tis = self.get_task_instances(session=session)

        # check for removed or restored tasks
        task_ids = set()
        for ti in tis:
            task_instance_mutation_hook(ti)
            task_ids.add(ti.task_id)
            task = None
            try:
                task = dag.get_task(ti.task_id)
            except AirflowException:
                if ti.state == State.REMOVED:
                    pass  # ti has already been removed, just ignore it
                elif self.state is not State.RUNNING and not dag.partial:
                    self.log.warning(
                        "Failed to get task '%s' for dag '%s'. "
                        "Marking it as removed.", ti, dag)
                    Stats.incr("task_removed_from_dag.{}".format(dag.dag_id),
                               1, 1)
                    ti.state = State.REMOVED

            should_restore_task = (task
                                   is not None) and ti.state == State.REMOVED
            if should_restore_task:
                self.log.info(
                    "Restoring task '%s' which was previously "
                    "removed from DAG '%s'", ti, dag)
                Stats.incr("task_restored_to_dag.{}".format(dag.dag_id), 1, 1)
                ti.state = State.NONE
            session.merge(ti)

        # check for missing tasks
        for task in dag.task_dict.values():
            if task.start_date > self.execution_date and not self.is_backfill:
                continue

            if task.task_id not in task_ids:
                Stats.incr("task_instance_created-{}".format(task.task_type),
                           1, 1)
                ti = TI(task, self.execution_date)
                task_instance_mutation_hook(ti)
                session.add(ti)

        try:
            session.commit()
        except IntegrityError as err:
            self.log.info(str(err))
            self.log.info('Hit IntegrityError while creating the TIs for '
                          f'{dag.dag_id} - {self.execution_date}.')
            self.log.info('Doing session rollback.')
            session.rollback()
Esempio n. 7
0
    def verify_integrity(self, session: Session = NEW_SESSION):
        """
        Verifies the DagRun by checking for removed tasks or tasks that are not in the
        database yet. It will set state to removed or add the task if required.

        :param session: Sqlalchemy ORM Session
        :type session: Session
        """
        from airflow.settings import task_instance_mutation_hook

        dag = self.get_dag()
        tis = self.get_task_instances(session=session)

        # check for removed or restored tasks
        task_ids = set()
        for ti in tis:
            task_instance_mutation_hook(ti)
            task_ids.add(ti.task_id)
            task = None
            try:
                task = dag.get_task(ti.task_id)
            except AirflowException:
                if ti.state == State.REMOVED:
                    pass  # ti has already been removed, just ignore it
                elif self.state != State.RUNNING and not dag.partial:
                    self.log.warning("Failed to get task '%s' for dag '%s'. Marking it as removed.", ti, dag)
                    Stats.incr(f"task_removed_from_dag.{dag.dag_id}", 1, 1)
                    ti.state = State.REMOVED

            should_restore_task = (task is not None) and ti.state == State.REMOVED
            if should_restore_task:
                self.log.info("Restoring task '%s' which was previously removed from DAG '%s'", ti, dag)
                Stats.incr(f"task_restored_to_dag.{dag.dag_id}", 1, 1)
                ti.state = State.NONE
            session.merge(ti)

        # check for missing tasks
        for task in dag.task_dict.values():
            if task.start_date > self.execution_date and not self.is_backfill:
                continue

            if task.task_id not in task_ids:
                Stats.incr(f"task_instance_created-{task.task_type}", 1, 1)
                ti = TI(task, execution_date=None, run_id=self.run_id)
                task_instance_mutation_hook(ti)
                session.add(ti)

        try:
            session.flush()
        except IntegrityError as err:
            self.log.info(str(err))
            self.log.info('Hit IntegrityError while creating the TIs for %s- %s', dag.dag_id, self.run_id)
            self.log.info('Doing session rollback.')
            # TODO[HA]: We probably need to savepoint this so we can keep the transaction alive.
            session.rollback()
Esempio n. 8
0
    def test_open_slots(self):
        pool = Pool(pool='test_pool', slots=5)
        dag = DAG(
            dag_id='test_open_slots',
            start_date=DEFAULT_DATE, )
        t1 = DummyOperator(task_id='dummy1', dag=dag, pool='test_pool')
        t2 = DummyOperator(task_id='dummy2', dag=dag, pool='test_pool')
        ti1 = TI(task=t1, execution_date=DEFAULT_DATE)
        ti2 = TI(task=t2, execution_date=DEFAULT_DATE)
        ti1.state = State.RUNNING
        ti2.state = State.QUEUED

        session = settings.Session
        session.add(pool)
        session.add(ti1)
        session.add(ti2)
        session.commit()
        session.close()

        self.assertEqual(3, pool.open_slots())
Esempio n. 9
0
    def test_open_slots(self):
        pool = Pool(pool='test_pool', slots=5)
        dag = DAG(
            dag_id='test_open_slots',
            start_date=DEFAULT_DATE,
        )
        op1 = DummyOperator(task_id='dummy1', dag=dag, pool='test_pool')
        op2 = DummyOperator(task_id='dummy2', dag=dag, pool='test_pool')
        ti1 = TI(task=op1, execution_date=DEFAULT_DATE)
        ti2 = TI(task=op2, execution_date=DEFAULT_DATE)
        ti1.state = State.RUNNING
        ti2.state = State.QUEUED

        session = settings.Session
        session.add(pool)
        session.add(ti1)
        session.add(ti2)
        session.commit()
        session.close()

        self.assertEqual(3, pool.open_slots())  # pylint: disable=no-value-for-parameter
        self.assertEqual(1, pool.running_slots())  # pylint: disable=no-value-for-parameter
        self.assertEqual(1, pool.queued_slots())  # pylint: disable=no-value-for-parameter
        self.assertEqual(2, pool.occupied_slots())  # pylint: disable=no-value-for-parameter
        self.assertEqual(
            {
                "default_pool": {
                    "open": 128,
                    "queued": 0,
                    "total": 128,
                    "running": 0,
                },
                "test_pool": {
                    "open": 3,
                    "queued": 1,
                    "running": 1,
                    "total": 5,
                },
            },
            pool.slots_stats(),
        )
Esempio n. 10
0
    def test_default_pool_open_slots(self):
        set_default_pool_slots(5)
        self.assertEqual(5, Pool.get_default_pool().open_slots())

        dag = DAG(
            dag_id='test_default_pool_open_slots',
            start_date=DEFAULT_DATE, )
        op1 = DummyOperator(task_id='dummy1', dag=dag)
        op2 = DummyOperator(task_id='dummy2', dag=dag, pool_slots=2)
        ti1 = TI(task=op1, execution_date=DEFAULT_DATE)
        ti2 = TI(task=op2, execution_date=DEFAULT_DATE)
        ti1.state = State.RUNNING
        ti2.state = State.QUEUED

        session = settings.Session
        session.add(ti1)
        session.add(ti2)
        session.commit()
        session.close()

        self.assertEqual(2, Pool.get_default_pool().open_slots())
Esempio n. 11
0
    def test_infinite_slots(self):
        pool = Pool(pool='test_pool', slots=-1)
        dag = DAG(
            dag_id='test_infinite_slots',
            start_date=DEFAULT_DATE, )
        op1 = DummyOperator(task_id='dummy1', dag=dag, pool='test_pool')
        op2 = DummyOperator(task_id='dummy2', dag=dag, pool='test_pool')
        ti1 = TI(task=op1, execution_date=DEFAULT_DATE)
        ti2 = TI(task=op2, execution_date=DEFAULT_DATE)
        ti1.state = State.RUNNING
        ti2.state = State.QUEUED

        session = settings.Session
        session.add(pool)
        session.add(ti1)
        session.add(ti2)
        session.commit()
        session.close()

        self.assertEqual(float('inf'), pool.open_slots())  # pylint: disable=no-value-for-parameter
        self.assertEqual(1, pool.used_slots())  # pylint: disable=no-value-for-parameter
        self.assertEqual(1, pool.queued_slots())  # pylint: disable=no-value-for-parameter
        self.assertEqual(2, pool.occupied_slots())  # pylint: disable=no-value-for-parameter
    def test_delete_old_records(self):
        """
        Test that old records are deleted from rendered_task_instance_fields table
        for a given task_id and dag_id.
        """
        session = settings.Session()
        dag = DAG("test_delete_old_records", start_date=START_DATE)
        with dag:
            task = BashOperator(task_id="test", bash_command="echo {{ ds }}")

        rtif_1 = RTIF(TI(task=task, execution_date=EXECUTION_DATE))
        rtif_2 = RTIF(
            TI(task=task, execution_date=EXECUTION_DATE + timedelta(days=1)))
        rtif_3 = RTIF(
            TI(task=task, execution_date=EXECUTION_DATE + timedelta(days=2)))

        session.add(rtif_1)
        session.add(rtif_2)
        session.add(rtif_3)
        session.commit()

        result = session.query(RTIF)\
            .filter(RTIF.dag_id == dag.dag_id, RTIF.task_id == task.task_id).all()

        self.assertIn(rtif_1, result)
        self.assertIn(rtif_2, result)
        self.assertIn(rtif_3, result)
        self.assertEqual(3, len(result))

        # Verify old records are deleted and only 1 record is kept
        RTIF.delete_old_records(task_id=task.task_id,
                                dag_id=task.dag_id,
                                num_to_keep=1)
        result = session.query(RTIF) \
            .filter(RTIF.dag_id == dag.dag_id, RTIF.task_id == task.task_id).all()
        self.assertEqual(1, len(result))
        self.assertEqual(rtif_3.execution_date, result[0].execution_date)
Esempio n. 13
0
    def verify_integrity(self, session=None):
        """
        Verifies the DagRun by checking for removed tasks or tasks that are not in the
        database yet. It will set state to removed or add the task if required.
        """
        dag = self.get_dag()
        tis = self.get_task_instances(session=session)

        # check for removed or restored tasks
        task_ids = []
        for ti in tis:
            task_instance_mutation_hook(ti)
            task_ids.append(ti.task_id)
            task = None
            try:
                task = dag.get_task(ti.task_id)
            except AirflowException:
                if ti.state == State.REMOVED:
                    pass  # ti has already been removed, just ignore it
                elif self.state is not State.RUNNING and not dag.partial:
                    self.log.warning("Failed to get task '{}' for dag '{}'. "
                                     "Marking it as removed.".format(ti, dag))
                    Stats.incr("task_removed_from_dag.{}".format(dag.dag_id),
                               1, 1)
                    ti.state = State.REMOVED

            should_restore_task = (task
                                   is not None) and ti.state == State.REMOVED
            if should_restore_task:
                self.log.info("Restoring task '{}' which was previously "
                              "removed from DAG '{}'".format(ti, dag))
                Stats.incr("task_restored_to_dag.{}".format(dag.dag_id), 1, 1)
                ti.state = State.NONE
            session.merge(ti)

        # check for missing tasks
        for task in dag.task_dict.values():
            if task.start_date > self.execution_date and not self.is_backfill:
                continue

            if task.task_id not in task_ids:
                Stats.incr(
                    "task_instance_created-{}".format(task.__class__.__name__),
                    1, 1)
                ti = TI(task, self.execution_date)
                task_instance_mutation_hook(ti)
                session.add(ti)

        session.commit()
    def test_skip_all_except(self):
        dag = DAG(
            'dag_test_skip_all_except',
            start_date=DEFAULT_DATE,
        )
        with dag:
            task1 = DummyOperator(task_id='task1')
            task2 = DummyOperator(task_id='task2')
            task3 = DummyOperator(task_id='task3')

            task1 >> [task2, task3]

        ti1 = TI(task1, execution_date=DEFAULT_DATE)
        ti2 = TI(task2, execution_date=DEFAULT_DATE)
        ti3 = TI(task3, execution_date=DEFAULT_DATE)

        SkipMixin().skip_all_except(ti=ti1, branch_task_ids=['task2'])

        def get_state(ti):
            ti.refresh_from_db()
            return ti.state

        assert get_state(ti2) == State.NONE
        assert get_state(ti3) == State.SKIPPED
Esempio n. 15
0
    def test_redact(self, redact):
        dag = DAG("test_ritf_redact", start_date=START_DATE)
        with dag:
            task = BashOperator(
                task_id="test",
                bash_command="echo {{ var.value.api_key }}",
                env={'foo': 'secret', 'other_api_key': 'masked based on key name'},
            )

        redact.side_effect = [
            'val 1',
            'val 2',
        ]

        ti = TI(task=task, execution_date=EXECUTION_DATE)
        rtif = RTIF(ti=ti)
        assert rtif.rendered_fields == {
            'bash_command': 'val 1',
            'env': 'val 2',
        }
Esempio n. 16
0
    def test_delete_old_records(self, rtif_num, num_to_keep, remaining_rtifs, expected_query_count):
        """
        Test that old records are deleted from rendered_task_instance_fields table
        for a given task_id and dag_id.
        """
        session = settings.Session()
        dag = DAG("test_delete_old_records", start_date=START_DATE)
        with dag:
            task = BashOperator(task_id="test", bash_command="echo {{ ds }}")

        rtif_list = [
            RTIF(TI(task=task, execution_date=EXECUTION_DATE + timedelta(days=num)))
            for num in range(rtif_num)
        ]

        session.add_all(rtif_list)
        session.commit()

        result = session.query(RTIF).filter(RTIF.dag_id == dag.dag_id, RTIF.task_id == task.task_id).all()

        for rtif in rtif_list:
            assert rtif in result

        assert rtif_num == len(result)

        # Verify old records are deleted and only 'num_to_keep' records are kept
        # For other DBs,an extra query is fired in RenderedTaskInstanceFields.delete_old_records
        expected_query_count_based_on_db = (
            expected_query_count + 1
            if session.bind.dialect.name == "mssql" and expected_query_count != 0
            else expected_query_count
        )

        with assert_queries_count(expected_query_count_based_on_db):
            RTIF.delete_old_records(task_id=task.task_id, dag_id=task.dag_id, num_to_keep=num_to_keep)
        result = session.query(RTIF).filter(RTIF.dag_id == dag.dag_id, RTIF.task_id == task.task_id).all()
        assert remaining_rtifs == len(result)
    def test_get_k8s_pod_yaml(self, mock_pod_mutation_hook):
        """
        Test that k8s_pod_yaml is rendered correctly, stored in the Database,
        and are correctly fetched using RTIF.get_k8s_pod_yaml
        """
        dag = DAG("test_get_k8s_pod_yaml", start_date=START_DATE)
        with dag:
            task = BashOperator(task_id="test", bash_command="echo hi")

        ti = TI(task=task, execution_date=EXECUTION_DATE)
        rtif = RTIF(ti=ti)

        # Test that pod_mutation_hook is called
        mock_pod_mutation_hook.assert_called_once_with(mock.ANY)

        self.assertEqual(ti.dag_id, rtif.dag_id)
        self.assertEqual(ti.task_id, rtif.task_id)
        self.assertEqual(ti.execution_date, rtif.execution_date)

        expected_pod_yaml = {
            'metadata': {
                'annotations': {
                    'dag_id': 'test_get_k8s_pod_yaml',
                    'execution_date': '2019-01-01T00:00:00+00:00',
                    'task_id': 'test',
                    'try_number': '1',
                },
                'labels': {
                    'airflow-worker': 'worker-config',
                    'airflow_version': version,
                    'dag_id': 'test_get_k8s_pod_yaml',
                    'execution_date': '2019-01-01T00_00_00_plus_00_00',
                    'kubernetes_executor': 'True',
                    'task_id': 'test',
                    'try_number': '1',
                },
                'name': mock.ANY,
                'namespace': 'default',
            },
            'spec': {
                'containers': [
                    {
                        'command': [
                            'airflow',
                            'tasks',
                            'run',
                            'test_get_k8s_pod_yaml',
                            'test',
                            '2019-01-01T00:00:00+00:00',
                        ],
                        'image': ':',
                        'name': 'base',
                    }
                ]
            },
        }

        self.assertEqual(expected_pod_yaml, rtif.k8s_pod_yaml)

        with create_session() as session:
            session.add(rtif)

        self.assertEqual(expected_pod_yaml, RTIF.get_k8s_pod_yaml(ti=ti))

        # Test the else part of get_k8s_pod_yaml
        # i.e. for the TIs that are not stored in RTIF table
        # Fetching them will return None
        with dag:
            task_2 = BashOperator(task_id="test2", bash_command="echo hello")

        ti2 = TI(task_2, EXECUTION_DATE)
        self.assertIsNone(RTIF.get_k8s_pod_yaml(ti=ti2))
Esempio n. 18
0
 def create_ti(task: "BaseOperator") -> TI:
     ti = TI(task, run_id=self.run_id)
     task_instance_mutation_hook(ti)
     created_counts[ti.operator] += 1
     return ti