예제 #1
0
 def test_get_infrastructure_provider_uuid(self):
     """Test that infrastructure provider UUID is returned."""
     infrastructure_type = Provider.PROVIDER_AWS
     with ProviderDBAccessor(self.ocp_provider_uuid) as accessor:
         accessor.set_infrastructure(self.aws_provider_uuid,
                                     infrastructure_type)
         self.assertEqual(accessor.get_infrastructure_provider_uuid(),
                          self.aws_provider_uuid)
예제 #2
0
 def test_get_infrastructure_type(self):
     """Test that infrastructure type is returned."""
     infrastructure_type = 'AWS'
     with ProviderDBAccessor(self.ocp_provider_uuid) as accessor:
         accessor.set_infrastructure(self.aws_provider_uuid,
                                     infrastructure_type)
         self.assertEqual(accessor.get_infrastructure_type(),
                          infrastructure_type)
예제 #3
0
 def get_date_column_filter(self):
     """Return a filter using the provider-appropriate column."""
     with ProviderDBAccessor(self._provider_uuid) as provider_accessor:
         type = provider_accessor.get_type().lower()
     if type == 'azure':
         return {'usage_date_time__gte': self.data_cutoff_date}
     else:
         return {'usage_start__gte': self.data_cutoff_date}
예제 #4
0
    def setUp(self):
        """Test set up."""
        super().setUp()

        with ProviderDBAccessor(
                self.aws_test_provider_uuid) as provider_accessor:
            provider = provider_accessor.get_provider()
            self.provider_id = provider.id
예제 #5
0
 def get_date_column_filter(self):
     """Return a filter using the provider-appropriate column."""
     with ProviderDBAccessor(self._provider_uuid) as provider_accessor:
         type = provider_accessor.get_type()
     if type in (Provider.PROVIDER_AZURE, Provider.PROVIDER_AZURE_LOCAL):
         return {"usage_date__gte": self.data_cutoff_date}
     else:
         return {"usage_start__gte": self.data_cutoff_date}
예제 #6
0
파일: test_tasks.py 프로젝트: ebpetway/koku
    def test_process_report_files_with_transaction_atomic_error(
            self, mock_files, mock_processor):
        """Test than an exception rolls back the atomic transaction."""
        path = "{}/{}".format("test", "file1.csv")
        mock_files.return_value = [{"file": path, "compression": "GZIP"}]
        schema_name = self.schema
        provider = Provider.PROVIDER_AWS
        provider_uuid = self.aws_provider_uuid
        report_month = DateHelper().today
        manifest_dict = {
            "assembly_id": "12345",
            "billing_period_start_datetime": report_month,
            "num_total_files": 1,
            "provider_uuid": self.aws_provider_uuid,
            "task": "170653c0-3e66-4b7e-a764-336496d7ca5a",
        }
        with ReportManifestDBAccessor() as manifest_accessor:
            manifest = manifest_accessor.add(**manifest_dict)
            manifest.save()
            manifest_id = manifest.id
            initial_update_time = manifest.manifest_updated_datetime

        with ReportStatsDBAccessor("file1.csv", manifest_id) as stats_accessor:
            stats_accessor.get_last_completed_datetime

        with ReportStatsDBAccessor(path, manifest_id) as report_file_accessor:
            report_file_accessor.get_last_started_datetime()

        mock_processor.side_effect = Exception

        with self.assertRaises(Exception):
            customer_name = "Fake Customer"
            authentication = "auth"
            billing_source = "bill"
            provider_type = provider
            get_report_files(
                customer_name=customer_name,
                authentication=authentication,
                billing_source=billing_source,
                provider_type=provider_type,
                schema_name=schema_name,
                provider_uuid=provider_uuid,
                report_month=report_month,
            )

        with ReportStatsDBAccessor(path, manifest_id) as report_file_accessor:
            self.assertIsNone(
                report_file_accessor.get_last_completed_datetime())

        with ReportManifestDBAccessor() as manifest_accessor:
            manifest = manifest_accessor.get_manifest_by_id(manifest_id)
            self.assertEqual(manifest.num_processed_files, 0)
            self.assertEqual(manifest.manifest_updated_datetime,
                             initial_update_time)

        with ProviderDBAccessor(
                provider_uuid=provider_uuid) as provider_accessor:
            self.assertFalse(provider_accessor.get_setup_complete())
예제 #7
0
def _process_report_file(schema_name, provider, provider_uuid, report_dict):
    """
    Task to process a Report.

    Args:
        schema_name   (String) db schema name
        provider      (String) provider type
        provider_uuid (String) provider uuid
        report_dict   (dict) The report data dict from previous task

    Returns:
        None

    """
    start_date = report_dict.get('start_date')
    report_path = report_dict.get('file')
    compression = report_dict.get('compression')
    manifest_id = report_dict.get('manifest_id')
    provider_id = report_dict.get('provider_id')
    log_statement = (f'Processing Report:\n'
                     f' schema_name: {schema_name}\n'
                     f' provider: {provider}\n'
                     f' provider_uuid: {provider_uuid}\n'
                     f' file: {report_path}\n'
                     f' compression: {compression}\n'
                     f' start_date: {start_date}')
    LOG.info(log_statement)
    mem = psutil.virtual_memory()
    mem_msg = f'Avaiable memory: {mem.free} bytes ({mem.percent}%)'
    LOG.info(mem_msg)

    file_name = report_path.split('/')[-1]
    with ReportStatsDBAccessor(file_name, manifest_id) as stats_recorder:
        stats_recorder.log_last_started_datetime()
    processor = ReportProcessor(schema_name=schema_name,
                                report_path=report_path,
                                compression=compression,
                                provider=provider,
                                provider_id=provider_id,
                                manifest_id=manifest_id)
    processor.process()
    with ReportStatsDBAccessor(file_name, manifest_id) as stats_recorder:
        stats_recorder.log_last_completed_datetime()

    with ReportManifestDBAccessor() as manifest_accesor:
        manifest = manifest_accesor.get_manifest_by_id(manifest_id)
        if manifest:
            manifest.num_processed_files += 1
            manifest.save()
            manifest_accesor.mark_manifest_as_updated(manifest)
        else:
            LOG.error('Unable to find manifest for ID: %s, file %s', manifest_id, file_name)

    with ProviderDBAccessor(provider_uuid=provider_uuid) as provider_accessor:
        if provider_accessor.get_setup_complete():
            files = processor.remove_processed_files(path.dirname(report_path))
            LOG.info('Temporary files removed: %s', str(files))
        provider_accessor.setup_complete()
예제 #8
0
    def _process_manifest_db_record(self, assembly_id, billing_start,
                                    num_of_files, manifest_modified_datetime,
                                    **kwargs):
        """Insert or update the manifest DB record."""
        msg = f"Inserting/updating manifest in database for assembly_id: {assembly_id}"
        LOG.info(log_json(self.tracing_id, msg))

        with ReportManifestDBAccessor() as manifest_accessor:
            manifest_entry = manifest_accessor.get_manifest(
                assembly_id, self._provider_uuid)

            if not manifest_entry:
                msg = f"No manifest entry found in database. Adding for bill period start: {billing_start}"
                LOG.info(log_json(self.tracing_id, msg, self.context))
                manifest_dict = {
                    "assembly_id": assembly_id,
                    "billing_period_start_datetime": billing_start,
                    "num_total_files": num_of_files,
                    "provider_uuid": self._provider_uuid,
                    "manifest_modified_datetime": manifest_modified_datetime,
                }
                manifest_dict.update(kwargs)
                try:
                    manifest_entry = manifest_accessor.add(**manifest_dict)
                except IntegrityError as error:
                    fk_violation = FKViolation(error)
                    if fk_violation:
                        LOG.warning(fk_violation)
                        raise ReportDownloaderError(
                            f"Method: _process_manifest_db_record :: {fk_violation}"
                        )
                    msg = (
                        f"Manifest entry uniqueness collision: Error {error}. "
                        "Manifest already added, getting manifest_entry_id.")
                    LOG.warning(log_json(self.tracing_id, msg, self.context))
                    with ReportManifestDBAccessor() as manifest_accessor:
                        manifest_entry = manifest_accessor.get_manifest(
                            assembly_id, self._provider_uuid)
            if not manifest_entry:
                msg = f"Manifest entry not found for given manifest {manifest_dict}."
                with ProviderDBAccessor(
                        self._provider_uuid) as provider_accessor:
                    provider = provider_accessor.get_provider()
                    if not provider:
                        msg = f"Provider entry not found for {self._provider_uuid}."
                        LOG.warning(
                            log_json(self.tracing_id, msg, self.context))
                        raise ReportDownloaderError(msg)
                LOG.warning(log_json(self.tracing_id, msg, self.context))
                raise IntegrityError(msg)
            else:
                if num_of_files != manifest_entry.num_total_files:
                    manifest_accessor.update_number_of_files_for_manifest(
                        manifest_entry)
                manifest_accessor.mark_manifest_as_updated(manifest_entry)
                manifest_id = manifest_entry.id

        return manifest_id
예제 #9
0
    def test_get_customer_uuid(self):
        """Test provider billing_source getter."""
        expected_uuid = None
        with CustomerDBAccessor(self.customer.id) as customer_accessor:
            expected_uuid = customer_accessor.get_uuid()

        uuid = self.aws_provider_uuid
        with ProviderDBAccessor(uuid) as accessor:
            self.assertEqual(expected_uuid, accessor.get_customer_uuid())
예제 #10
0
파일: test_common.py 프로젝트: xJustin/koku
    def setUp(self):
        """Shared variables used by ocp common tests."""
        super().setUp()
        self.accessor = OCPReportDBAccessor(schema=self.schema)
        self.provider_accessor = ProviderDBAccessor(
            provider_uuid=self.ocp_test_provider_uuid)
        self.report_schema = self.accessor.report_schema
        self.creator = ReportObjectCreator(self.schema)
        self.all_tables = list(OCP_REPORT_TABLE_MAP.values())

        self.provider_uuid = self.provider_accessor.get_provider().uuid
        reporting_period = self.creator.create_ocp_report_period(
            provider_uuid=self.provider_uuid)
        report = self.creator.create_ocp_report(
            reporting_period, reporting_period.report_period_start)
        self.creator.create_ocp_usage_line_item(reporting_period, report)
        self.creator.create_ocp_storage_line_item(reporting_period, report)
        self.creator.create_ocp_node_label_line_item(reporting_period, report)
예제 #11
0
def refresh_materialized_views(  # noqa: C901
    schema_name,
    provider_type,
    manifest_id=None,
    provider_uuid="",
    synchronous=False,
    queue_name=None,
    tracing_id=None,
):
    """Refresh the database's materialized views for reporting."""
    task_name = "masu.processor.tasks.refresh_materialized_views"
    cache_args = [schema_name, provider_type, provider_uuid]
    if not synchronous:
        worker_cache = WorkerCache()
        if worker_cache.single_task_is_running(task_name, cache_args):
            msg = f"Task {task_name} already running for {cache_args}. Requeuing."
            LOG.info(log_json(tracing_id, msg))
            refresh_materialized_views.s(
                schema_name,
                provider_type,
                manifest_id=manifest_id,
                provider_uuid=provider_uuid,
                synchronous=synchronous,
                queue_name=queue_name,
                tracing_id=tracing_id,
            ).apply_async(queue=queue_name or REFRESH_MATERIALIZED_VIEWS_QUEUE)
            return
        worker_cache.lock_single_task(task_name,
                                      cache_args,
                                      timeout=settings.WORKER_CACHE_TIMEOUT)
    materialized_views = ()
    try:
        with schema_context(schema_name):
            for view in materialized_views:
                table_name = view._meta.db_table
                with connection.cursor() as cursor:
                    cursor.execute(
                        f"REFRESH MATERIALIZED VIEW CONCURRENTLY {table_name}")
                    LOG.info(log_json(tracing_id, f"Refreshed {table_name}."))

        invalidate_view_cache_for_tenant_and_source_type(
            schema_name, provider_type)

        if provider_uuid:
            ProviderDBAccessor(provider_uuid).set_data_updated_timestamp()
        if manifest_id:
            # Processing for this monifest should be complete after this step
            with ReportManifestDBAccessor() as manifest_accessor:
                manifest = manifest_accessor.get_manifest_by_id(manifest_id)
                manifest_accessor.mark_manifest_as_completed(manifest)
    except Exception as ex:
        if not synchronous:
            worker_cache.release_single_task(task_name, cache_args)
        raise ex

    if not synchronous:
        worker_cache.release_single_task(task_name, cache_args)
예제 #12
0
class OCPUtilTests(MasuTestCase):
    """Test the OCP utility functions."""

    def setUp(self):
        """Shared variables used by ocp common tests."""
        super().setUp()
        self.common_accessor = ReportingCommonDBAccessor()
        self.column_map = self.common_accessor.column_map
        self.accessor = OCPReportDBAccessor(schema=self.schema, column_map=self.column_map)
        self.provider_accessor = ProviderDBAccessor(provider_uuid=self.ocp_test_provider_uuid)
        self.report_schema = self.accessor.report_schema
        self.creator = ReportObjectCreator(self.schema, self.column_map)
        self.all_tables = list(OCP_REPORT_TABLE_MAP.values())

        self.provider_uuid = self.provider_accessor.get_provider().uuid
        reporting_period = self.creator.create_ocp_report_period(provider_uuid=self.provider_uuid)
        report = self.creator.create_ocp_report(
            reporting_period, reporting_period.report_period_start
        )
        self.creator.create_ocp_usage_line_item(reporting_period, report)
        self.creator.create_ocp_storage_line_item(reporting_period, report)

    def test_get_cluster_id_from_provider(self):
        """Test that the cluster ID is returned from OCP provider."""
        cluster_id = utils.get_cluster_id_from_provider(self.ocp_test_provider_uuid)
        self.assertIsNotNone(cluster_id)

    def test_get_cluster_id_from_non_ocp_provider(self):
        """Test that None is returned when getting cluster ID on non-OCP provider."""
        cluster_id = utils.get_cluster_id_from_provider(self.aws_provider_uuid)
        self.assertIsNone(cluster_id)

    def test_get_provider_uuid_from_cluster_id(self):
        """Test that the provider uuid is returned for a cluster ID."""
        cluster_id = self.ocp_provider_resource_name
        provider_uuid = utils.get_provider_uuid_from_cluster_id(cluster_id)
        try:
            UUID(provider_uuid)
        except ValueError:
            self.fail('{} is not a valid uuid.'.format(str(provider_uuid)))

    def test_get_provider_uuid_from_invalid_cluster_id(self):
        """Test that the provider uuid is not returned for an invalid cluster ID."""
        cluster_id = 'bad_cluster_id'
        provider_uuid = utils.get_provider_uuid_from_cluster_id(cluster_id)
        self.assertIsNone(provider_uuid)

    def test_poll_ingest_override_for_provider(self):
        """Test that OCP polling override returns True if insights local path exists."""
        fake_dir = tempfile.mkdtemp()
        with patch.object(Config, 'INSIGHTS_LOCAL_REPORT_DIR', fake_dir):
            cluster_id = utils.get_cluster_id_from_provider(self.ocp_test_provider_uuid)
            expected_path = '{}/{}/'.format(Config.INSIGHTS_LOCAL_REPORT_DIR, cluster_id)
            os.makedirs(expected_path, exist_ok=True)
            self.assertTrue(utils.poll_ingest_override_for_provider(self.ocp_test_provider_uuid))
        shutil.rmtree(fake_dir)
예제 #13
0
    def setUp(self):
        super().setUp()
        self.common_accessor = ReportingCommonDBAccessor()
        self.column_map = self.common_accessor.column_map
        self.accessor = OCPReportDBAccessor(schema=self.schema,
                                            column_map=self.column_map)
        self.provider_accessor = ProviderDBAccessor(
            provider_uuid=self.ocp_test_provider_uuid)
        self.report_schema = self.accessor.report_schema
        self.creator = ReportObjectCreator(self.schema, self.column_map)
        self.all_tables = list(OCP_REPORT_TABLE_MAP.values())

        self.provider_uuid = self.provider_accessor.get_provider().uuid
        reporting_period = self.creator.create_ocp_report_period(
            provider_uuid=self.provider_uuid)
        report = self.creator.create_ocp_report(
            reporting_period, reporting_period.report_period_start)
        self.creator.create_ocp_usage_line_item(reporting_period, report)
        self.creator.create_ocp_storage_line_item(reporting_period, report)
예제 #14
0
    def set_provider_infra_map(self, infra_map):
        """Use the infra map to map providers to infrastructures.

        The infra_map comes from created in _generate_ocp_infra_map_from_sql.
        """
        for key, infra_tuple in infra_map.items():
            with ProviderDBAccessor(key) as provider_accessor:
                provider_accessor.set_infrastructure(
                    infrastructure_provider_uuid=infra_tuple[0],
                    infrastructure_type=infra_tuple[1])
예제 #15
0
    def test_process_report_files_with_transaction_atomic_error(
            self, mock_processor, mock_setup_complete):
        """Test than an exception rolls back the atomic transaction."""
        path = '{}/{}'.format('test', 'file1.csv')
        schema_name = self.schema
        provider = Provider.PROVIDER_AWS
        provider_uuid = self.aws_provider_uuid
        manifest_dict = {
            'assembly_id':
            '12345',
            'billing_period_start_datetime':
            DateAccessor().today_with_timezone('UTC'),
            'num_total_files':
            2,
            'provider_uuid':
            self.aws_provider_uuid,
            'task':
            '170653c0-3e66-4b7e-a764-336496d7ca5a',
        }
        with ReportManifestDBAccessor() as manifest_accessor:
            manifest = manifest_accessor.add(**manifest_dict)
            manifest.save()
            manifest_id = manifest.id
            initial_update_time = manifest.manifest_updated_datetime

        with ReportStatsDBAccessor(path, manifest_id) as report_file_accessor:
            report_file_accessor.get_last_started_datetime()

        report_dict = {
            'file': path,
            'compression': 'gzip',
            'start_date': str(DateAccessor().today()),
            'manifest_id': manifest_id,
        }

        mock_setup_complete.side_effect = Exception

        with self.assertRaises(Exception):
            _process_report_file(schema_name, provider, provider_uuid,
                                 report_dict)

        with ReportStatsDBAccessor(path, manifest_id) as report_file_accessor:
            self.assertIsNone(
                report_file_accessor.get_last_completed_datetime())

        with ReportManifestDBAccessor() as manifest_accessor:
            manifest = manifest_accessor.get_manifest_by_id(manifest_id)
            self.assertEqual(manifest.num_processed_files, 0)
            self.assertEqual(manifest.manifest_updated_datetime,
                             initial_update_time)

        with ProviderDBAccessor(
                provider_uuid=provider_uuid) as provider_accessor:
            self.assertFalse(provider_accessor.get_setup_complete())
예제 #16
0
    def setUp(self):
        """Set up each test."""
        super().setUp()
        if self.accessor._conn.closed:
            self.accessor._conn = self.accessor._db.connect()
        if self.accessor._pg2_conn.closed:
            self.accessor._pg2_conn = self.accessor._get_psycopg2_connection()
        if self.accessor._cursor.closed:
            self.accessor._cursor = self.accessor._get_psycopg2_cursor()

        self.provider_accessor = ProviderDBAccessor(
            provider_uuid=self.aws_test_provider_uuid
        )
        provider_id = self.provider_accessor.get_provider().id
        self.provider_accessor = ProviderDBAccessor(
            provider_uuid=self.aws_test_provider_uuid
        )
        self.updater = AWSReportChargeUpdater(
            schema=self.test_schema,
            provider_uuid=self.aws_test_provider_uuid,
            provider_id=provider_id
        )
        today = DateAccessor().today_with_timezone('UTC')
        bill = self.creator.create_cost_entry_bill(today)
        cost_entry = self.creator.create_cost_entry(bill, today)
        product = self.creator.create_cost_entry_product()
        pricing = self.creator.create_cost_entry_pricing()
        reservation = self.creator.create_cost_entry_reservation()
        self.creator.create_cost_entry_line_item(
            bill,
            cost_entry,
            product,
            pricing,
            reservation
        )

        self.manifest = self.manifest_accessor.add(**self.manifest_dict)
        self.manifest_accessor.commit()

        with ProviderDBAccessor(self.aws_test_provider_uuid) as provider_accessor:
            self.provider = provider_accessor.get_provider()
예제 #17
0
def report_data():
    """Update report summary tables in the database."""
    params = request.args
    async_result = None
    all_providers = False
    provider_uuid = params.get('provider_uuid')
    provider_type = params.get('provider_type')
    schema_name = params.get('schema')
    start_date = params.get('start_date')
    end_date = params.get('end_date')

    if provider_uuid is None and provider_type is None:
        errmsg = 'provider_uuid or provider_type must be supplied as a parameter.'
        return jsonify({'Error': errmsg}), 400

    if provider_uuid == '*':
        all_providers = True
    elif provider_uuid:
        with ProviderDBAccessor(provider_uuid) as provider_accessor:
            provider = provider_accessor.get_type()
    else:
        provider = provider_type

    if start_date is None:
        errmsg = 'start_date is a required parameter.'
        return jsonify({'Error': errmsg}), 400

    if not all_providers:
        if schema_name is None:
            errmsg = 'schema is a required parameter.'
            return jsonify({'Error': errmsg}), 400

        if provider is None:
            errmsg = 'Unable to determine provider type.'
            return jsonify({'Error': errmsg}), 400

        if provider_type and provider_type != provider:
            errmsg = 'provider_uuid and provider_type have mismatched provider types.'
            return jsonify({'Error': errmsg}), 400

        async_result = update_summary_tables.delay(
            schema_name,
            provider,
            provider_uuid,
            start_date,
            end_date
        )
    else:
        async_result = update_all_summary_tables.delay(
            start_date,
            end_date
        )
    return jsonify({REPORT_DATA_KEY: str(async_result)})
예제 #18
0
 def test_update_summary_tables_with_aws_provider(
     self, mock_utility, mock_ocp, mock_ocp_on_aws, mock_map
 ):
     """Test that summary tables are properly run for an OCP provider."""
     fake_bills = [Mock(), Mock()]
     fake_bills[0].id = 1
     fake_bills[1].id = 2
     bill_ids = [str(bill.id) for bill in fake_bills]
     mock_utility.return_value = fake_bills
     start_date = self.date_accessor.today_with_timezone('UTC')
     end_date = start_date + datetime.timedelta(days=1)
     start_date_str = start_date.strftime('%Y-%m-%d')
     end_date_str = end_date.strftime('%Y-%m-%d')
     with ProviderDBAccessor(self.aws_provider_uuid) as provider_accessor:
         provider = provider_accessor.get_provider()
     with ProviderDBAccessor(self.ocp_test_provider_uuid) as provider_accessor:
         cluster_id = provider_accessor.get_authentication()
     mock_map.return_value = {self.ocp_test_provider_uuid: (self.aws_provider_uuid, 'AWS')}
     updater = OCPCloudReportSummaryUpdater(schema='acct10001', provider=provider, manifest=None)
     updater.update_summary_tables(start_date_str, end_date_str)
     mock_ocp_on_aws.assert_called_with(start_date_str, end_date_str, cluster_id, bill_ids)
예제 #19
0
    def test_set_infrastructure(self):
        """Test that infrastructure provider UUID is returned."""
        infrastructure_type = Provider.PROVIDER_AWS
        with ProviderDBAccessor(self.ocp_provider_uuid) as accessor:
            accessor.set_infrastructure(self.aws_provider_uuid, infrastructure_type)

        mapping = ProviderInfrastructureMap.objects.filter(
            infrastructure_provider_id=self.aws_provider_uuid, infrastructure_type=infrastructure_type
        ).first()

        mapping_on_provider = Provider.objects.filter(infrastructure=mapping).first()
        self.assertEqual(mapping.id, mapping_on_provider.infrastructure.id)
 def setUpClass(cls):
     """Set up the test class with required objects."""
     cls.common_accessor = ReportingCommonDBAccessor()
     cls.column_map = cls.common_accessor.column_map
     cls.ocp_provider_uuid = '3c6e687e-1a09-4a05-970c-2ccf44b0952e'
     cls.accessor = OCPReportDBAccessor(schema='acct10001',
                                        column_map=cls.column_map)
     cls.provider_accessor = ProviderDBAccessor(
         provider_uuid=cls.ocp_provider_uuid)
     cls.report_schema = cls.accessor.report_schema
     cls.creator = ReportObjectCreator(cls.accessor, cls.column_map,
                                       cls.report_schema.column_types)
     cls.all_tables = list(OCP_REPORT_TABLE_MAP.values())
예제 #21
0
 def test_get_infra_db_key_for_provider_type(self):
     """Test db_key private method for OCP-on-AWS infrastructure map."""
     with ProviderDBAccessor(self.ocp_test_provider_uuid) as provider_accessor:
         provider = provider_accessor.get_provider()
     updater = OCPCloudReportSummaryUpdater(
         schema='acct10001',
         provider=provider,
         manifest=None
     )
     self.assertEqual(updater._get_infra_db_key_for_provider_type('AWS'), 'aws_uuid')
     self.assertEqual(updater._get_infra_db_key_for_provider_type('AWS-local'), 'aws_uuid')
     self.assertEqual(updater._get_infra_db_key_for_provider_type('OCP'), 'ocp_uuid')
     self.assertEqual(updater._get_infra_db_key_for_provider_type('WRONG'), None)
예제 #22
0
 def test_update_aws_summary_tables_with_no_cluster_info(self, mock_ocp_on_aws, mock_cluster_info):
     """Test that aws summary tables are not updated when there is no cluster info."""
     # this is a yes or no check so false is fine
     mock_cluster_info.return_value = False
     start_date = self.dh.today.date()
     end_date = start_date + datetime.timedelta(days=1)
     with ProviderDBAccessor(self.aws_provider_uuid) as provider_accessor:
         provider = provider_accessor.get_provider()
     updater = OCPCloudParquetReportSummaryUpdater(schema="acct10001", provider=provider, manifest=None)
     updater.update_aws_summary_tables(
         self.ocpaws_provider_uuid, self.aws_test_provider_uuid, str(start_date), str(end_date)
     )
     mock_ocp_on_aws.assert_not_called()
    def setUp(self):
        """Set up each test."""
        super().setUp()

        billing_start = self.date_accessor.today_with_timezone('UTC').replace(
            day=1)
        self.manifest_dict = {
            'assembly_id': '1234',
            'billing_period_start_datetime': billing_start,
            'num_total_files': 2,
            'provider_id': self.aws_provider.id,
        }

        self.provider_accessor = ProviderDBAccessor(
            provider_uuid=self.aws_test_provider_uuid)
        provider_id = self.provider_accessor.get_provider().id
        self.provider_accessor = ProviderDBAccessor(
            provider_uuid=self.aws_test_provider_uuid)
        self.updater = AWSReportChargeUpdater(
            schema=self.schema,
            provider_uuid=self.aws_test_provider_uuid,
            provider_id=provider_id,
        )
        today = DateAccessor().today_with_timezone('UTC')
        bill = self.creator.create_cost_entry_bill(provider_id=provider_id,
                                                   bill_date=today)
        cost_entry = self.creator.create_cost_entry(bill, today)
        product = self.creator.create_cost_entry_product()
        pricing = self.creator.create_cost_entry_pricing()
        reservation = self.creator.create_cost_entry_reservation()
        self.creator.create_cost_entry_line_item(bill, cost_entry, product,
                                                 pricing, reservation)

        self.manifest = self.manifest_accessor.add(**self.manifest_dict)
        self.manifest_accessor.commit()

        with ProviderDBAccessor(
                self.aws_test_provider_uuid) as provider_accessor:
            self.provider = provider_accessor.get_provider()
예제 #24
0
 def test_update_summary_tables_with_aws_provider(self, mock_utility, mock_ocp_on_aws, mock_map):
     """Test that summary tables are properly run for an OCP provider."""
     fake_bills = [Mock(), Mock()]
     fake_bills[0].id = 1
     fake_bills[1].id = 2
     bill_ids = [str(bill.id) for bill in fake_bills]
     mock_utility.return_value = fake_bills
     start_date = self.dh.today
     end_date = start_date + datetime.timedelta(days=1)
     start_date_str = start_date.strftime("%Y-%m-%d")
     end_date_str = end_date.strftime("%Y-%m-%d")
     with ProviderDBAccessor(self.aws_provider_uuid) as provider_accessor:
         provider = provider_accessor.get_provider()
     with ProviderDBAccessor(self.ocp_test_provider_uuid) as provider_accessor:
         credentials = provider_accessor.get_credentials()
     cluster_id = credentials.get("cluster_id")
     mock_map.return_value = {self.ocp_test_provider_uuid: (self.aws_provider_uuid, Provider.PROVIDER_AWS)}
     updater = OCPCloudReportSummaryUpdater(schema="acct10001", provider=provider, manifest=None)
     updater.update_summary_tables(start_date_str, end_date_str)
     mock_ocp_on_aws.assert_called_with(
         start_date.date(), end_date.date(), cluster_id, bill_ids, decimal.Decimal(0)
     )
예제 #25
0
파일: tasks.py 프로젝트: tohjustin/koku
def refresh_materialized_views(schema_name,
                               provider_type,
                               manifest_id=None,
                               provider_uuid=None,
                               synchronous=False):
    """Refresh the database's materialized views for reporting."""
    task_name = "masu.processor.tasks.refresh_materialized_views"
    cache_args = [schema_name]
    if not synchronous:
        worker_cache = WorkerCache()
        while worker_cache.single_task_is_running(task_name, cache_args):
            time.sleep(5)

        worker_cache.lock_single_task(task_name, cache_args)
    materialized_views = ()
    if provider_type in (Provider.PROVIDER_AWS, Provider.PROVIDER_AWS_LOCAL):
        materialized_views = (AWS_MATERIALIZED_VIEWS +
                              OCP_ON_AWS_MATERIALIZED_VIEWS +
                              OCP_ON_INFRASTRUCTURE_MATERIALIZED_VIEWS)
    elif provider_type in (Provider.PROVIDER_OCP):
        materialized_views = (OCP_MATERIALIZED_VIEWS +
                              OCP_ON_AWS_MATERIALIZED_VIEWS +
                              OCP_ON_AZURE_MATERIALIZED_VIEWS +
                              OCP_ON_INFRASTRUCTURE_MATERIALIZED_VIEWS)
    elif provider_type in (Provider.PROVIDER_AZURE,
                           Provider.PROVIDER_AZURE_LOCAL):
        materialized_views = (AZURE_MATERIALIZED_VIEWS +
                              OCP_ON_AZURE_MATERIALIZED_VIEWS +
                              OCP_ON_INFRASTRUCTURE_MATERIALIZED_VIEWS)

    with schema_context(schema_name):
        for view in materialized_views:
            table_name = view._meta.db_table
            with connection.cursor() as cursor:
                cursor.execute(
                    f"REFRESH MATERIALIZED VIEW CONCURRENTLY {table_name}")
                LOG.info(f"Refreshed {table_name}.")

    invalidate_view_cache_for_tenant_and_source_type(schema_name,
                                                     provider_type)

    if provider_uuid:
        ProviderDBAccessor(provider_uuid).set_data_updated_timestamp()
    if manifest_id:
        # Processing for this monifest should be complete after this step
        with ReportManifestDBAccessor() as manifest_accessor:
            manifest = manifest_accessor.get_manifest_by_id(manifest_id)
            manifest_accessor.mark_manifest_as_completed(manifest)

    if not synchronous:
        worker_cache.release_single_task(task_name, cache_args)
예제 #26
0
파일: common.py 프로젝트: project-koku/masu
def get_bills_from_provider(provider_uuid,
                            schema,
                            start_date=None,
                            end_date=None):
    """
    Return the AWS bill IDs given a provider UUID.

    Args:
        provider_uuid (str): Provider UUID.
        schema (str): Tenant schema
        start_date (datetime, str): Start date for bill IDs.
        end_date (datetime, str) End date for bill IDs.

    Returns:
        (list): AWS cost entry bill objects.

    """
    if isinstance(start_date, datetime.datetime):
        start_date = start_date.replace(day=1)
        start_date = start_date.strftime('%Y-%m-%d')
    elif isinstance(start_date, str):
        start_date = datetime.datetime.strptime(start_date, '%Y-%m-%d')
        start_date = start_date.replace(day=1)
        start_date = start_date.strftime('%Y-%m-%d')
    if isinstance(end_date, datetime.datetime):
        end_date = end_date.strftime('%Y-%m-%d')

    with ReportingCommonDBAccessor() as reporting_common:
        column_map = reporting_common.column_map

    with ProviderDBAccessor(provider_uuid) as provider_accessor:
        provider = provider_accessor.get_provider()

    if provider.type not in (AMAZON_WEB_SERVICES, AWS_LOCAL_SERVICE_PROVIDER):
        err_msg = 'Provider UUID is not an AWS type.  It is {}'.\
            format(provider.type)
        LOG.warning(err_msg)
        return []

    with AWSReportDBAccessor(schema, column_map) as report_accessor:
        bill_table_name = AWS_CUR_TABLE_MAP['bill']
        bill_obj = getattr(report_accessor.report_schema, bill_table_name)
        bills = report_accessor.get_cost_entry_bills_query_by_provider(
            provider.id)
        if start_date:
            bills = bills.filter(bill_obj.billing_period_start >= start_date)
        if end_date:
            bills = bills.filter(bill_obj.billing_period_start <= end_date)
        bills = bills.all()

    return bills
예제 #27
0
    def update_summary_tables(self, start_date, end_date):
        """Populate the summary tables for reporting.

        Args:
            start_date (str) The date to start populating the table.
            end_date   (str) The date to end on.

        Returns
            None

        """
        infra_map = self.get_infra_map()
        openshift_provider_uuids, infra_provider_uuids = self.get_openshift_and_infra_providers_lists(
            infra_map)

        if self._provider.type == Provider.PROVIDER_OCP and self._provider_uuid not in openshift_provider_uuids:
            infra_map = self._generate_ocp_infra_map_from_sql(
                start_date, end_date)
        elif self._provider.type in Provider.CLOUD_PROVIDER_LIST and self._provider_uuid not in infra_provider_uuids:
            # When running for an Infrastructure provider we want all
            # of the matching clusters to run
            infra_map = self._generate_ocp_infra_map_from_sql(
                start_date, end_date)

        # If running as an infrastructure provider (e.g. AWS)
        # this loop should run for all associated OpenShift clusters.
        # If running for an OpenShift provider, it should just run one time.
        for ocp_provider_uuid, infra_tuple in infra_map.items():
            infra_provider_uuid = infra_tuple[0]
            infra_provider_type = infra_tuple[1]
            if infra_provider_type in (Provider.PROVIDER_AWS,
                                       Provider.PROVIDER_AWS_LOCAL):
                self.update_aws_summary_tables(ocp_provider_uuid,
                                               infra_provider_uuid, start_date,
                                               end_date)
            elif infra_provider_type in (Provider.PROVIDER_AZURE,
                                         Provider.PROVIDER_AZURE_LOCAL):
                self.update_azure_summary_tables(ocp_provider_uuid,
                                                 infra_provider_uuid,
                                                 start_date, end_date)

            # Update markup for OpenShift tables
            with ProviderDBAccessor(ocp_provider_uuid) as provider_accessor:
                OCPCostModelCostUpdater(
                    self._schema,
                    provider_accessor.provider)._update_markup_cost(
                        start_date, end_date)

        if infra_map:
            self.refresh_openshift_on_infrastructure_views(
                OCP_ON_INFRASTRUCTURE_MATERIALIZED_VIEWS)
예제 #28
0
파일: common.py 프로젝트: project-koku/koku
def get_account_alias_from_role_arn(role_arn, session=None):
    """
    Get account ID for given RoleARN.

    Args:
        role_arn     (String) AWS IAM RoleARN

    Returns:
        (String): Account ID

    """
    provider_uuid = get_provider_uuid_from_arn(role_arn)
    context_key = "aws_list_account_aliases"
    if not session:
        session = get_assume_role_session(role_arn)
    iam_client = session.client("iam")

    account_id = role_arn.split(":")[-2]
    alias = account_id

    with ProviderDBAccessor(provider_uuid) as provider_accessor:
        context = provider_accessor.get_additional_context()
        list_aliases = context.get(context_key, True)

    if list_aliases:
        try:
            alias_response = iam_client.list_account_aliases()
            alias_list = alias_response.get("AccountAliases", [])
            # Note: Boto3 docs states that you can only have one alias per account
            # so the pop() should be ok...
            alias = alias_list.pop() if alias_list else None
        except ClientError as err:
            LOG.info("Unable to list account aliases.  Reason: %s", str(err))
            context[context_key] = False
            with ProviderDBAccessor(provider_uuid) as provider_accessor:
                provider_accessor.set_additional_context(context)

    return (account_id, alias)
예제 #29
0
    def test_update_azure_summary_tables(self, mock_utility, mock_ocp_on_azure,
                                         mock_tag_summary, mock_map):
        """Test that summary tables are properly run for an OCP provider."""
        fake_bills = MagicMock()
        fake_bills.__iter__.return_value = [Mock(), Mock()]
        first = Mock()
        bill_id = 1
        first.return_value.id = bill_id
        fake_bills.first = first
        mock_utility.return_value = fake_bills
        start_date = self.dh.today.date()
        end_date = start_date + datetime.timedelta(days=1)

        with ProviderDBAccessor(self.azure_provider_uuid) as provider_accessor:
            provider = provider_accessor.get_provider()
        with ProviderDBAccessor(
                self.ocp_test_provider_uuid) as provider_accessor:
            credentials = provider_accessor.get_credentials()
        cluster_id = credentials.get("cluster_id")
        mock_map.return_value = {
            self.ocp_test_provider_uuid:
            (self.azure_provider_uuid, Provider.PROVIDER_AZURE)
        }
        updater = OCPCloudParquetReportSummaryUpdater(schema="acct10001",
                                                      provider=provider,
                                                      manifest=None)
        updater.update_azure_summary_tables(self.ocp_test_provider_uuid,
                                            self.azure_test_provider_uuid,
                                            start_date, end_date)
        mock_ocp_on_azure.assert_called_with(
            start_date,
            end_date,
            self.ocp_test_provider_uuid,
            self.azure_test_provider_uuid,
            cluster_id,
            bill_id,
            decimal.Decimal(0),
        )
    def setUp(self):
        """Set up each test."""
        super().setUp()

        billing_start = self.date_accessor.today_with_timezone('UTC').replace(
            day=1)
        self.manifest_dict = {
            'assembly_id': '1234',
            'billing_period_start_datetime': billing_start,
            'num_total_files': 1,
            'provider_id': self.azure_provider.id,
        }

        self.provider_accessor = ProviderDBAccessor(
            provider_uuid=self.azure_test_provider_uuid)
        provider_id = self.provider_accessor.get_provider().id

        self.updater = AzureReportChargeUpdater(
            schema=self.schema,
            provider_uuid=self.azure_test_provider_uuid,
            provider_id=provider_id)

        today = DateAccessor().today_with_timezone('UTC')
        bill = self.creator.create_azure_cost_entry_bill(
            provider_id=provider_id, bill_date=today)
        product = self.creator.create_azure_cost_entry_product()
        meter = self.creator.create_azure_meter()
        service = self.creator.create_azure_service()
        self.creator.create_azure_cost_entry_line_item(bill, product, meter,
                                                       service)

        self.manifest = self.manifest_accessor.add(**self.manifest_dict)
        self.manifest_accessor.commit()

        with ProviderDBAccessor(
                self.azure_test_provider_uuid) as provider_accessor:
            self.provider = provider_accessor.get_provider()