def test_process_file_missing_manifest(self, mock_manifest_accessor, mock_stats_accessor, mock_processor): """Test the process_report_file functionality when manifest is missing.""" mock_manifest_accessor.get_manifest_by_id.return_value = None report_dir = tempfile.mkdtemp() path = '{}/{}'.format(report_dir, 'file1.csv') schema_name = self.test_schema provider = 'AWS' provider_uuid = self.aws_test_provider_uuid report_dict = {'file': path, 'compression': 'gzip', 'start_date': str(DateAccessor().today())} mock_proc = mock_processor() mock_stats_acc = mock_stats_accessor().__enter__() mock_manifest_acc = mock_manifest_accessor().__enter__() _process_report_file(schema_name, provider, provider_uuid, report_dict) mock_proc.process.assert_called() mock_stats_acc.log_last_started_datetime.assert_called() mock_stats_acc.log_last_completed_datetime.assert_called() mock_stats_acc.commit.assert_called() mock_manifest_acc.mark_manifest_as_updated.assert_not_called() shutil.rmtree(report_dir)
def test_process_file_non_initial_ingest( self, mock_manifest_accessor, mock_stats_accessor, mock_processor, mock_provider_accessor ): """Test the process_report_file functionality on non-initial ingest.""" report_dir = tempfile.mkdtemp() path = "{}/{}".format(report_dir, "file1.csv") schema_name = self.schema provider = Provider.PROVIDER_AWS provider_uuid = self.aws_provider_uuid report_dict = {"file": path, "compression": "gzip", "start_date": str(DateAccessor().today())} mock_proc = mock_processor() mock_stats_acc = mock_stats_accessor().__enter__() mock_manifest_acc = mock_manifest_accessor().__enter__() mock_provider_acc = mock_provider_accessor().__enter__() mock_provider_acc.get_setup_complete.return_value = True _process_report_file(schema_name, provider, provider_uuid, report_dict) mock_proc.process.assert_called() mock_proc.remove_processed_files.assert_called() mock_stats_acc.log_last_started_datetime.assert_called() mock_stats_acc.log_last_completed_datetime.assert_called() mock_manifest_acc.mark_manifest_as_updated.assert_called() mock_provider_acc.setup_complete.assert_called() shutil.rmtree(report_dir)
def test_process_file_missing_manifest(self, mock_manifest_accessor, mock_stats_accessor, mock_processor): """Test the process_report_file functionality when manifest is missing.""" mock_manifest_accessor.get_manifest_by_id.return_value = None report_dir = tempfile.mkdtemp() path = "{}/{}".format(report_dir, "file1.csv") schema_name = self.schema provider = Provider.PROVIDER_AWS provider_uuid = self.aws_provider_uuid report_dict = { "file": path, "compression": "gzip", "start_date": str(DateHelper().today) } mock_proc = mock_processor() mock_stats_acc = mock_stats_accessor().__enter__() mock_manifest_acc = mock_manifest_accessor().__enter__() _process_report_file(schema_name, provider, provider_uuid, report_dict) mock_proc.process.assert_called() mock_stats_acc.log_last_started_datetime.assert_called() mock_stats_acc.log_last_completed_datetime.assert_called() mock_manifest_acc.mark_manifest_as_updated.assert_not_called() shutil.rmtree(report_dir)
def test_process_file_non_initial_ingest(self, mock_manifest_accessor, mock_stats_accessor, mock_processor, mock_provider_accessor): """Test the process_report_file functionality on non-initial ingest.""" report_dir = tempfile.mkdtemp() path = '{}/{}'.format(report_dir, 'file1.csv') schema_name = self.schema provider = 'AWS' provider_uuid = self.aws_test_provider_uuid report_dict = { 'file': path, 'compression': 'gzip', 'start_date': str(DateAccessor().today()), } mock_proc = mock_processor() mock_stats_acc = mock_stats_accessor().__enter__() mock_manifest_acc = mock_manifest_accessor().__enter__() mock_provider_acc = mock_provider_accessor().__enter__() mock_provider_acc.get_setup_complete.return_value = True _process_report_file(schema_name, provider, provider_uuid, report_dict) mock_proc.process.assert_called() mock_proc.remove_processed_files.assert_called() mock_stats_acc.log_last_started_datetime.assert_called() mock_stats_acc.log_last_completed_datetime.assert_called() mock_manifest_acc.mark_manifest_as_updated.assert_called() mock_provider_acc.setup_complete.assert_called() shutil.rmtree(report_dir)
def test_process_report_files_with_transaction_atomic_error( self, mock_processor, mock_setup_complete): """Test than an exception rolls back the atomic transaction.""" path = '{}/{}'.format('test', 'file1.csv') schema_name = self.schema provider = Provider.PROVIDER_AWS provider_uuid = self.aws_provider_uuid manifest_dict = { 'assembly_id': '12345', 'billing_period_start_datetime': DateAccessor().today_with_timezone('UTC'), 'num_total_files': 2, 'provider_uuid': self.aws_provider_uuid, 'task': '170653c0-3e66-4b7e-a764-336496d7ca5a', } with ReportManifestDBAccessor() as manifest_accessor: manifest = manifest_accessor.add(**manifest_dict) manifest.save() manifest_id = manifest.id initial_update_time = manifest.manifest_updated_datetime with ReportStatsDBAccessor(path, manifest_id) as report_file_accessor: report_file_accessor.get_last_started_datetime() report_dict = { 'file': path, 'compression': 'gzip', 'start_date': str(DateAccessor().today()), 'manifest_id': manifest_id, } mock_setup_complete.side_effect = Exception with self.assertRaises(Exception): _process_report_file(schema_name, provider, provider_uuid, report_dict) with ReportStatsDBAccessor(path, manifest_id) as report_file_accessor: self.assertIsNone( report_file_accessor.get_last_completed_datetime()) with ReportManifestDBAccessor() as manifest_accessor: manifest = manifest_accessor.get_manifest_by_id(manifest_id) self.assertEqual(manifest.num_processed_files, 0) self.assertEqual(manifest.manifest_updated_datetime, initial_update_time) with ProviderDBAccessor( provider_uuid=provider_uuid) as provider_accessor: self.assertFalse(provider_accessor.get_setup_complete())
def process_report_file(schema_name, report_path, compression): """ Task to process a Report. Args: schema_name (String) db schema name report_path (String) path to downloaded reports compression (String) 'PLAIN' or 'GZIP' Returns: None """ _process_report_file(schema_name, report_path, compression) start_date = DateAccessor().today().date() LOG.info(f'Queueing update_summary_tables task for {schema_name}') update_summary_tables.delay(schema_name, start_date)
def test_process_file_exception(self, mock_stats_accessor, mock_processor): """Test the process_report_file functionality when exception is thrown.""" report_dir = tempfile.mkdtemp() path = "{}/{}".format(report_dir, "file1.csv") schema_name = self.schema provider = Provider.PROVIDER_AWS provider_uuid = self.aws_provider_uuid report_dict = {"file": path, "compression": "gzip", "start_date": str(DateAccessor().today())} mock_processor.side_effect = ReportProcessorError("mock error") mock_stats_acc = mock_stats_accessor().__enter__() with self.assertRaises(ReportProcessorError): _process_report_file(schema_name, provider, provider_uuid, report_dict) mock_stats_acc.log_last_started_datetime.assert_called() mock_stats_acc.log_last_completed_datetime.assert_not_called() shutil.rmtree(report_dir)
def test_process_file(self, mock_accessor, mock_processor): """Test the process_report_file functionality.""" report_dir = tempfile.mkdtemp() path = '{}/{}'.format(report_dir, 'file1.csv') request = {'report_path': path, 'compression': 'gzip', 'schema_name': 'testcustomer'} mock_proc = mock_processor() mock_acc = mock_accessor() _process_report_file(**request) mock_proc.process.assert_called() mock_acc.log_last_started_datetime.assert_called() mock_acc.log_last_completed_datetime.assert_called() mock_acc.commit.assert_called() shutil.rmtree(report_dir)
def test_process_report_files_with_transaction_atomic_error(self, mock_processor, mock_setup_complete): """Test than an exception rolls back the atomic transaction.""" path = "{}/{}".format("test", "file1.csv") schema_name = self.schema provider = Provider.PROVIDER_AWS provider_uuid = self.aws_provider_uuid manifest_dict = { "assembly_id": "12345", "billing_period_start_datetime": DateAccessor().today_with_timezone("UTC"), "num_total_files": 2, "provider_uuid": self.aws_provider_uuid, "task": "170653c0-3e66-4b7e-a764-336496d7ca5a", } with ReportManifestDBAccessor() as manifest_accessor: manifest = manifest_accessor.add(**manifest_dict) manifest.save() manifest_id = manifest.id initial_update_time = manifest.manifest_updated_datetime with ReportStatsDBAccessor(path, manifest_id) as report_file_accessor: report_file_accessor.get_last_started_datetime() report_dict = { "file": path, "compression": "gzip", "start_date": str(DateAccessor().today()), "manifest_id": manifest_id, } mock_setup_complete.side_effect = Exception with self.assertRaises(Exception): _process_report_file(schema_name, provider, provider_uuid, report_dict) with ReportStatsDBAccessor(path, manifest_id) as report_file_accessor: self.assertIsNone(report_file_accessor.get_last_completed_datetime()) with ReportManifestDBAccessor() as manifest_accessor: manifest = manifest_accessor.get_manifest_by_id(manifest_id) self.assertEqual(manifest.num_processed_files, 0) self.assertEqual(manifest.manifest_updated_datetime, initial_update_time) with ProviderDBAccessor(provider_uuid=provider_uuid) as provider_accessor: self.assertFalse(provider_accessor.get_setup_complete())
def test_process_file_exception(self, mock_stats_accessor, mock_processor): """Test the process_report_file functionality when exception is thrown.""" report_dir = tempfile.mkdtemp() path = '{}/{}'.format(report_dir, 'file1.csv') schema_name = self.test_schema provider = 'AWS' provider_uuid = self.aws_test_provider_uuid report_dict = {'file': path, 'compression': 'gzip', 'start_date': str(DateAccessor().today())} mock_processor.side_effect = ReportProcessorError('mock error') mock_stats_acc = mock_stats_accessor().__enter__() with self.assertRaises(ReportProcessorError): _process_report_file(schema_name, provider, provider_uuid, report_dict) mock_stats_acc.log_last_started_datetime.assert_called() mock_stats_acc.log_last_completed_datetime.assert_not_called() mock_stats_acc.commit.assert_called() shutil.rmtree(report_dir)
def process_report(request_id, report): """ Process line item report. Returns True when line item processing is complete. This is important because the listen_for_messages -> process_messages path must have a positive acknowledgement that line item processing is complete before committing. If the service goes down in the middle of processing (SIGTERM) we do not want a stray kafka commit to prematurely commit the message before processing has been complete. Args: request_id (Str): The request id report (Dict) - keys: value request_id: String, account: String, schema_name: String, manifest_id: Integer, provider_uuid: String, provider_type: String, current_file: String, date: DateTime Returns: True if line item report processing is complete. """ schema_name = report.get("schema_name") manifest_id = report.get("manifest_id") provider_uuid = str(report.get("provider_uuid")) provider_type = report.get("provider_type") date = report.get("date") # The create_table flag is used by the ParquetReportProcessor # to create a Hive/Trino table. report_dict = { "file": report.get("current_file"), "compression": UNCOMPRESSED, "manifest_id": manifest_id, "provider_uuid": provider_uuid, "request_id": request_id, "tracing_id": report.get("tracing_id"), "provider_type": "OCP", "start_date": date, "create_table": True, } try: return _process_report_file(schema_name, provider_type, report_dict) except NotImplementedError as err: LOG.info(f"NotImplementedError: {str(err)}") return True
def process_report(request_id, report): """ Process line item report. Returns True when line item processing is complete. This is important because the listen_for_messages -> process_messages path must have a positive acknowledgement that line item processing is complete before committing. If the service goes down in the middle of processing (SIGTERM) we do not want a stray kafka commit to prematurely commit the message before processing has been complete. Args: request_id (Str): The request id report (Dict) - keys: value request_id: String, account: String, schema_name: String, manifest_id: Integer, provider_uuid: String, provider_type: String, current_file: String, date: DateTime Returns: True if line item report processing is complete. """ schema_name = report.get("schema_name") manifest_id = report.get("manifest_id") provider_uuid = report.get("provider_uuid") provider_type = report.get("provider_type") report_dict = { "file": report.get("current_file"), "compression": UNCOMPRESSED, "manifest_id": manifest_id, "provider_uuid": provider_uuid, } return _process_report_file(schema_name, provider_type, provider_uuid, report_dict)
def get_report_files(self, customer_name, authentication, billing_source, provider_type, schema_name, provider_uuid, report_month): """ Task to download a Report and process the report. FIXME: A 2 hour timeout is arbitrarily set for in progress processing requests. Once we know a realistic processing time for the largest CUR file in production this value can be adjusted or made configurable. Args: customer_name (String): Name of the customer owning the cost usage report. authentication (String): Credential needed to access cost usage report in the backend provider. billing_source (String): Location of the cost usage report in the backend provider. provider_type (String): Koku defined provider type string. Example: Amazon = 'AWS' schema_name (String): Name of the DB schema Returns: None """ worker_stats.GET_REPORT_ATTEMPTS_COUNTER.labels( provider_type=provider_type).inc() month = parser.parse(report_month) reports = _get_report_files(self, customer_name, authentication, billing_source, provider_type, provider_uuid, month) try: stmt = (f"Reports to be processed:\n" f" schema_name: {customer_name}\n" f" provider: {provider_type}\n" f" provider_uuid: {provider_uuid}\n") for report in reports: stmt += " file: " + str(report["file"]) + "\n" LOG.info(stmt[:-1]) reports_to_summarize = [] for report_dict in reports: manifest_id = report_dict.get("manifest_id") file_name = os.path.basename(report_dict.get("file")) with ReportStatsDBAccessor(file_name, manifest_id) as stats: started_date = stats.get_last_started_datetime() completed_date = stats.get_last_completed_datetime() # Skip processing if already in progress. if started_date and not completed_date: expired_start_date = started_date + datetime.timedelta(hours=2) if DateAccessor().today_with_timezone( "UTC") < expired_start_date: LOG.info( "Skipping processing task for %s since it was started at: %s.", file_name, str(started_date)) continue # Skip processing if complete. if started_date and completed_date: LOG.info( "Skipping processing task for %s. Started on: %s and completed on: %s.", file_name, str(started_date), str(completed_date), ) continue stmt = (f"Processing starting:\n" f" schema_name: {customer_name}\n" f" provider: {provider_type}\n" f" provider_uuid: {provider_uuid}\n" f' file: {report_dict.get("file")}') LOG.info(stmt) worker_stats.PROCESS_REPORT_ATTEMPTS_COUNTER.labels( provider_type=provider_type).inc() _process_report_file(schema_name, provider_type, provider_uuid, report_dict) report_meta = {} known_manifest_ids = [ report.get("manifest_id") for report in reports_to_summarize ] if report_dict.get("manifest_id") not in known_manifest_ids: report_meta["schema_name"] = schema_name report_meta["provider_type"] = provider_type report_meta["provider_uuid"] = provider_uuid report_meta["manifest_id"] = report_dict.get("manifest_id") reports_to_summarize.append(report_meta) except ReportProcessorError as processing_error: worker_stats.PROCESS_REPORT_ERROR_COUNTER.labels( provider_type=provider_type).inc() LOG.error(str(processing_error)) raise processing_error return reports_to_summarize
def get_report_files( # noqa: C901 self, customer_name, authentication, billing_source, provider_type, schema_name, provider_uuid, report_month, report_context, tracing_id=None, ): """ Task to download a Report and process the report. FIXME: A 2 hour timeout is arbitrarily set for in progress processing requests. Once we know a realistic processing time for the largest CUR file in production this value can be adjusted or made configurable. Args: customer_name (String): Name of the customer owning the cost usage report. authentication (String): Credential needed to access cost usage report in the backend provider. billing_source (String): Location of the cost usage report in the backend provider. provider_type (String): Koku defined provider type string. Example: Amazon = 'AWS' schema_name (String): Name of the DB schema Returns: None """ try: worker_stats.GET_REPORT_ATTEMPTS_COUNTER.labels( provider_type=provider_type).inc() month = report_month if isinstance(report_month, str): month = parser.parse(report_month) report_file = report_context.get("key") cache_key = f"{provider_uuid}:{report_file}" tracing_id = report_context.get("assembly_id", "no-tracing-id") WorkerCache().add_task_to_cache(cache_key) context = { "account": customer_name[4:], "provider_uuid": provider_uuid } try: report_dict = _get_report_files( tracing_id, customer_name, authentication, billing_source, provider_type, provider_uuid, month, report_context, ) except (MasuProcessingError, MasuProviderError, ReportDownloaderError) as err: worker_stats.REPORT_FILE_DOWNLOAD_ERROR_COUNTER.labels( provider_type=provider_type).inc() WorkerCache().remove_task_from_cache(cache_key) LOG.warning(log_json(tracing_id, str(err), context)) return stmt = (f"Reports to be processed: " f" schema_name: {customer_name} " f" provider: {provider_type} " f" provider_uuid: {provider_uuid}") if report_dict: stmt += f" file: {report_dict['file']}" LOG.info(log_json(tracing_id, stmt, context)) else: WorkerCache().remove_task_from_cache(cache_key) return None report_meta = { "schema_name": schema_name, "provider_type": provider_type, "provider_uuid": provider_uuid, "manifest_id": report_dict.get("manifest_id"), "tracing_id": tracing_id, } try: stmt = (f"Processing starting: " f" schema_name: {customer_name} " f" provider: {provider_type} " f" provider_uuid: {provider_uuid} " f' file: {report_dict.get("file")}') LOG.info(log_json(tracing_id, stmt)) worker_stats.PROCESS_REPORT_ATTEMPTS_COUNTER.labels( provider_type=provider_type).inc() report_dict["tracing_id"] = tracing_id report_dict["provider_type"] = provider_type _process_report_file(schema_name, provider_type, report_dict) except (ReportProcessorError, ReportProcessorDBError) as processing_error: worker_stats.PROCESS_REPORT_ERROR_COUNTER.labels( provider_type=provider_type).inc() LOG.error(log_json(tracing_id, str(processing_error), context)) WorkerCache().remove_task_from_cache(cache_key) raise processing_error except NotImplementedError as err: LOG.info(log_json(tracing_id, str(err), context)) WorkerCache().remove_task_from_cache(cache_key) WorkerCache().remove_task_from_cache(cache_key) return report_meta except ReportDownloaderWarning as err: LOG.warning(log_json(tracing_id, str(err), context)) WorkerCache().remove_task_from_cache(cache_key) except Exception as err: worker_stats.PROCESS_REPORT_ERROR_COUNTER.labels( provider_type=provider_type).inc() LOG.error(log_json(tracing_id, str(err), context)) WorkerCache().remove_task_from_cache(cache_key)
def get_report_files( self, customer_name, authentication, billing_source, provider_type, schema_name, provider_uuid, report_month, report_context, ): """ Task to download a Report and process the report. FIXME: A 2 hour timeout is arbitrarily set for in progress processing requests. Once we know a realistic processing time for the largest CUR file in production this value can be adjusted or made configurable. Args: customer_name (String): Name of the customer owning the cost usage report. authentication (String): Credential needed to access cost usage report in the backend provider. billing_source (String): Location of the cost usage report in the backend provider. provider_type (String): Koku defined provider type string. Example: Amazon = 'AWS' schema_name (String): Name of the DB schema Returns: None """ worker_stats.GET_REPORT_ATTEMPTS_COUNTER.labels( provider_type=provider_type).inc() month = report_month if isinstance(report_month, str): month = parser.parse(report_month) cache_key = f"{provider_uuid}:{month.date()}" WorkerCache().add_task_to_cache(cache_key) report_dict = _get_report_files( self, customer_name, authentication, billing_source, provider_type, provider_uuid, month, cache_key, report_context, ) stmt = (f"Reports to be processed:\n" f" schema_name: {customer_name}\n" f" provider: {provider_type}\n" f" provider_uuid: {provider_uuid}\n") if report_dict: stmt += f" file: {report_dict['file']}" LOG.info(stmt) else: return None try: stmt = (f"Processing starting:\n" f" schema_name: {customer_name}\n" f" provider: {provider_type}\n" f" provider_uuid: {provider_uuid}\n" f' file: {report_dict.get("file")}') LOG.info(stmt) worker_stats.PROCESS_REPORT_ATTEMPTS_COUNTER.labels( provider_type=provider_type).inc() _process_report_file(schema_name, provider_type, report_dict) report_meta = { "schema_name": schema_name, "provider_type": provider_type, "provider_uuid": provider_uuid, "manifest_id": report_dict.get("manifest_id"), } except (ReportProcessorError, ReportProcessorDBError) as processing_error: worker_stats.PROCESS_REPORT_ERROR_COUNTER.labels( provider_type=provider_type).inc() LOG.error(str(processing_error)) WorkerCache().remove_task_from_cache(cache_key) raise processing_error WorkerCache().remove_task_from_cache(cache_key) start_date = report_dict.get("start_date") manifest_id = report_dict.get("manifest_id") if start_date: start_date_str = start_date.strftime("%Y-%m-%d") convert_to_parquet.delay( self.request.id, schema_name[4:], provider_uuid, provider_type, start_date_str, manifest_id, [report_context.get("local_file")], ) return report_meta
def get_report_files(self, customer_name, authentication, billing_source, provider_type, schema_name, provider_uuid, report_month): """ Task to download a Report and process the report. FIXME: A 2 hour timeout is arbitrarily set for in progress processing requests. Once we know a realistic processing time for the largest CUR file in production this value can be adjusted or made configurable. Args: customer_name (String): Name of the customer owning the cost usage report. authentication (String): Credential needed to access cost usage report in the backend provider. billing_source (String): Location of the cost usage report in the backend provider. provider_type (String): Koku defined provider type string. Example: Amazon = 'AWS' schema_name (String): Name of the DB schema Returns: None """ worker_stats.GET_REPORT_ATTEMPTS_COUNTER.labels( provider_type=provider_type).inc() month = report_month if isinstance(report_month, str): month = parser.parse(report_month) cache_key = f"{provider_uuid}:{month}" reports = _get_report_files(self, customer_name, authentication, billing_source, provider_type, provider_uuid, month, cache_key) stmt = (f"Reports to be processed:\n" f" schema_name: {customer_name}\n" f" provider: {provider_type}\n" f" provider_uuid: {provider_uuid}\n") for report in reports: stmt += " file: " + str(report["file"]) + "\n" LOG.info(stmt[:-1]) reports_to_summarize = [] start_date = None for report_dict in reports: with transaction.atomic(): try: manifest_id = report_dict.get("manifest_id") file_name = os.path.basename(report_dict.get("file")) with ReportStatsDBAccessor(file_name, manifest_id) as stats: started_date = stats.get_last_started_datetime() completed_date = stats.get_last_completed_datetime() # Skip processing if already in progress. if started_date and not completed_date: expired_start_date = started_date + datetime.timedelta( hours=Config.REPORT_PROCESSING_TIMEOUT_HOURS) if DateAccessor().today_with_timezone( "UTC") < expired_start_date: LOG.info( "Skipping processing task for %s since it was started at: %s.", file_name, str(started_date), ) continue stmt = (f"Processing starting:\n" f" schema_name: {customer_name}\n" f" provider: {provider_type}\n" f" provider_uuid: {provider_uuid}\n" f' file: {report_dict.get("file")}') LOG.info(stmt) if not start_date: start_date = report_dict.get("start_date") worker_stats.PROCESS_REPORT_ATTEMPTS_COUNTER.labels( provider_type=provider_type).inc() _process_report_file(schema_name, provider_type, provider_uuid, report_dict) known_manifest_ids = [ report.get("manifest_id") for report in reports_to_summarize ] if report_dict.get("manifest_id") not in known_manifest_ids: report_meta = { "schema_name": schema_name, "provider_type": provider_type, "provider_uuid": provider_uuid, "manifest_id": report_dict.get("manifest_id"), } reports_to_summarize.append(report_meta) except (ReportProcessorError, ReportProcessorDBError) as processing_error: worker_stats.PROCESS_REPORT_ERROR_COUNTER.labels( provider_type=provider_type).inc() LOG.error(str(processing_error)) WorkerCache().remove_task_from_cache(cache_key) raise processing_error WorkerCache().remove_task_from_cache(cache_key) if start_date: start_date_str = start_date.strftime("%Y-%m-%d") convert_to_parquet.delay(self.request.id, schema_name[4:], provider_uuid, provider_type, start_date_str, manifest_id) return reports_to_summarize