def test_get_infrastructure_provider_uuid(self): """Test that infrastructure provider UUID is returned.""" infrastructure_type = Provider.PROVIDER_AWS with ProviderDBAccessor(self.ocp_provider_uuid) as accessor: accessor.set_infrastructure(self.aws_provider_uuid, infrastructure_type) self.assertEqual(accessor.get_infrastructure_provider_uuid(), self.aws_provider_uuid)
def test_get_infrastructure_type(self): """Test that infrastructure type is returned.""" infrastructure_type = 'AWS' with ProviderDBAccessor(self.ocp_provider_uuid) as accessor: accessor.set_infrastructure(self.aws_provider_uuid, infrastructure_type) self.assertEqual(accessor.get_infrastructure_type(), infrastructure_type)
def get_date_column_filter(self): """Return a filter using the provider-appropriate column.""" with ProviderDBAccessor(self._provider_uuid) as provider_accessor: type = provider_accessor.get_type().lower() if type == 'azure': return {'usage_date_time__gte': self.data_cutoff_date} else: return {'usage_start__gte': self.data_cutoff_date}
def setUp(self): """Test set up.""" super().setUp() with ProviderDBAccessor( self.aws_test_provider_uuid) as provider_accessor: provider = provider_accessor.get_provider() self.provider_id = provider.id
def get_date_column_filter(self): """Return a filter using the provider-appropriate column.""" with ProviderDBAccessor(self._provider_uuid) as provider_accessor: type = provider_accessor.get_type() if type in (Provider.PROVIDER_AZURE, Provider.PROVIDER_AZURE_LOCAL): return {"usage_date__gte": self.data_cutoff_date} else: return {"usage_start__gte": self.data_cutoff_date}
def test_process_report_files_with_transaction_atomic_error( self, mock_files, mock_processor): """Test than an exception rolls back the atomic transaction.""" path = "{}/{}".format("test", "file1.csv") mock_files.return_value = [{"file": path, "compression": "GZIP"}] schema_name = self.schema provider = Provider.PROVIDER_AWS provider_uuid = self.aws_provider_uuid report_month = DateHelper().today manifest_dict = { "assembly_id": "12345", "billing_period_start_datetime": report_month, "num_total_files": 1, "provider_uuid": self.aws_provider_uuid, "task": "170653c0-3e66-4b7e-a764-336496d7ca5a", } with ReportManifestDBAccessor() as manifest_accessor: manifest = manifest_accessor.add(**manifest_dict) manifest.save() manifest_id = manifest.id initial_update_time = manifest.manifest_updated_datetime with ReportStatsDBAccessor("file1.csv", manifest_id) as stats_accessor: stats_accessor.get_last_completed_datetime with ReportStatsDBAccessor(path, manifest_id) as report_file_accessor: report_file_accessor.get_last_started_datetime() mock_processor.side_effect = Exception with self.assertRaises(Exception): customer_name = "Fake Customer" authentication = "auth" billing_source = "bill" provider_type = provider get_report_files( customer_name=customer_name, authentication=authentication, billing_source=billing_source, provider_type=provider_type, schema_name=schema_name, provider_uuid=provider_uuid, report_month=report_month, ) with ReportStatsDBAccessor(path, manifest_id) as report_file_accessor: self.assertIsNone( report_file_accessor.get_last_completed_datetime()) with ReportManifestDBAccessor() as manifest_accessor: manifest = manifest_accessor.get_manifest_by_id(manifest_id) self.assertEqual(manifest.num_processed_files, 0) self.assertEqual(manifest.manifest_updated_datetime, initial_update_time) with ProviderDBAccessor( provider_uuid=provider_uuid) as provider_accessor: self.assertFalse(provider_accessor.get_setup_complete())
def _process_report_file(schema_name, provider, provider_uuid, report_dict): """ Task to process a Report. Args: schema_name (String) db schema name provider (String) provider type provider_uuid (String) provider uuid report_dict (dict) The report data dict from previous task Returns: None """ start_date = report_dict.get('start_date') report_path = report_dict.get('file') compression = report_dict.get('compression') manifest_id = report_dict.get('manifest_id') provider_id = report_dict.get('provider_id') log_statement = (f'Processing Report:\n' f' schema_name: {schema_name}\n' f' provider: {provider}\n' f' provider_uuid: {provider_uuid}\n' f' file: {report_path}\n' f' compression: {compression}\n' f' start_date: {start_date}') LOG.info(log_statement) mem = psutil.virtual_memory() mem_msg = f'Avaiable memory: {mem.free} bytes ({mem.percent}%)' LOG.info(mem_msg) file_name = report_path.split('/')[-1] with ReportStatsDBAccessor(file_name, manifest_id) as stats_recorder: stats_recorder.log_last_started_datetime() processor = ReportProcessor(schema_name=schema_name, report_path=report_path, compression=compression, provider=provider, provider_id=provider_id, manifest_id=manifest_id) processor.process() with ReportStatsDBAccessor(file_name, manifest_id) as stats_recorder: stats_recorder.log_last_completed_datetime() with ReportManifestDBAccessor() as manifest_accesor: manifest = manifest_accesor.get_manifest_by_id(manifest_id) if manifest: manifest.num_processed_files += 1 manifest.save() manifest_accesor.mark_manifest_as_updated(manifest) else: LOG.error('Unable to find manifest for ID: %s, file %s', manifest_id, file_name) with ProviderDBAccessor(provider_uuid=provider_uuid) as provider_accessor: if provider_accessor.get_setup_complete(): files = processor.remove_processed_files(path.dirname(report_path)) LOG.info('Temporary files removed: %s', str(files)) provider_accessor.setup_complete()
def _process_manifest_db_record(self, assembly_id, billing_start, num_of_files, manifest_modified_datetime, **kwargs): """Insert or update the manifest DB record.""" msg = f"Inserting/updating manifest in database for assembly_id: {assembly_id}" LOG.info(log_json(self.tracing_id, msg)) with ReportManifestDBAccessor() as manifest_accessor: manifest_entry = manifest_accessor.get_manifest( assembly_id, self._provider_uuid) if not manifest_entry: msg = f"No manifest entry found in database. Adding for bill period start: {billing_start}" LOG.info(log_json(self.tracing_id, msg, self.context)) manifest_dict = { "assembly_id": assembly_id, "billing_period_start_datetime": billing_start, "num_total_files": num_of_files, "provider_uuid": self._provider_uuid, "manifest_modified_datetime": manifest_modified_datetime, } manifest_dict.update(kwargs) try: manifest_entry = manifest_accessor.add(**manifest_dict) except IntegrityError as error: fk_violation = FKViolation(error) if fk_violation: LOG.warning(fk_violation) raise ReportDownloaderError( f"Method: _process_manifest_db_record :: {fk_violation}" ) msg = ( f"Manifest entry uniqueness collision: Error {error}. " "Manifest already added, getting manifest_entry_id.") LOG.warning(log_json(self.tracing_id, msg, self.context)) with ReportManifestDBAccessor() as manifest_accessor: manifest_entry = manifest_accessor.get_manifest( assembly_id, self._provider_uuid) if not manifest_entry: msg = f"Manifest entry not found for given manifest {manifest_dict}." with ProviderDBAccessor( self._provider_uuid) as provider_accessor: provider = provider_accessor.get_provider() if not provider: msg = f"Provider entry not found for {self._provider_uuid}." LOG.warning( log_json(self.tracing_id, msg, self.context)) raise ReportDownloaderError(msg) LOG.warning(log_json(self.tracing_id, msg, self.context)) raise IntegrityError(msg) else: if num_of_files != manifest_entry.num_total_files: manifest_accessor.update_number_of_files_for_manifest( manifest_entry) manifest_accessor.mark_manifest_as_updated(manifest_entry) manifest_id = manifest_entry.id return manifest_id
def test_get_customer_uuid(self): """Test provider billing_source getter.""" expected_uuid = None with CustomerDBAccessor(self.customer.id) as customer_accessor: expected_uuid = customer_accessor.get_uuid() uuid = self.aws_provider_uuid with ProviderDBAccessor(uuid) as accessor: self.assertEqual(expected_uuid, accessor.get_customer_uuid())
def setUp(self): """Shared variables used by ocp common tests.""" super().setUp() self.accessor = OCPReportDBAccessor(schema=self.schema) self.provider_accessor = ProviderDBAccessor( provider_uuid=self.ocp_test_provider_uuid) self.report_schema = self.accessor.report_schema self.creator = ReportObjectCreator(self.schema) self.all_tables = list(OCP_REPORT_TABLE_MAP.values()) self.provider_uuid = self.provider_accessor.get_provider().uuid reporting_period = self.creator.create_ocp_report_period( provider_uuid=self.provider_uuid) report = self.creator.create_ocp_report( reporting_period, reporting_period.report_period_start) self.creator.create_ocp_usage_line_item(reporting_period, report) self.creator.create_ocp_storage_line_item(reporting_period, report) self.creator.create_ocp_node_label_line_item(reporting_period, report)
def refresh_materialized_views( # noqa: C901 schema_name, provider_type, manifest_id=None, provider_uuid="", synchronous=False, queue_name=None, tracing_id=None, ): """Refresh the database's materialized views for reporting.""" task_name = "masu.processor.tasks.refresh_materialized_views" cache_args = [schema_name, provider_type, provider_uuid] if not synchronous: worker_cache = WorkerCache() if worker_cache.single_task_is_running(task_name, cache_args): msg = f"Task {task_name} already running for {cache_args}. Requeuing." LOG.info(log_json(tracing_id, msg)) refresh_materialized_views.s( schema_name, provider_type, manifest_id=manifest_id, provider_uuid=provider_uuid, synchronous=synchronous, queue_name=queue_name, tracing_id=tracing_id, ).apply_async(queue=queue_name or REFRESH_MATERIALIZED_VIEWS_QUEUE) return worker_cache.lock_single_task(task_name, cache_args, timeout=settings.WORKER_CACHE_TIMEOUT) materialized_views = () try: with schema_context(schema_name): for view in materialized_views: table_name = view._meta.db_table with connection.cursor() as cursor: cursor.execute( f"REFRESH MATERIALIZED VIEW CONCURRENTLY {table_name}") LOG.info(log_json(tracing_id, f"Refreshed {table_name}.")) invalidate_view_cache_for_tenant_and_source_type( schema_name, provider_type) if provider_uuid: ProviderDBAccessor(provider_uuid).set_data_updated_timestamp() if manifest_id: # Processing for this monifest should be complete after this step with ReportManifestDBAccessor() as manifest_accessor: manifest = manifest_accessor.get_manifest_by_id(manifest_id) manifest_accessor.mark_manifest_as_completed(manifest) except Exception as ex: if not synchronous: worker_cache.release_single_task(task_name, cache_args) raise ex if not synchronous: worker_cache.release_single_task(task_name, cache_args)
class OCPUtilTests(MasuTestCase): """Test the OCP utility functions.""" def setUp(self): """Shared variables used by ocp common tests.""" super().setUp() self.common_accessor = ReportingCommonDBAccessor() self.column_map = self.common_accessor.column_map self.accessor = OCPReportDBAccessor(schema=self.schema, column_map=self.column_map) self.provider_accessor = ProviderDBAccessor(provider_uuid=self.ocp_test_provider_uuid) self.report_schema = self.accessor.report_schema self.creator = ReportObjectCreator(self.schema, self.column_map) self.all_tables = list(OCP_REPORT_TABLE_MAP.values()) self.provider_uuid = self.provider_accessor.get_provider().uuid reporting_period = self.creator.create_ocp_report_period(provider_uuid=self.provider_uuid) report = self.creator.create_ocp_report( reporting_period, reporting_period.report_period_start ) self.creator.create_ocp_usage_line_item(reporting_period, report) self.creator.create_ocp_storage_line_item(reporting_period, report) def test_get_cluster_id_from_provider(self): """Test that the cluster ID is returned from OCP provider.""" cluster_id = utils.get_cluster_id_from_provider(self.ocp_test_provider_uuid) self.assertIsNotNone(cluster_id) def test_get_cluster_id_from_non_ocp_provider(self): """Test that None is returned when getting cluster ID on non-OCP provider.""" cluster_id = utils.get_cluster_id_from_provider(self.aws_provider_uuid) self.assertIsNone(cluster_id) def test_get_provider_uuid_from_cluster_id(self): """Test that the provider uuid is returned for a cluster ID.""" cluster_id = self.ocp_provider_resource_name provider_uuid = utils.get_provider_uuid_from_cluster_id(cluster_id) try: UUID(provider_uuid) except ValueError: self.fail('{} is not a valid uuid.'.format(str(provider_uuid))) def test_get_provider_uuid_from_invalid_cluster_id(self): """Test that the provider uuid is not returned for an invalid cluster ID.""" cluster_id = 'bad_cluster_id' provider_uuid = utils.get_provider_uuid_from_cluster_id(cluster_id) self.assertIsNone(provider_uuid) def test_poll_ingest_override_for_provider(self): """Test that OCP polling override returns True if insights local path exists.""" fake_dir = tempfile.mkdtemp() with patch.object(Config, 'INSIGHTS_LOCAL_REPORT_DIR', fake_dir): cluster_id = utils.get_cluster_id_from_provider(self.ocp_test_provider_uuid) expected_path = '{}/{}/'.format(Config.INSIGHTS_LOCAL_REPORT_DIR, cluster_id) os.makedirs(expected_path, exist_ok=True) self.assertTrue(utils.poll_ingest_override_for_provider(self.ocp_test_provider_uuid)) shutil.rmtree(fake_dir)
def setUp(self): super().setUp() self.common_accessor = ReportingCommonDBAccessor() self.column_map = self.common_accessor.column_map self.accessor = OCPReportDBAccessor(schema=self.schema, column_map=self.column_map) self.provider_accessor = ProviderDBAccessor( provider_uuid=self.ocp_test_provider_uuid) self.report_schema = self.accessor.report_schema self.creator = ReportObjectCreator(self.schema, self.column_map) self.all_tables = list(OCP_REPORT_TABLE_MAP.values()) self.provider_uuid = self.provider_accessor.get_provider().uuid reporting_period = self.creator.create_ocp_report_period( provider_uuid=self.provider_uuid) report = self.creator.create_ocp_report( reporting_period, reporting_period.report_period_start) self.creator.create_ocp_usage_line_item(reporting_period, report) self.creator.create_ocp_storage_line_item(reporting_period, report)
def set_provider_infra_map(self, infra_map): """Use the infra map to map providers to infrastructures. The infra_map comes from created in _generate_ocp_infra_map_from_sql. """ for key, infra_tuple in infra_map.items(): with ProviderDBAccessor(key) as provider_accessor: provider_accessor.set_infrastructure( infrastructure_provider_uuid=infra_tuple[0], infrastructure_type=infra_tuple[1])
def test_process_report_files_with_transaction_atomic_error( self, mock_processor, mock_setup_complete): """Test than an exception rolls back the atomic transaction.""" path = '{}/{}'.format('test', 'file1.csv') schema_name = self.schema provider = Provider.PROVIDER_AWS provider_uuid = self.aws_provider_uuid manifest_dict = { 'assembly_id': '12345', 'billing_period_start_datetime': DateAccessor().today_with_timezone('UTC'), 'num_total_files': 2, 'provider_uuid': self.aws_provider_uuid, 'task': '170653c0-3e66-4b7e-a764-336496d7ca5a', } with ReportManifestDBAccessor() as manifest_accessor: manifest = manifest_accessor.add(**manifest_dict) manifest.save() manifest_id = manifest.id initial_update_time = manifest.manifest_updated_datetime with ReportStatsDBAccessor(path, manifest_id) as report_file_accessor: report_file_accessor.get_last_started_datetime() report_dict = { 'file': path, 'compression': 'gzip', 'start_date': str(DateAccessor().today()), 'manifest_id': manifest_id, } mock_setup_complete.side_effect = Exception with self.assertRaises(Exception): _process_report_file(schema_name, provider, provider_uuid, report_dict) with ReportStatsDBAccessor(path, manifest_id) as report_file_accessor: self.assertIsNone( report_file_accessor.get_last_completed_datetime()) with ReportManifestDBAccessor() as manifest_accessor: manifest = manifest_accessor.get_manifest_by_id(manifest_id) self.assertEqual(manifest.num_processed_files, 0) self.assertEqual(manifest.manifest_updated_datetime, initial_update_time) with ProviderDBAccessor( provider_uuid=provider_uuid) as provider_accessor: self.assertFalse(provider_accessor.get_setup_complete())
def setUp(self): """Set up each test.""" super().setUp() if self.accessor._conn.closed: self.accessor._conn = self.accessor._db.connect() if self.accessor._pg2_conn.closed: self.accessor._pg2_conn = self.accessor._get_psycopg2_connection() if self.accessor._cursor.closed: self.accessor._cursor = self.accessor._get_psycopg2_cursor() self.provider_accessor = ProviderDBAccessor( provider_uuid=self.aws_test_provider_uuid ) provider_id = self.provider_accessor.get_provider().id self.provider_accessor = ProviderDBAccessor( provider_uuid=self.aws_test_provider_uuid ) self.updater = AWSReportChargeUpdater( schema=self.test_schema, provider_uuid=self.aws_test_provider_uuid, provider_id=provider_id ) today = DateAccessor().today_with_timezone('UTC') bill = self.creator.create_cost_entry_bill(today) cost_entry = self.creator.create_cost_entry(bill, today) product = self.creator.create_cost_entry_product() pricing = self.creator.create_cost_entry_pricing() reservation = self.creator.create_cost_entry_reservation() self.creator.create_cost_entry_line_item( bill, cost_entry, product, pricing, reservation ) self.manifest = self.manifest_accessor.add(**self.manifest_dict) self.manifest_accessor.commit() with ProviderDBAccessor(self.aws_test_provider_uuid) as provider_accessor: self.provider = provider_accessor.get_provider()
def report_data(): """Update report summary tables in the database.""" params = request.args async_result = None all_providers = False provider_uuid = params.get('provider_uuid') provider_type = params.get('provider_type') schema_name = params.get('schema') start_date = params.get('start_date') end_date = params.get('end_date') if provider_uuid is None and provider_type is None: errmsg = 'provider_uuid or provider_type must be supplied as a parameter.' return jsonify({'Error': errmsg}), 400 if provider_uuid == '*': all_providers = True elif provider_uuid: with ProviderDBAccessor(provider_uuid) as provider_accessor: provider = provider_accessor.get_type() else: provider = provider_type if start_date is None: errmsg = 'start_date is a required parameter.' return jsonify({'Error': errmsg}), 400 if not all_providers: if schema_name is None: errmsg = 'schema is a required parameter.' return jsonify({'Error': errmsg}), 400 if provider is None: errmsg = 'Unable to determine provider type.' return jsonify({'Error': errmsg}), 400 if provider_type and provider_type != provider: errmsg = 'provider_uuid and provider_type have mismatched provider types.' return jsonify({'Error': errmsg}), 400 async_result = update_summary_tables.delay( schema_name, provider, provider_uuid, start_date, end_date ) else: async_result = update_all_summary_tables.delay( start_date, end_date ) return jsonify({REPORT_DATA_KEY: str(async_result)})
def test_update_summary_tables_with_aws_provider( self, mock_utility, mock_ocp, mock_ocp_on_aws, mock_map ): """Test that summary tables are properly run for an OCP provider.""" fake_bills = [Mock(), Mock()] fake_bills[0].id = 1 fake_bills[1].id = 2 bill_ids = [str(bill.id) for bill in fake_bills] mock_utility.return_value = fake_bills start_date = self.date_accessor.today_with_timezone('UTC') end_date = start_date + datetime.timedelta(days=1) start_date_str = start_date.strftime('%Y-%m-%d') end_date_str = end_date.strftime('%Y-%m-%d') with ProviderDBAccessor(self.aws_provider_uuid) as provider_accessor: provider = provider_accessor.get_provider() with ProviderDBAccessor(self.ocp_test_provider_uuid) as provider_accessor: cluster_id = provider_accessor.get_authentication() mock_map.return_value = {self.ocp_test_provider_uuid: (self.aws_provider_uuid, 'AWS')} updater = OCPCloudReportSummaryUpdater(schema='acct10001', provider=provider, manifest=None) updater.update_summary_tables(start_date_str, end_date_str) mock_ocp_on_aws.assert_called_with(start_date_str, end_date_str, cluster_id, bill_ids)
def test_set_infrastructure(self): """Test that infrastructure provider UUID is returned.""" infrastructure_type = Provider.PROVIDER_AWS with ProviderDBAccessor(self.ocp_provider_uuid) as accessor: accessor.set_infrastructure(self.aws_provider_uuid, infrastructure_type) mapping = ProviderInfrastructureMap.objects.filter( infrastructure_provider_id=self.aws_provider_uuid, infrastructure_type=infrastructure_type ).first() mapping_on_provider = Provider.objects.filter(infrastructure=mapping).first() self.assertEqual(mapping.id, mapping_on_provider.infrastructure.id)
def setUpClass(cls): """Set up the test class with required objects.""" cls.common_accessor = ReportingCommonDBAccessor() cls.column_map = cls.common_accessor.column_map cls.ocp_provider_uuid = '3c6e687e-1a09-4a05-970c-2ccf44b0952e' cls.accessor = OCPReportDBAccessor(schema='acct10001', column_map=cls.column_map) cls.provider_accessor = ProviderDBAccessor( provider_uuid=cls.ocp_provider_uuid) cls.report_schema = cls.accessor.report_schema cls.creator = ReportObjectCreator(cls.accessor, cls.column_map, cls.report_schema.column_types) cls.all_tables = list(OCP_REPORT_TABLE_MAP.values())
def test_get_infra_db_key_for_provider_type(self): """Test db_key private method for OCP-on-AWS infrastructure map.""" with ProviderDBAccessor(self.ocp_test_provider_uuid) as provider_accessor: provider = provider_accessor.get_provider() updater = OCPCloudReportSummaryUpdater( schema='acct10001', provider=provider, manifest=None ) self.assertEqual(updater._get_infra_db_key_for_provider_type('AWS'), 'aws_uuid') self.assertEqual(updater._get_infra_db_key_for_provider_type('AWS-local'), 'aws_uuid') self.assertEqual(updater._get_infra_db_key_for_provider_type('OCP'), 'ocp_uuid') self.assertEqual(updater._get_infra_db_key_for_provider_type('WRONG'), None)
def test_update_aws_summary_tables_with_no_cluster_info(self, mock_ocp_on_aws, mock_cluster_info): """Test that aws summary tables are not updated when there is no cluster info.""" # this is a yes or no check so false is fine mock_cluster_info.return_value = False start_date = self.dh.today.date() end_date = start_date + datetime.timedelta(days=1) with ProviderDBAccessor(self.aws_provider_uuid) as provider_accessor: provider = provider_accessor.get_provider() updater = OCPCloudParquetReportSummaryUpdater(schema="acct10001", provider=provider, manifest=None) updater.update_aws_summary_tables( self.ocpaws_provider_uuid, self.aws_test_provider_uuid, str(start_date), str(end_date) ) mock_ocp_on_aws.assert_not_called()
def setUp(self): """Set up each test.""" super().setUp() billing_start = self.date_accessor.today_with_timezone('UTC').replace( day=1) self.manifest_dict = { 'assembly_id': '1234', 'billing_period_start_datetime': billing_start, 'num_total_files': 2, 'provider_id': self.aws_provider.id, } self.provider_accessor = ProviderDBAccessor( provider_uuid=self.aws_test_provider_uuid) provider_id = self.provider_accessor.get_provider().id self.provider_accessor = ProviderDBAccessor( provider_uuid=self.aws_test_provider_uuid) self.updater = AWSReportChargeUpdater( schema=self.schema, provider_uuid=self.aws_test_provider_uuid, provider_id=provider_id, ) today = DateAccessor().today_with_timezone('UTC') bill = self.creator.create_cost_entry_bill(provider_id=provider_id, bill_date=today) cost_entry = self.creator.create_cost_entry(bill, today) product = self.creator.create_cost_entry_product() pricing = self.creator.create_cost_entry_pricing() reservation = self.creator.create_cost_entry_reservation() self.creator.create_cost_entry_line_item(bill, cost_entry, product, pricing, reservation) self.manifest = self.manifest_accessor.add(**self.manifest_dict) self.manifest_accessor.commit() with ProviderDBAccessor( self.aws_test_provider_uuid) as provider_accessor: self.provider = provider_accessor.get_provider()
def test_update_summary_tables_with_aws_provider(self, mock_utility, mock_ocp_on_aws, mock_map): """Test that summary tables are properly run for an OCP provider.""" fake_bills = [Mock(), Mock()] fake_bills[0].id = 1 fake_bills[1].id = 2 bill_ids = [str(bill.id) for bill in fake_bills] mock_utility.return_value = fake_bills start_date = self.dh.today end_date = start_date + datetime.timedelta(days=1) start_date_str = start_date.strftime("%Y-%m-%d") end_date_str = end_date.strftime("%Y-%m-%d") with ProviderDBAccessor(self.aws_provider_uuid) as provider_accessor: provider = provider_accessor.get_provider() with ProviderDBAccessor(self.ocp_test_provider_uuid) as provider_accessor: credentials = provider_accessor.get_credentials() cluster_id = credentials.get("cluster_id") mock_map.return_value = {self.ocp_test_provider_uuid: (self.aws_provider_uuid, Provider.PROVIDER_AWS)} updater = OCPCloudReportSummaryUpdater(schema="acct10001", provider=provider, manifest=None) updater.update_summary_tables(start_date_str, end_date_str) mock_ocp_on_aws.assert_called_with( start_date.date(), end_date.date(), cluster_id, bill_ids, decimal.Decimal(0) )
def refresh_materialized_views(schema_name, provider_type, manifest_id=None, provider_uuid=None, synchronous=False): """Refresh the database's materialized views for reporting.""" task_name = "masu.processor.tasks.refresh_materialized_views" cache_args = [schema_name] if not synchronous: worker_cache = WorkerCache() while worker_cache.single_task_is_running(task_name, cache_args): time.sleep(5) worker_cache.lock_single_task(task_name, cache_args) materialized_views = () if provider_type in (Provider.PROVIDER_AWS, Provider.PROVIDER_AWS_LOCAL): materialized_views = (AWS_MATERIALIZED_VIEWS + OCP_ON_AWS_MATERIALIZED_VIEWS + OCP_ON_INFRASTRUCTURE_MATERIALIZED_VIEWS) elif provider_type in (Provider.PROVIDER_OCP): materialized_views = (OCP_MATERIALIZED_VIEWS + OCP_ON_AWS_MATERIALIZED_VIEWS + OCP_ON_AZURE_MATERIALIZED_VIEWS + OCP_ON_INFRASTRUCTURE_MATERIALIZED_VIEWS) elif provider_type in (Provider.PROVIDER_AZURE, Provider.PROVIDER_AZURE_LOCAL): materialized_views = (AZURE_MATERIALIZED_VIEWS + OCP_ON_AZURE_MATERIALIZED_VIEWS + OCP_ON_INFRASTRUCTURE_MATERIALIZED_VIEWS) with schema_context(schema_name): for view in materialized_views: table_name = view._meta.db_table with connection.cursor() as cursor: cursor.execute( f"REFRESH MATERIALIZED VIEW CONCURRENTLY {table_name}") LOG.info(f"Refreshed {table_name}.") invalidate_view_cache_for_tenant_and_source_type(schema_name, provider_type) if provider_uuid: ProviderDBAccessor(provider_uuid).set_data_updated_timestamp() if manifest_id: # Processing for this monifest should be complete after this step with ReportManifestDBAccessor() as manifest_accessor: manifest = manifest_accessor.get_manifest_by_id(manifest_id) manifest_accessor.mark_manifest_as_completed(manifest) if not synchronous: worker_cache.release_single_task(task_name, cache_args)
def get_bills_from_provider(provider_uuid, schema, start_date=None, end_date=None): """ Return the AWS bill IDs given a provider UUID. Args: provider_uuid (str): Provider UUID. schema (str): Tenant schema start_date (datetime, str): Start date for bill IDs. end_date (datetime, str) End date for bill IDs. Returns: (list): AWS cost entry bill objects. """ if isinstance(start_date, datetime.datetime): start_date = start_date.replace(day=1) start_date = start_date.strftime('%Y-%m-%d') elif isinstance(start_date, str): start_date = datetime.datetime.strptime(start_date, '%Y-%m-%d') start_date = start_date.replace(day=1) start_date = start_date.strftime('%Y-%m-%d') if isinstance(end_date, datetime.datetime): end_date = end_date.strftime('%Y-%m-%d') with ReportingCommonDBAccessor() as reporting_common: column_map = reporting_common.column_map with ProviderDBAccessor(provider_uuid) as provider_accessor: provider = provider_accessor.get_provider() if provider.type not in (AMAZON_WEB_SERVICES, AWS_LOCAL_SERVICE_PROVIDER): err_msg = 'Provider UUID is not an AWS type. It is {}'.\ format(provider.type) LOG.warning(err_msg) return [] with AWSReportDBAccessor(schema, column_map) as report_accessor: bill_table_name = AWS_CUR_TABLE_MAP['bill'] bill_obj = getattr(report_accessor.report_schema, bill_table_name) bills = report_accessor.get_cost_entry_bills_query_by_provider( provider.id) if start_date: bills = bills.filter(bill_obj.billing_period_start >= start_date) if end_date: bills = bills.filter(bill_obj.billing_period_start <= end_date) bills = bills.all() return bills
def update_summary_tables(self, start_date, end_date): """Populate the summary tables for reporting. Args: start_date (str) The date to start populating the table. end_date (str) The date to end on. Returns None """ infra_map = self.get_infra_map() openshift_provider_uuids, infra_provider_uuids = self.get_openshift_and_infra_providers_lists( infra_map) if self._provider.type == Provider.PROVIDER_OCP and self._provider_uuid not in openshift_provider_uuids: infra_map = self._generate_ocp_infra_map_from_sql( start_date, end_date) elif self._provider.type in Provider.CLOUD_PROVIDER_LIST and self._provider_uuid not in infra_provider_uuids: # When running for an Infrastructure provider we want all # of the matching clusters to run infra_map = self._generate_ocp_infra_map_from_sql( start_date, end_date) # If running as an infrastructure provider (e.g. AWS) # this loop should run for all associated OpenShift clusters. # If running for an OpenShift provider, it should just run one time. for ocp_provider_uuid, infra_tuple in infra_map.items(): infra_provider_uuid = infra_tuple[0] infra_provider_type = infra_tuple[1] if infra_provider_type in (Provider.PROVIDER_AWS, Provider.PROVIDER_AWS_LOCAL): self.update_aws_summary_tables(ocp_provider_uuid, infra_provider_uuid, start_date, end_date) elif infra_provider_type in (Provider.PROVIDER_AZURE, Provider.PROVIDER_AZURE_LOCAL): self.update_azure_summary_tables(ocp_provider_uuid, infra_provider_uuid, start_date, end_date) # Update markup for OpenShift tables with ProviderDBAccessor(ocp_provider_uuid) as provider_accessor: OCPCostModelCostUpdater( self._schema, provider_accessor.provider)._update_markup_cost( start_date, end_date) if infra_map: self.refresh_openshift_on_infrastructure_views( OCP_ON_INFRASTRUCTURE_MATERIALIZED_VIEWS)
def get_account_alias_from_role_arn(role_arn, session=None): """ Get account ID for given RoleARN. Args: role_arn (String) AWS IAM RoleARN Returns: (String): Account ID """ provider_uuid = get_provider_uuid_from_arn(role_arn) context_key = "aws_list_account_aliases" if not session: session = get_assume_role_session(role_arn) iam_client = session.client("iam") account_id = role_arn.split(":")[-2] alias = account_id with ProviderDBAccessor(provider_uuid) as provider_accessor: context = provider_accessor.get_additional_context() list_aliases = context.get(context_key, True) if list_aliases: try: alias_response = iam_client.list_account_aliases() alias_list = alias_response.get("AccountAliases", []) # Note: Boto3 docs states that you can only have one alias per account # so the pop() should be ok... alias = alias_list.pop() if alias_list else None except ClientError as err: LOG.info("Unable to list account aliases. Reason: %s", str(err)) context[context_key] = False with ProviderDBAccessor(provider_uuid) as provider_accessor: provider_accessor.set_additional_context(context) return (account_id, alias)
def test_update_azure_summary_tables(self, mock_utility, mock_ocp_on_azure, mock_tag_summary, mock_map): """Test that summary tables are properly run for an OCP provider.""" fake_bills = MagicMock() fake_bills.__iter__.return_value = [Mock(), Mock()] first = Mock() bill_id = 1 first.return_value.id = bill_id fake_bills.first = first mock_utility.return_value = fake_bills start_date = self.dh.today.date() end_date = start_date + datetime.timedelta(days=1) with ProviderDBAccessor(self.azure_provider_uuid) as provider_accessor: provider = provider_accessor.get_provider() with ProviderDBAccessor( self.ocp_test_provider_uuid) as provider_accessor: credentials = provider_accessor.get_credentials() cluster_id = credentials.get("cluster_id") mock_map.return_value = { self.ocp_test_provider_uuid: (self.azure_provider_uuid, Provider.PROVIDER_AZURE) } updater = OCPCloudParquetReportSummaryUpdater(schema="acct10001", provider=provider, manifest=None) updater.update_azure_summary_tables(self.ocp_test_provider_uuid, self.azure_test_provider_uuid, start_date, end_date) mock_ocp_on_azure.assert_called_with( start_date, end_date, self.ocp_test_provider_uuid, self.azure_test_provider_uuid, cluster_id, bill_id, decimal.Decimal(0), )
def setUp(self): """Set up each test.""" super().setUp() billing_start = self.date_accessor.today_with_timezone('UTC').replace( day=1) self.manifest_dict = { 'assembly_id': '1234', 'billing_period_start_datetime': billing_start, 'num_total_files': 1, 'provider_id': self.azure_provider.id, } self.provider_accessor = ProviderDBAccessor( provider_uuid=self.azure_test_provider_uuid) provider_id = self.provider_accessor.get_provider().id self.updater = AzureReportChargeUpdater( schema=self.schema, provider_uuid=self.azure_test_provider_uuid, provider_id=provider_id) today = DateAccessor().today_with_timezone('UTC') bill = self.creator.create_azure_cost_entry_bill( provider_id=provider_id, bill_date=today) product = self.creator.create_azure_cost_entry_product() meter = self.creator.create_azure_meter() service = self.creator.create_azure_service() self.creator.create_azure_cost_entry_line_item(bill, product, meter, service) self.manifest = self.manifest_accessor.add(**self.manifest_dict) self.manifest_accessor.commit() with ProviderDBAccessor( self.azure_test_provider_uuid) as provider_accessor: self.provider = provider_accessor.get_provider()