def test_delete_old_records(self, rtif_num, num_to_keep, remaining_rtifs, expected_query_count):
        """
        Test that old records are deleted from rendered_task_instance_fields table
        for a given task_id and dag_id.
        """
        session = settings.Session()
        dag = DAG("test_delete_old_records", start_date=START_DATE)
        with dag:
            task = BashOperator(task_id="test", bash_command="echo {{ ds }}")

        rtif_list = [
            RTIF(TI(task=task, execution_date=EXECUTION_DATE + timedelta(days=num)))
            for num in range(rtif_num)
        ]

        session.add_all(rtif_list)
        session.commit()

        result = session.query(RTIF).filter(RTIF.dag_id == dag.dag_id, RTIF.task_id == task.task_id).all()

        for rtif in rtif_list:
            self.assertIn(rtif, result)

        self.assertEqual(rtif_num, len(result))

        # Verify old records are deleted and only 'num_to_keep' records are kept
        with assert_queries_count(expected_query_count):
            RTIF.delete_old_records(task_id=task.task_id, dag_id=task.dag_id, num_to_keep=num_to_keep)
        result = session.query(RTIF).filter(RTIF.dag_id == dag.dag_id, RTIF.task_id == task.task_id).all()
        self.assertEqual(remaining_rtifs, len(result))
 def test_bulk_sync_to_db(self):
     dags = [
         DAG("dag_1"),
         DAG("dag_2"),
         DAG("dag_3"),
     ]
     with assert_queries_count(10):
         SDM.bulk_sync_to_db(dags)
Beispiel #3
0
 def test_count_number_queries(self, tasks_count):
     dag = DAG('test_dagrun_query_count', start_date=DEFAULT_DATE)
     for i in range(tasks_count):
         DummyOperator(task_id=f'dummy_task_{i}', owner='test', dag=dag)
     with assert_queries_count(3):
         dag.create_dagrun(
             run_id="test_dagrun_query_count",
             state=State.RUNNING
         )
Beispiel #4
0
    def test_get_dag_with_dag_serialization(self):
        """
        Test that Serialized DAG is updated in DagBag when it is updated in
        Serialized DAG table after 'min_serialized_dag_fetch_interval' seconds are passed.
        """

        with freeze_time(tz.datetime(2020, 1, 5, 0, 0, 0)):
            example_bash_op_dag = DagBag(
                include_examples=True).dags.get("example_bash_operator")
            SerializedDagModel.write_dag(dag=example_bash_op_dag)

            dag_bag = DagBag(read_dags_from_db=True)
            ser_dag_1 = dag_bag.get_dag("example_bash_operator")
            ser_dag_1_update_time = dag_bag.dags_last_fetched[
                "example_bash_operator"]
            self.assertEqual(example_bash_op_dag.tags, ser_dag_1.tags)
            self.assertEqual(ser_dag_1_update_time,
                             tz.datetime(2020, 1, 5, 0, 0, 0))

        # Check that if min_serialized_dag_fetch_interval has not passed we do not fetch the DAG
        # from DB
        with freeze_time(tz.datetime(2020, 1, 5, 0, 0, 4)):
            with assert_queries_count(0):
                self.assertEqual(
                    dag_bag.get_dag("example_bash_operator").tags,
                    ["example", "example2"])

        # Make a change in the DAG and write Serialized DAG to the DB
        with freeze_time(tz.datetime(2020, 1, 5, 0, 0, 6)):
            example_bash_op_dag.tags += ["new_tag"]
            SerializedDagModel.write_dag(dag=example_bash_op_dag)

        # Since min_serialized_dag_fetch_interval is passed verify that calling 'dag_bag.get_dag'
        # fetches the Serialized DAG from DB
        with freeze_time(tz.datetime(2020, 1, 5, 0, 0, 8)):
            with assert_queries_count(2):
                updated_ser_dag_1 = dag_bag.get_dag("example_bash_operator")
                updated_ser_dag_1_update_time = dag_bag.dags_last_fetched[
                    "example_bash_operator"]

        self.assertCountEqual(updated_ser_dag_1.tags,
                              ["example", "example2", "new_tag"])
        self.assertGreater(updated_ser_dag_1_update_time,
                           ser_dag_1_update_time)
Beispiel #5
0
    def test_get_all_roles_with_permissions(self):
        with assert_queries_count(1):
            roles = self.security_manager._get_all_roles_with_permissions()

        assert isinstance(roles, dict)
        for role_name, role in roles.items():
            assert isinstance(role_name, str)
            assert isinstance(role, self.security_manager.role_model)

        assert 'Admin' in roles
Beispiel #6
0
    def test_get_all_permissions(self):
        with assert_queries_count(1):
            perms = self.security_manager.get_all_permissions()

        assert isinstance(perms, set)
        for perm in perms:
            assert isinstance(perm, tuple)
            assert len(perm) == 2

        assert ('can_read', 'Connections') in perms
Beispiel #7
0
    def test_get_all_non_dag_permissionviews(self):
        with assert_queries_count(1):
            pvs = self.security_manager._get_all_non_dag_permissionviews()

        assert isinstance(pvs, dict)
        for (perm_name, viewmodel_name), perm_view in pvs.items():
            assert isinstance(perm_name, str)
            assert isinstance(viewmodel_name, str)
            assert isinstance(perm_view, self.security_manager.permissionview_model)

        assert ('can_read', 'Connections') in pvs
    def test_number_of_queries_single_loop(self, mock_get_task_runner, return_codes):
        unique_prefix = str(uuid.uuid4())
        dag = DAG(dag_id=f'{unique_prefix}_test_number_of_queries', start_date=DEFAULT_DATE)
        task = DummyOperator(task_id='test_state_succeeded1', dag=dag)

        dag.clear()
        dag.create_dagrun(run_id=unique_prefix, state=State.NONE)

        ti = TaskInstance(task=task, execution_date=DEFAULT_DATE)

        mock_get_task_runner.return_value.return_code.side_effects = return_codes

        job = LocalTaskJob(task_instance=ti, executor=MockExecutor())
        with assert_queries_count(13):
            job.run()
Beispiel #9
0
 def test_should_not_do_database_queries(self):
     example_dags = glob(f"{ROOT_FOLDER}/airflow/**/example_dags/example_*.py", recursive=True)
     example_dags = [
         dag_file
         for dag_file in example_dags
         if any(not dag_file.endswith(e) for e in NO_DB_QUERY_EXCEPTION)
     ]
     assert 0 != len(example_dags)
     for filepath in example_dags:
         relative_filepath = os.path.relpath(filepath, ROOT_FOLDER)
         with self.subTest(f"File {relative_filepath} shouldn't do database queries"):
             with assert_queries_count(0):
                 DagBag(
                     dag_folder=filepath,
                     include_examples=False,
                 )
Beispiel #10
0
    def test_bulk_sync_roles_modelview(self):
        role_name = 'MyRole2'
        role_perms = [
            ('can_list', 'SomeModelView'),
            ('can_show', 'SomeModelView'),
            ('can_add', 'SomeModelView'),
            (permissions.ACTION_CAN_EDIT, 'SomeModelView'),
            (permissions.ACTION_CAN_DELETE, 'SomeModelView'),
        ]
        mock_roles = [{'role': role_name, 'perms': role_perms}]
        self.security_manager.bulk_sync_roles(mock_roles)

        role = self.appbuilder.sm.find_role(role_name)
        assert role is not None
        assert len(role_perms) == len(role.permissions)

        # Check short circuit works
        with assert_queries_count(2):  # One for permissionview, one for roles
            self.security_manager.bulk_sync_roles(mock_roles)
Beispiel #11
0
    def test_create_dag_specific_permissions(self):
        dag_id = 'some_dag_id'
        dag_permission_name = self.security_manager.prefixed_dag_id(dag_id)
        assert ('can_read', dag_permission_name) not in self.security_manager.get_all_permissions()

        dag_model = DagModel(
            dag_id=dag_id, fileloc='/tmp/dag_.py', schedule_interval='2 2 * * *', is_paused=True
        )
        self.session.add(dag_model)
        self.session.commit()

        self.security_manager.create_dag_specific_permissions()
        self.session.commit()

        assert ('can_read', dag_permission_name) in self.security_manager.get_all_permissions()

        # Make sure we short circuit when the perms already exist
        with assert_queries_count(2):  # One query to get DagModels, one query to get all perms
            self.security_manager.create_dag_specific_permissions()
Beispiel #12
0
    def test_create_dag_specific_permissions(self, dagbag_mock):
        access_control = {'Public': {permissions.ACTION_CAN_READ}}
        dags = [
            DAG('has_access_control', access_control=access_control),
            DAG('no_access_control'),
        ]

        collect_dags_from_db_mock = mock.Mock()
        dagbag = mock.Mock()

        dagbag.dags = {dag.dag_id: dag for dag in dags}
        dagbag.collect_dags_from_db = collect_dags_from_db_mock
        dagbag_mock.return_value = dagbag

        self.security_manager._sync_dag_view_permissions = mock.Mock()

        for dag in dags:
            dag_resource_name = permissions.resource_name_for_dag(dag.dag_id)
            all_perms = self.security_manager.get_all_permissions()
            assert ('can_read', dag_resource_name) not in all_perms
            assert ('can_edit', dag_resource_name) not in all_perms

        self.security_manager.create_dag_specific_permissions()

        dagbag_mock.assert_called_once_with(read_dags_from_db=True)
        collect_dags_from_db_mock.assert_called_once_with()

        for dag in dags:
            dag_resource_name = permissions.resource_name_for_dag(dag.dag_id)
            all_perms = self.security_manager.get_all_permissions()
            assert ('can_read', dag_resource_name) in all_perms
            assert ('can_edit', dag_resource_name) in all_perms

        self.security_manager._sync_dag_view_permissions.assert_called_once_with(
            permissions.resource_name_for_dag('has_access_control'),
            access_control)

        del dagbag.dags["has_access_control"]
        with assert_queries_count(
                1):  # one query to get all perms; dagbag is mocked
            self.security_manager.create_dag_specific_permissions()
Beispiel #13
0
    def test_delete_old_records(self, rtif_num, num_to_keep, remaining_rtifs, expected_query_count):
        """
        Test that old records are deleted from rendered_task_instance_fields table
        for a given task_id and dag_id.
        """
        session = settings.Session()
        dag = DAG("test_delete_old_records", start_date=START_DATE)
        with dag:
            task = BashOperator(task_id="test", bash_command="echo {{ ds }}")

        rtif_list = [
            RTIF(TI(task=task, execution_date=EXECUTION_DATE + timedelta(days=num)))
            for num in range(rtif_num)
        ]

        session.add_all(rtif_list)
        session.commit()

        result = session.query(RTIF).filter(RTIF.dag_id == dag.dag_id, RTIF.task_id == task.task_id).all()

        for rtif in rtif_list:
            assert rtif in result

        assert rtif_num == len(result)

        # Verify old records are deleted and only 'num_to_keep' records are kept
        # For other DBs,an extra query is fired in RenderedTaskInstanceFields.delete_old_records
        expected_query_count_based_on_db = (
            expected_query_count + 1
            if session.bind.dialect.name == "mssql" and expected_query_count != 0
            else expected_query_count
        )

        with assert_queries_count(expected_query_count_based_on_db):
            RTIF.delete_old_records(task_id=task.task_id, dag_id=task.dag_id, num_to_keep=num_to_keep)
        result = session.query(RTIF).filter(RTIF.dag_id == dag.dag_id, RTIF.task_id == task.task_id).all()
        assert remaining_rtifs == len(result)
Beispiel #14
0
    def test_bulk_sync_to_db(self):
        clear_db_dags()
        dags = [
            DAG(f'dag-bulk-sync-{i}', start_date=DEFAULT_DATE, tags=["test-dag"]) for i in range(0, 4)
        ]

        with assert_queries_count(3):
            DAG.bulk_sync_to_db(dags)
        with create_session() as session:
            self.assertEqual(
                {'dag-bulk-sync-0', 'dag-bulk-sync-1', 'dag-bulk-sync-2', 'dag-bulk-sync-3'},
                {row[0] for row in session.query(DagModel.dag_id).all()}
            )
            self.assertEqual(
                {
                    ('dag-bulk-sync-0', 'test-dag'),
                    ('dag-bulk-sync-1', 'test-dag'),
                    ('dag-bulk-sync-2', 'test-dag'),
                    ('dag-bulk-sync-3', 'test-dag'),
                },
                set(session.query(DagTag.dag_id, DagTag.name).all())
            )
        # Re-sync should do fewer queries
        with assert_queries_count(2):
            DAG.bulk_sync_to_db(dags)
        with assert_queries_count(2):
            DAG.bulk_sync_to_db(dags)
        # Adding tags
        for dag in dags:
            dag.tags.append("test-dag2")
        with assert_queries_count(3):
            DAG.bulk_sync_to_db(dags)
        with create_session() as session:
            self.assertEqual(
                {'dag-bulk-sync-0', 'dag-bulk-sync-1', 'dag-bulk-sync-2', 'dag-bulk-sync-3'},
                {row[0] for row in session.query(DagModel.dag_id).all()}
            )
            self.assertEqual(
                {
                    ('dag-bulk-sync-0', 'test-dag'),
                    ('dag-bulk-sync-0', 'test-dag2'),
                    ('dag-bulk-sync-1', 'test-dag'),
                    ('dag-bulk-sync-1', 'test-dag2'),
                    ('dag-bulk-sync-2', 'test-dag'),
                    ('dag-bulk-sync-2', 'test-dag2'),
                    ('dag-bulk-sync-3', 'test-dag'),
                    ('dag-bulk-sync-3', 'test-dag2'),
                },
                set(session.query(DagTag.dag_id, DagTag.name).all())
            )
        # Removing tags
        for dag in dags:
            dag.tags.remove("test-dag")
        with assert_queries_count(3):
            DAG.bulk_sync_to_db(dags)
        with create_session() as session:
            self.assertEqual(
                {'dag-bulk-sync-0', 'dag-bulk-sync-1', 'dag-bulk-sync-2', 'dag-bulk-sync-3'},
                {row[0] for row in session.query(DagModel.dag_id).all()}
            )
            self.assertEqual(
                {
                    ('dag-bulk-sync-0', 'test-dag2'),
                    ('dag-bulk-sync-1', 'test-dag2'),
                    ('dag-bulk-sync-2', 'test-dag2'),
                    ('dag-bulk-sync-3', 'test-dag2'),
                },
                set(session.query(DagTag.dag_id, DagTag.name).all())
            )
Beispiel #15
0
def test_index(admin_client):
    with assert_queries_count(44):
        resp = admin_client.get('/', follow_redirects=True)
    check_content_in_response('DAGs', resp)