def test_delete_old_records(self, rtif_num, num_to_keep, remaining_rtifs, expected_query_count):
        """
        Test that old records are deleted from rendered_task_instance_fields table
        for a given task_id and dag_id.
        """
        session = settings.Session()
        dag = DAG("test_delete_old_records", start_date=START_DATE)
        with dag:
            task = BashOperator(task_id="test", bash_command="echo {{ ds }}")

        rtif_list = [
            RTIF(TI(task=task, execution_date=EXECUTION_DATE + timedelta(days=num)))
            for num in range(rtif_num)
        ]

        session.add_all(rtif_list)
        session.commit()

        result = session.query(RTIF).filter(RTIF.dag_id == dag.dag_id, RTIF.task_id == task.task_id).all()

        for rtif in rtif_list:
            self.assertIn(rtif, result)

        self.assertEqual(rtif_num, len(result))

        # Verify old records are deleted and only 'num_to_keep' records are kept
        with assert_queries_count(expected_query_count):
            RTIF.delete_old_records(task_id=task.task_id, dag_id=task.dag_id, num_to_keep=num_to_keep)
        result = session.query(RTIF).filter(RTIF.dag_id == dag.dag_id, RTIF.task_id == task.task_id).all()
        self.assertEqual(remaining_rtifs, len(result))
    def test_get_templated_fields(self, templated_field, expected_rendered_field):
        """
        Test that template_fields are rendered correctly, stored in the Database,
        and are correctly fetched using RTIF.get_templated_fields
        """
        dag = DAG("test_serialized_rendered_fields", start_date=START_DATE)
        with dag:
            task = BashOperator(task_id="test", bash_command=templated_field)

        ti = TI(task=task, execution_date=EXECUTION_DATE)
        rtif = RTIF(ti=ti)
        self.assertEqual(ti.dag_id, rtif.dag_id)
        self.assertEqual(ti.task_id, rtif.task_id)
        self.assertEqual(ti.execution_date, rtif.execution_date)
        self.assertEqual(expected_rendered_field, rtif.rendered_fields.get("bash_command"))

        with create_session() as session:
            session.add(rtif)

        self.assertEqual(
            {"bash_command": expected_rendered_field, "env": None}, RTIF.get_templated_fields(ti=ti)
        )

        # Test the else part of get_templated_fields
        # i.e. for the TIs that are not stored in RTIF table
        # Fetching them will return None
        with dag:
            task_2 = BashOperator(task_id="test2", bash_command=templated_field)

        ti2 = TI(task_2, EXECUTION_DATE)
        self.assertIsNone(RTIF.get_templated_fields(ti=ti2))
    def test_delete_old_records(self):
        """
        Test that old records are deleted from rendered_task_instance_fields table
        for a given task_id and dag_id.
        """
        session = settings.Session()
        dag = DAG("test_delete_old_records", start_date=START_DATE)
        with dag:
            task = BashOperator(task_id="test", bash_command="echo {{ ds }}")

        rtif_1 = RTIF(TI(task=task, execution_date=EXECUTION_DATE))
        rtif_2 = RTIF(
            TI(task=task, execution_date=EXECUTION_DATE + timedelta(days=1)))
        rtif_3 = RTIF(
            TI(task=task, execution_date=EXECUTION_DATE + timedelta(days=2)))

        session.add(rtif_1)
        session.add(rtif_2)
        session.add(rtif_3)
        session.commit()

        result = session.query(RTIF)\
            .filter(RTIF.dag_id == dag.dag_id, RTIF.task_id == task.task_id).all()

        self.assertIn(rtif_1, result)
        self.assertIn(rtif_2, result)
        self.assertIn(rtif_3, result)
        self.assertEqual(3, len(result))

        # Verify old records are deleted and only 1 record is kept
        RTIF.delete_old_records(task_id=task.task_id,
                                dag_id=task.dag_id,
                                num_to_keep=1)
        result = session.query(RTIF) \
            .filter(RTIF.dag_id == dag.dag_id, RTIF.task_id == task.task_id).all()
        self.assertEqual(1, len(result))
        self.assertEqual(rtif_3.execution_date, result[0].execution_date)
Example #4
0
    def test_delete_old_records(self, rtif_num, num_to_keep, remaining_rtifs, expected_query_count):
        """
        Test that old records are deleted from rendered_task_instance_fields table
        for a given task_id and dag_id.
        """
        session = settings.Session()
        dag = DAG("test_delete_old_records", start_date=START_DATE)
        with dag:
            task = BashOperator(task_id="test", bash_command="echo {{ ds }}")

        rtif_list = [
            RTIF(TI(task=task, execution_date=EXECUTION_DATE + timedelta(days=num)))
            for num in range(rtif_num)
        ]

        session.add_all(rtif_list)
        session.commit()

        result = session.query(RTIF).filter(RTIF.dag_id == dag.dag_id, RTIF.task_id == task.task_id).all()

        for rtif in rtif_list:
            assert rtif in result

        assert rtif_num == len(result)

        # Verify old records are deleted and only 'num_to_keep' records are kept
        # For other DBs,an extra query is fired in RenderedTaskInstanceFields.delete_old_records
        expected_query_count_based_on_db = (
            expected_query_count + 1
            if session.bind.dialect.name == "mssql" and expected_query_count != 0
            else expected_query_count
        )

        with assert_queries_count(expected_query_count_based_on_db):
            RTIF.delete_old_records(task_id=task.task_id, dag_id=task.dag_id, num_to_keep=num_to_keep)
        result = session.query(RTIF).filter(RTIF.dag_id == dag.dag_id, RTIF.task_id == task.task_id).all()
        assert remaining_rtifs == len(result)
    def test_write(self):
        """
        Test records can be written and overwritten
        """
        Variable.set(key="test_key", value="test_val")

        session = settings.Session()
        result = session.query(RTIF).all()
        assert [] == result

        with DAG("test_write", start_date=START_DATE):
            task = BashOperator(task_id="test",
                                bash_command="echo {{ var.value.test_key }}")

        rtif = RTIF(TI(task=task, execution_date=EXECUTION_DATE))
        rtif.write()
        result = (session.query(RTIF.dag_id, RTIF.task_id,
                                RTIF.rendered_fields).filter(
                                    RTIF.dag_id == rtif.dag_id,
                                    RTIF.task_id == rtif.task_id,
                                    RTIF.execution_date == rtif.execution_date,
                                ).first())
        assert ('test_write', 'test', {
            'bash_command': 'echo test_val',
            'env': None
        }) == result

        # Test that overwrite saves new values to the DB
        Variable.delete("test_key")
        Variable.set(key="test_key", value="test_val_updated")

        with DAG("test_write", start_date=START_DATE):
            updated_task = BashOperator(
                task_id="test", bash_command="echo {{ var.value.test_key }}")

        rtif_updated = RTIF(
            TI(task=updated_task, execution_date=EXECUTION_DATE))
        rtif_updated.write()

        result_updated = (session.query(
            RTIF.dag_id, RTIF.task_id, RTIF.rendered_fields).filter(
                RTIF.dag_id == rtif_updated.dag_id,
                RTIF.task_id == rtif_updated.task_id,
                RTIF.execution_date == rtif_updated.execution_date,
            ).first())
        assert (
            'test_write',
            'test',
            {
                'bash_command': 'echo test_val_updated',
                'env': None
            },
        ) == result_updated
    def test_get_k8s_pod_yaml(self, mock_pod_mutation_hook):
        """
        Test that k8s_pod_yaml is rendered correctly, stored in the Database,
        and are correctly fetched using RTIF.get_k8s_pod_yaml
        """
        dag = DAG("test_get_k8s_pod_yaml", start_date=START_DATE)
        with dag:
            task = BashOperator(task_id="test", bash_command="echo hi")

        ti = TI(task=task, execution_date=EXECUTION_DATE)
        rtif = RTIF(ti=ti)

        # Test that pod_mutation_hook is called
        mock_pod_mutation_hook.assert_called_once_with(mock.ANY)

        self.assertEqual(ti.dag_id, rtif.dag_id)
        self.assertEqual(ti.task_id, rtif.task_id)
        self.assertEqual(ti.execution_date, rtif.execution_date)

        expected_pod_yaml = {
            'metadata': {
                'annotations': {
                    'dag_id': 'test_get_k8s_pod_yaml',
                    'execution_date': '2019-01-01T00:00:00+00:00',
                    'task_id': 'test',
                    'try_number': '1',
                },
                'labels': {
                    'airflow-worker': 'worker-config',
                    'airflow_version': version,
                    'dag_id': 'test_get_k8s_pod_yaml',
                    'execution_date': '2019-01-01T00_00_00_plus_00_00',
                    'kubernetes_executor': 'True',
                    'task_id': 'test',
                    'try_number': '1',
                },
                'name': mock.ANY,
                'namespace': 'default',
            },
            'spec': {
                'containers': [
                    {
                        'command': [
                            'airflow',
                            'tasks',
                            'run',
                            'test_get_k8s_pod_yaml',
                            'test',
                            '2019-01-01T00:00:00+00:00',
                        ],
                        'image': ':',
                        'name': 'base',
                    }
                ]
            },
        }

        self.assertEqual(expected_pod_yaml, rtif.k8s_pod_yaml)

        with create_session() as session:
            session.add(rtif)

        self.assertEqual(expected_pod_yaml, RTIF.get_k8s_pod_yaml(ti=ti))

        # Test the else part of get_k8s_pod_yaml
        # i.e. for the TIs that are not stored in RTIF table
        # Fetching them will return None
        with dag:
            task_2 = BashOperator(task_id="test2", bash_command="echo hello")

        ti2 = TI(task_2, EXECUTION_DATE)
        self.assertIsNone(RTIF.get_k8s_pod_yaml(ti=ti2))