예제 #1
0
class TestUpdateSummaryTablesTask(MasuTestCase):
    """Test cases for Processor summary table Celery tasks."""

    @classmethod
    def setUpClass(cls):
        """Set up for the class."""
        super().setUpClass()
        cls.aws_tables = list(AWS_CUR_TABLE_MAP.values())
        cls.ocp_tables = list(OCP_REPORT_TABLE_MAP.values())
        cls.all_tables = list(AWS_CUR_TABLE_MAP.values()) + list(OCP_REPORT_TABLE_MAP.values())
        with ReportingCommonDBAccessor() as report_common_db:
            cls.column_map = report_common_db.column_map

        cls.creator = ReportObjectCreator(cls.schema, cls.column_map)

    def setUp(self):
        """Set up each test."""
        super().setUp()
        self.aws_accessor = AWSReportDBAccessor(schema=self.schema, column_map=self.column_map)
        self.ocp_accessor = OCPReportDBAccessor(schema=self.schema, column_map=self.column_map)

        # Populate some line item data so that the summary tables
        # have something to pull from
        self.start_date = DateAccessor().today_with_timezone("UTC").replace(day=1)
        last_month = self.start_date - relativedelta.relativedelta(months=1)

        for cost_entry_date in (self.start_date, last_month):
            bill = self.creator.create_cost_entry_bill(provider_uuid=self.aws_provider_uuid, bill_date=cost_entry_date)
            cost_entry = self.creator.create_cost_entry(bill, cost_entry_date)
            for family in ["Storage", "Compute Instance", "Database Storage", "Database Instance"]:
                product = self.creator.create_cost_entry_product(family)
                pricing = self.creator.create_cost_entry_pricing()
                reservation = self.creator.create_cost_entry_reservation()
                self.creator.create_cost_entry_line_item(bill, cost_entry, product, pricing, reservation)
        provider_ocp_uuid = self.ocp_test_provider_uuid

        with ProviderDBAccessor(provider_uuid=provider_ocp_uuid) as provider_accessor:
            provider_uuid = provider_accessor.get_provider().uuid

        cluster_id = self.ocp_provider_resource_name
        for period_date in (self.start_date, last_month):
            period = self.creator.create_ocp_report_period(
                provider_uuid=provider_uuid, period_date=period_date, cluster_id=cluster_id
            )
            report = self.creator.create_ocp_report(period, period_date)
            for _ in range(25):
                self.creator.create_ocp_usage_line_item(period, report)

    @patch("masu.processor.tasks.chain")
    @patch("masu.processor.tasks.refresh_materialized_views")
    @patch("masu.processor.tasks.update_charge_info")
    def test_update_summary_tables_aws(self, mock_charge_info, mock_views, mock_chain):
        """Test that the summary table task runs."""
        provider = Provider.PROVIDER_AWS
        provider_aws_uuid = self.aws_provider_uuid

        daily_table_name = AWS_CUR_TABLE_MAP["line_item_daily"]
        summary_table_name = AWS_CUR_TABLE_MAP["line_item_daily_summary"]
        start_date = self.start_date.replace(day=1) + relativedelta.relativedelta(months=-1)

        with schema_context(self.schema):
            daily_query = self.aws_accessor._get_db_obj_query(daily_table_name)
            summary_query = self.aws_accessor._get_db_obj_query(summary_table_name)

            initial_daily_count = daily_query.count()
            initial_summary_count = summary_query.count()

        self.assertEqual(initial_daily_count, 0)
        self.assertEqual(initial_summary_count, 0)

        update_summary_tables(self.schema, provider, provider_aws_uuid, start_date)

        with schema_context(self.schema):
            self.assertNotEqual(daily_query.count(), initial_daily_count)
            self.assertNotEqual(summary_query.count(), initial_summary_count)

        mock_chain.return_value.apply_async.assert_called()

    @patch("masu.processor.tasks.update_charge_info")
    def test_update_summary_tables_aws_end_date(self, mock_charge_info):
        """Test that the summary table task respects a date range."""
        provider = Provider.PROVIDER_AWS
        provider_aws_uuid = self.aws_provider_uuid
        ce_table_name = AWS_CUR_TABLE_MAP["cost_entry"]
        daily_table_name = AWS_CUR_TABLE_MAP["line_item_daily"]
        summary_table_name = AWS_CUR_TABLE_MAP["line_item_daily_summary"]

        start_date = self.start_date.replace(
            day=1, hour=0, minute=0, second=0, microsecond=0
        ) + relativedelta.relativedelta(months=-1)

        end_date = start_date + timedelta(days=10)
        end_date = end_date.replace(hour=23, minute=59, second=59)

        daily_table = getattr(self.aws_accessor.report_schema, daily_table_name)
        summary_table = getattr(self.aws_accessor.report_schema, summary_table_name)
        ce_table = getattr(self.aws_accessor.report_schema, ce_table_name)

        with schema_context(self.schema):
            ce_start_date = ce_table.objects.filter(interval_start__gte=start_date).aggregate(Min("interval_start"))[
                "interval_start__min"
            ]
            ce_end_date = ce_table.objects.filter(interval_start__lte=end_date).aggregate(Max("interval_start"))[
                "interval_start__max"
            ]

        # The summary tables will only include dates where there is data
        expected_start_date = max(start_date, ce_start_date)
        expected_start_date = expected_start_date.replace(hour=0, minute=0, second=0, microsecond=0)
        expected_end_date = min(end_date, ce_end_date)
        expected_end_date = expected_end_date.replace(hour=0, minute=0, second=0, microsecond=0)

        update_summary_tables(self.schema, provider, provider_aws_uuid, start_date, end_date)

        with schema_context(self.schema):
            daily_entry = daily_table.objects.all().aggregate(Min("usage_start"), Max("usage_end"))
            result_start_date = daily_entry["usage_start__min"]
            result_end_date = daily_entry["usage_end__max"]

        self.assertEqual(result_start_date, expected_start_date)
        self.assertEqual(result_end_date, expected_end_date)

        with schema_context(self.schema):
            summary_entry = summary_table.objects.all().aggregate(Min("usage_start"), Max("usage_end"))
            result_start_date = summary_entry["usage_start__min"]
            result_end_date = summary_entry["usage_end__max"]

        self.assertEqual(result_start_date, expected_start_date)
        self.assertEqual(result_end_date, expected_end_date)

    @patch("masu.processor.tasks.chain")
    @patch("masu.processor.tasks.refresh_materialized_views")
    @patch("masu.processor.tasks.update_charge_info")
    @patch("masu.database.cost_model_db_accessor.CostModelDBAccessor._make_rate_by_metric_map")
    @patch("masu.database.cost_model_db_accessor.CostModelDBAccessor.get_markup")
    def test_update_summary_tables_ocp(self, mock_markup, mock_rate_map, mock_charge_info, mock_view, mock_chain):
        """Test that the summary table task runs."""
        markup = {}
        mem_rate = {"tiered_rates": [{"value": "1.5", "unit": "USD"}]}
        cpu_rate = {"tiered_rates": [{"value": "2.5", "unit": "USD"}]}
        rate_metric_map = {"cpu_core_usage_per_hour": cpu_rate, "memory_gb_usage_per_hour": mem_rate}

        mock_markup.return_value = markup
        mock_rate_map.return_value = rate_metric_map

        provider = Provider.PROVIDER_OCP
        provider_ocp_uuid = self.ocp_test_provider_uuid

        daily_table_name = OCP_REPORT_TABLE_MAP["line_item_daily"]
        start_date = self.start_date.replace(
            day=1, hour=0, minute=0, second=0, microsecond=0
        ) + relativedelta.relativedelta(months=-1)
        end_date = start_date + timedelta(days=10)

        with schema_context(self.schema):
            daily_query = self.ocp_accessor._get_db_obj_query(daily_table_name)

            initial_daily_count = daily_query.count()

        self.assertEqual(initial_daily_count, 0)
        update_summary_tables(self.schema, provider, provider_ocp_uuid, start_date, end_date)

        with schema_context(self.schema):
            self.assertNotEqual(daily_query.count(), initial_daily_count)

        update_charge_info(
            schema_name=self.schema, provider_uuid=provider_ocp_uuid, start_date=start_date, end_date=end_date
        )

        table_name = OCP_REPORT_TABLE_MAP["line_item_daily_summary"]
        with ProviderDBAccessor(provider_ocp_uuid) as provider_accessor:
            provider_obj = provider_accessor.get_provider()

        usage_period_qry = self.ocp_accessor.get_usage_period_query_by_provider(provider_obj.uuid)
        with schema_context(self.schema):
            cluster_id = usage_period_qry.first().cluster_id

            items = self.ocp_accessor._get_db_obj_query(table_name).filter(cluster_id=cluster_id)
            for item in items:
                self.assertIsNotNone(item.pod_charge_memory_gigabyte_hours)
                self.assertIsNotNone(item.pod_charge_cpu_core_hours)

            storage_daily_name = OCP_REPORT_TABLE_MAP["storage_line_item_daily"]

            items = self.ocp_accessor._get_db_obj_query(storage_daily_name).filter(cluster_id=cluster_id)
            for item in items:
                self.assertIsNotNone(item.volume_request_storage_byte_seconds)
                self.assertIsNotNone(item.persistentvolumeclaim_usage_byte_seconds)

            storage_summary_name = OCP_REPORT_TABLE_MAP["line_item_daily_summary"]
            items = self.ocp_accessor._get_db_obj_query(storage_summary_name).filter(
                cluster_id=cluster_id, data_source="Storage"
            )
            for item in items:
                self.assertIsNotNone(item.volume_request_storage_gigabyte_months)
                self.assertIsNotNone(item.persistentvolumeclaim_usage_gigabyte_months)

        mock_chain.return_value.apply_async.assert_called()

    @patch("masu.processor.tasks.update_charge_info")
    @patch("masu.database.cost_model_db_accessor.CostModelDBAccessor.get_memory_gb_usage_per_hour_rates")
    @patch("masu.database.cost_model_db_accessor.CostModelDBAccessor.get_cpu_core_usage_per_hour_rates")
    def test_update_summary_tables_ocp_end_date(self, mock_cpu_rate, mock_mem_rate, mock_charge_info):
        """Test that the summary table task respects a date range."""
        mock_cpu_rate.return_value = 1.5
        mock_mem_rate.return_value = 2.5
        provider = Provider.PROVIDER_OCP
        provider_ocp_uuid = self.ocp_test_provider_uuid
        ce_table_name = OCP_REPORT_TABLE_MAP["report"]
        daily_table_name = OCP_REPORT_TABLE_MAP["line_item_daily"]

        start_date = self.start_date.replace(
            day=1, hour=0, minute=0, second=0, microsecond=0
        ) + relativedelta.relativedelta(months=-1)

        end_date = start_date + timedelta(days=10)
        end_date = end_date.replace(hour=23, minute=59, second=59)

        daily_table = getattr(self.ocp_accessor.report_schema, daily_table_name)
        ce_table = getattr(self.ocp_accessor.report_schema, ce_table_name)

        with schema_context(self.schema):
            ce_start_date = ce_table.objects.filter(interval_start__gte=start_date).aggregate(Min("interval_start"))[
                "interval_start__min"
            ]

            ce_end_date = ce_table.objects.filter(interval_start__lte=end_date).aggregate(Max("interval_start"))[
                "interval_start__max"
            ]

        # The summary tables will only include dates where there is data
        expected_start_date = max(start_date, ce_start_date)
        expected_start_date = expected_start_date.replace(hour=0, minute=0, second=0, microsecond=0)
        expected_end_date = min(end_date, ce_end_date)
        expected_end_date = expected_end_date.replace(hour=0, minute=0, second=0, microsecond=0)

        update_summary_tables(self.schema, provider, provider_ocp_uuid, start_date, end_date)
        with schema_context(self.schema):
            daily_entry = daily_table.objects.all().aggregate(Min("usage_start"), Max("usage_end"))
            result_start_date = daily_entry["usage_start__min"]
            result_end_date = daily_entry["usage_end__max"]

        self.assertEqual(result_start_date, expected_start_date)
        self.assertEqual(result_end_date, expected_end_date)

    @patch("masu.processor.tasks.update_summary_tables")
    def test_get_report_data_for_all_providers(self, mock_update):
        """Test GET report_data endpoint with provider_uuid=*."""
        start_date = date.today()
        update_all_summary_tables(start_date)

        mock_update.delay.assert_called_with(ANY, ANY, ANY, str(start_date), ANY)

    def test_refresh_materialized_views(self):
        """Test that materialized views are refreshed."""
        manifest_dict = {
            "assembly_id": "12345",
            "billing_period_start_datetime": DateAccessor().today_with_timezone("UTC"),
            "num_total_files": 2,
            "provider_uuid": self.aws_provider_uuid,
            "task": "170653c0-3e66-4b7e-a764-336496d7ca5a",
        }
        fake_aws = FakeAWSCostData(self.aws_provider)
        generator = AWSReportDataGenerator(self.tenant)
        generator.add_data_to_tenant(fake_aws)

        with ReportManifestDBAccessor() as manifest_accessor:
            manifest = manifest_accessor.add(**manifest_dict)
            manifest.save()

        refresh_materialized_views(self.schema, Provider.PROVIDER_AWS, manifest_id=manifest.id)

        views_to_check = [view for view in AWS_MATERIALIZED_VIEWS if "Cost" in view._meta.db_table]

        with schema_context(self.schema):
            for view in views_to_check:
                self.assertNotEqual(view.objects.count(), 0)

        with ReportManifestDBAccessor() as manifest_accessor:
            manifest = manifest_accessor.get_manifest_by_id(manifest.id)
            self.assertIsNotNone(manifest.manifest_completed_datetime)

    def test_vacuum_schema(self):
        """Test that the vacuum schema task runs."""
        logging.disable(logging.NOTSET)
        expected = "INFO:masu.processor.tasks:VACUUM"
        with self.assertLogs("masu.processor.tasks", level="INFO") as logger:
            vacuum_schema(self.schema)
            self.assertIn(expected, logger.output)
예제 #2
0
class AWSReportProcessorTest(MasuTestCase):
    """Test Cases for the AWSReportProcessor object."""

    @classmethod
    def setUpClass(cls):
        """Set up the test class with required objects."""
        super().setUpClass()
        cls.test_report_test_path = './koku/masu/test/data/test_cur.csv'
        cls.test_report_gzip_test_path = './koku/masu/test/data/test_cur.csv.gz'

        cls.date_accessor = DateAccessor()
        cls.manifest_accessor = ReportManifestDBAccessor()

        with ReportingCommonDBAccessor() as report_common_db:
            cls.column_map = report_common_db.column_map

        _report_tables = copy.deepcopy(AWS_CUR_TABLE_MAP)
        _report_tables.pop('line_item_daily', None)
        _report_tables.pop('line_item_daily_summary', None)
        _report_tables.pop('tags_summary', None)
        cls.report_tables = list(_report_tables.values())
        # Grab a single row of test data to work with
        with open(cls.test_report_test_path, 'r') as f:
            reader = csv.DictReader(f)
            cls.row = next(reader)

    def setUp(self):
        """Set up shared variables."""
        super().setUp()

        self.temp_dir = tempfile.mkdtemp()
        self.test_report = f'{self.temp_dir}/test_cur.csv'
        self.test_report_gzip = f'{self.temp_dir}/test_cur.csv.gz'

        shutil.copy2(self.test_report_test_path, self.test_report)
        shutil.copy2(self.test_report_gzip_test_path, self.test_report_gzip)

        self.processor = AWSReportProcessor(
            schema_name=self.schema,
            report_path=self.test_report,
            compression=UNCOMPRESSED,
            provider_uuid=self.aws_provider_uuid,
        )

        billing_start = self.date_accessor.today_with_timezone('UTC').replace(
            year=2018, month=6, day=1, hour=0, minute=0, second=0
        )
        self.assembly_id = '1234'
        self.manifest_dict = {
            'assembly_id': self.assembly_id,
            'billing_period_start_datetime': billing_start,
            'num_total_files': 2,
            'provider_uuid': self.aws_provider_uuid,
        }

        self.accessor = AWSReportDBAccessor(self.schema, self.column_map)
        self.report_schema = self.accessor.report_schema
        self.manifest = self.manifest_accessor.add(**self.manifest_dict)

    def tearDown(self):
        """Return the database to a pre-test state."""
        super().tearDown()

        shutil.rmtree(self.temp_dir)

        self.processor.processed_report.remove_processed_rows()
        self.processor.line_item_columns = None

    def test_initializer(self):
        """Test initializer."""
        self.assertIsNotNone(self.processor._schema)
        self.assertIsNotNone(self.processor._report_path)
        self.assertIsNotNone(self.processor._report_name)
        self.assertIsNotNone(self.processor._compression)
        self.assertEqual(self.processor._datetime_format, Config.AWS_DATETIME_STR_FORMAT)
        self.assertEqual(self.processor._batch_size, Config.REPORT_PROCESSING_BATCH_SIZE)

    def test_initializer_unsupported_compression(self):
        """Assert that an error is raised for an invalid compression."""
        with self.assertRaises(MasuProcessingError):
            AWSReportProcessor(
                schema_name=self.schema,
                report_path=self.test_report,
                compression='unsupported',
                provider_uuid=self.aws_provider_uuid,
            )

    def test_process_default(self):
        """Test the processing of an uncompressed file."""
        counts = {}
        processor = AWSReportProcessor(
            schema_name=self.schema,
            report_path=self.test_report,
            compression=UNCOMPRESSED,
            provider_uuid=self.aws_provider_uuid,
            manifest_id=self.manifest.id,
        )
        report_db = self.accessor
        report_schema = report_db.report_schema
        for table_name in self.report_tables:
            table = getattr(report_schema, table_name)
            with schema_context(self.schema):
                count = table.objects.count()
            counts[table_name] = count

        bill_date = self.manifest.billing_period_start_datetime.date()

        expected = (
            f'INFO:masu.processor.report_processor_base:Processing bill starting on {bill_date}.\n'
            f' Processing entire month.\n'
            f' schema_name: {self.schema},\n'
            f' provider_uuid: {self.aws_provider_uuid},\n'
            f' manifest_id: {self.manifest.id}'
        )
        logging.disable(
            logging.NOTSET
        )  # We are currently disabling all logging below CRITICAL in masu/__init__.py
        with self.assertLogs('masu.processor.report_processor_base', level='INFO') as logger:
            processor.process()
            self.assertIn(expected, logger.output)

        for table_name in self.report_tables:
            table = getattr(report_schema, table_name)
            with schema_context(self.schema):
                count = table.objects.count()

            if table_name in (
                'reporting_awscostentryreservation',
                'reporting_ocpawscostlineitem_daily_summary',
                'reporting_ocpawscostlineitem_project_daily_summary',
            ):
                self.assertTrue(count >= counts[table_name])
            else:
                self.assertTrue(count > counts[table_name])

        self.assertFalse(os.path.exists(self.test_report))

    def test_process_gzip(self):
        """Test the processing of a gzip compressed file."""
        counts = {}
        processor = AWSReportProcessor(
            schema_name=self.schema,
            report_path=self.test_report_gzip,
            compression=GZIP_COMPRESSED,
            provider_uuid=self.aws_provider_uuid,
        )
        report_db = self.accessor
        report_schema = report_db.report_schema
        for table_name in self.report_tables:
            table = getattr(report_schema, table_name)
            with schema_context(self.schema):
                count = table.objects.count()
            counts[table_name] = count

        processor.process()

        for table_name in self.report_tables:
            table = getattr(report_schema, table_name)
            with schema_context(self.schema):
                count = table.objects.count()

            if table_name in (
                'reporting_awscostentryreservation',
                'reporting_ocpawscostlineitem_daily_summary',
                'reporting_ocpawscostlineitem_project_daily_summary',
            ):
                self.assertTrue(count >= counts[table_name])
            else:
                self.assertTrue(count > counts[table_name])

    def test_process_duplicates(self):
        """Test that row duplicates are not inserted into the DB."""
        counts = {}
        processor = AWSReportProcessor(
            schema_name=self.schema,
            report_path=self.test_report,
            compression=UNCOMPRESSED,
            provider_uuid=self.aws_provider_uuid,
        )

        # Process for the first time
        processor.process()
        report_db = self.accessor
        report_schema = report_db.report_schema

        for table_name in self.report_tables:
            table = getattr(report_schema, table_name)
            with schema_context(self.schema):
                count = table.objects.count()
            counts[table_name] = count

        # Wipe stale data
        with schema_context(self.schema):
            self.accessor._get_db_obj_query(AWS_CUR_TABLE_MAP['line_item']).delete()

        shutil.copy2(self.test_report_test_path, self.test_report)

        processor = AWSReportProcessor(
            schema_name=self.schema,
            report_path=self.test_report,
            compression=UNCOMPRESSED,
            provider_uuid=self.aws_provider_uuid,
        )
        # Process for the second time
        processor.process()

        for table_name in self.report_tables:
            table = getattr(report_schema, table_name)
            with schema_context(self.schema):
                count = table.objects.count()
            self.assertTrue(count == counts[table_name])

    def test_process_finalized_rows(self):
        """Test that a finalized bill is processed properly."""
        data = []
        table_name = AWS_CUR_TABLE_MAP['line_item']

        with open(self.test_report, 'r') as f:
            reader = csv.DictReader(f)
            for row in reader:
                data.append(row)

        for row in data:
            row['bill/InvoiceId'] = '12345'

        tmp_file = '/tmp/test_process_finalized_rows.csv'
        field_names = data[0].keys()

        with open(tmp_file, 'w') as f:
            writer = csv.DictWriter(f, fieldnames=field_names)
            writer.writeheader()
            writer.writerows(data)

        processor = AWSReportProcessor(
            schema_name=self.schema,
            report_path=self.test_report,
            compression=UNCOMPRESSED,
            provider_uuid=self.aws_provider_uuid,
        )

        # Process for the first time
        processor.process()
        report_db = self.accessor
        report_schema = report_db.report_schema

        bill_table_name = AWS_CUR_TABLE_MAP['bill']
        bill_table = getattr(report_schema, bill_table_name)
        with schema_context(self.schema):
            bill = bill_table.objects.first()
            self.assertIsNone(bill.finalized_datetime)

        table = getattr(report_schema, table_name)
        with schema_context(self.schema):
            orig_count = table.objects.count()

        # Wipe stale data
        with schema_context(self.schema):
            self.accessor._get_db_obj_query(table_name).delete()

        processor = AWSReportProcessor(
            schema_name=self.schema,
            report_path=tmp_file,
            compression=UNCOMPRESSED,
            provider_uuid=self.aws_provider_uuid,
        )
        # Process for the second time
        processor.process()

        with schema_context(self.schema):
            count = table.objects.count()
            self.assertTrue(count == orig_count)
            count = table.objects.filter(invoice_id__isnull=False).count()
            self.assertTrue(count == orig_count)

        with schema_context(self.schema):
            bill = bill_table.objects.first()
            self.assertIsNotNone(bill.finalized_datetime)

    def test_process_finalized_rows_small_batch_size(self):
        """Test that a finalized bill is processed properly on batch size."""
        data = []
        table_name = AWS_CUR_TABLE_MAP['line_item']

        with open(self.test_report, 'r') as f:
            reader = csv.DictReader(f)
            for row in reader:
                data.append(row)

        for row in data:
            row['bill/InvoiceId'] = '12345'

        tmp_file = '/tmp/test_process_finalized_rows.csv'
        field_names = data[0].keys()

        with open(tmp_file, 'w') as f:
            writer = csv.DictWriter(f, fieldnames=field_names)
            writer.writeheader()
            writer.writerows(data)

        processor = AWSReportProcessor(
            schema_name=self.schema,
            report_path=self.test_report,
            compression=UNCOMPRESSED,
            provider_uuid=self.aws_provider_uuid,
        )

        # Process for the first time
        processor.process()
        report_db = self.accessor
        report_schema = report_db.report_schema

        bill_table_name = AWS_CUR_TABLE_MAP['bill']
        bill_table = getattr(report_schema, bill_table_name)
        with schema_context(self.schema):
            bill = bill_table.objects.first()
            self.assertIsNone(bill.finalized_datetime)

        table = getattr(report_schema, table_name)
        with schema_context(self.schema):
            orig_count = table.objects.count()

        # Wipe stale data
        with schema_context(self.schema):
            self.accessor._get_db_obj_query(table_name).delete()

        processor = AWSReportProcessor(
            schema_name=self.schema,
            report_path=tmp_file,
            compression=UNCOMPRESSED,
            provider_uuid=self.aws_provider_uuid,
        )
        processor._batch_size = 2
        # Process for the second time
        processor.process()

        with schema_context(self.schema):
            count = table.objects.count()
            self.assertTrue(count == orig_count)
            count = table.objects.filter(invoice_id__isnull=False).count()
            self.assertTrue(count == orig_count)

        with schema_context(self.schema):
            bill = bill_table.objects.first()
            self.assertIsNotNone(bill.finalized_datetime)

    def test_do_not_overwrite_finalized_bill_timestamp(self):
        """Test that a finalized bill timestamp does not get overwritten."""
        data = []
        with open(self.test_report, 'r') as f:
            reader = csv.DictReader(f)
            for row in reader:
                data.append(row)

        for row in data:
            row['bill/InvoiceId'] = '12345'

        tmp_file = '/tmp/test_process_finalized_rows.csv'
        field_names = data[0].keys()

        with open(tmp_file, 'w') as f:
            writer = csv.DictWriter(f, fieldnames=field_names)
            writer.writeheader()
            writer.writerows(data)

        processor = AWSReportProcessor(
            schema_name=self.schema,
            report_path=self.test_report,
            compression=UNCOMPRESSED,
            provider_uuid=self.aws_provider_uuid,
        )

        # Process for the first time
        processor.process()
        report_db = self.accessor
        report_schema = report_db.report_schema

        bill_table_name = AWS_CUR_TABLE_MAP['bill']
        bill_table = getattr(report_schema, bill_table_name)
        with schema_context(self.schema):
            bill = bill_table.objects.first()

        with open(tmp_file, 'w') as f:
            writer = csv.DictWriter(f, fieldnames=field_names)
            writer.writeheader()
            writer.writerows(data)

        processor = AWSReportProcessor(
            schema_name=self.schema,
            report_path=tmp_file,
            compression=UNCOMPRESSED,
            provider_uuid=self.aws_provider_uuid,
        )
        # Process for the second time
        processor.process()

        finalized_datetime = bill.finalized_datetime

        with open(tmp_file, 'w') as f:
            writer = csv.DictWriter(f, fieldnames=field_names)
            writer.writeheader()
            writer.writerows(data)

        processor = AWSReportProcessor(
            schema_name=self.schema,
            report_path=tmp_file,
            compression=UNCOMPRESSED,
            provider_uuid=self.aws_provider_uuid,
        )
        # Process for the third time to make sure the timestamp is the same
        processor.process()
        self.assertEqual(bill.finalized_datetime, finalized_datetime)

    def test_get_file_opener_default(self):
        """Test that the default file opener is returned."""
        opener, mode = self.processor._get_file_opener(UNCOMPRESSED)

        self.assertEqual(opener, open)
        self.assertEqual(mode, 'r')

    def test_get_file_opener_gzip(self):
        """Test that the gzip file opener is returned."""
        opener, mode = self.processor._get_file_opener(GZIP_COMPRESSED)

        self.assertEqual(opener, gzip.open)
        self.assertEqual(mode, 'rt')

    def test_update_mappings(self):
        """Test that mappings are updated."""
        test_entry = {'key': 'value'}
        counts = {}
        ce_maps = {
            'cost_entry': self.processor.existing_cost_entry_map,
            'product': self.processor.existing_product_map,
            'pricing': self.processor.existing_pricing_map,
            'reservation': self.processor.existing_reservation_map,
        }

        for name, ce_map in ce_maps.items():
            counts[name] = len(ce_map.values())
            ce_map.update(test_entry)

        self.processor._update_mappings()

        for name, ce_map in ce_maps.items():
            self.assertTrue(len(ce_map.values()) > counts[name])
            for key in test_entry:
                self.assertIn(key, ce_map)

    def test_write_processed_rows_to_csv(self):
        """Test that the CSV bulk upload file contains proper data."""
        bill_id = self.processor._create_cost_entry_bill(self.row, self.accessor)
        cost_entry_id = self.processor._create_cost_entry(self.row, bill_id, self.accessor)
        product_id = self.processor._create_cost_entry_product(self.row, self.accessor)
        pricing_id = self.processor._create_cost_entry_pricing(self.row, self.accessor)
        reservation_id = self.processor._create_cost_entry_reservation(self.row, self.accessor)
        self.processor._create_cost_entry_line_item(
            self.row, cost_entry_id, bill_id, product_id, pricing_id, reservation_id, self.accessor,
        )

        file_obj = self.processor._write_processed_rows_to_csv()

        line_item_data = self.processor.processed_report.line_items.pop()
        # Convert data to CSV format
        expected_values = [str(value) if value else None for value in line_item_data.values()]

        reader = csv.reader(file_obj)
        new_row = next(reader)
        new_row = new_row[0].split('\t')
        actual = {}

        for i, key in enumerate(line_item_data.keys()):
            actual[key] = new_row[i] if new_row[i] else None

        self.assertEqual(actual.keys(), line_item_data.keys())
        self.assertEqual(list(actual.values()), expected_values)

    def test_get_data_for_table(self):
        """Test that a row is disected into appropriate data structures."""
        column_map = self.column_map

        for table_name in self.report_tables:
            expected_columns = sorted(column_map[table_name].values())
            data = self.processor._get_data_for_table(self.row, table_name)

            for key in data:
                self.assertIn(key, expected_columns)

    def test_process_tags(self):
        """Test that tags are properly packaged in a JSON string."""
        row = {
            'resourceTags/user:environment': 'prod',
            'notATag': 'value',
            'resourceTags/System': 'value',
            'resourceTags/system:system_key': 'system_value',
        }
        expected = {'environment': 'prod', 'system_key': 'system_value'}
        actual = json.loads(self.processor._process_tags(row))

        self.assertNotIn(row['notATag'], actual)
        self.assertEqual(expected, actual)

    def test_get_cost_entry_time_interval(self):
        """Test that an interval string is properly split."""
        fmt = Config.AWS_DATETIME_STR_FORMAT
        end = datetime.datetime.utcnow()
        expected_start = (end - datetime.timedelta(days=1)).strftime(fmt)
        expected_end = end.strftime(fmt)
        interval = expected_start + '/' + expected_end

        actual_start, actual_end = self.processor._get_cost_entry_time_interval(interval)

        self.assertEqual(expected_start, actual_start)
        self.assertEqual(expected_end, actual_end)

    def test_create_cost_entry_bill(self):
        """Test that a cost entry bill id is returned."""
        table_name = AWS_CUR_TABLE_MAP['bill']

        bill_id = self.processor._create_cost_entry_bill(self.row, self.accessor)

        self.assertIsNotNone(bill_id)

        query = self.accessor._get_db_obj_query(table_name)
        id_in_db = query.order_by('-id').first().id
        provider_uuid = query.order_by('-id').first().provider_id

        self.assertEqual(bill_id, id_in_db)
        self.assertIsNotNone(provider_uuid)

    def test_create_cost_entry_bill_existing(self):
        """Test that a cost entry bill id is returned from an existing bill."""
        table_name = AWS_CUR_TABLE_MAP['bill']

        bill_id = self.processor._create_cost_entry_bill(self.row, self.accessor)

        query = self.accessor._get_db_obj_query(table_name)
        bill = query.first()

        self.processor.current_bill = bill

        new_bill_id = self.processor._create_cost_entry_bill(self.row, self.accessor)

        self.assertEqual(bill_id, new_bill_id)

        self.processor.current_bill = None

    def test_create_cost_entry(self):
        """Test that a cost entry id is returned."""
        table_name = AWS_CUR_TABLE_MAP['cost_entry']

        bill_id = self.processor._create_cost_entry_bill(self.row, self.accessor)

        cost_entry_id = self.processor._create_cost_entry(self.row, bill_id, self.accessor)

        self.assertIsNotNone(cost_entry_id)

        query = self.accessor._get_db_obj_query(table_name)
        id_in_db = query.order_by('-id').first().id

        self.assertEqual(cost_entry_id, id_in_db)

    def test_create_cost_entry_existing(self):
        """Test that a cost entry id is returned from an existing entry."""
        bill_id = self.processor._create_cost_entry_bill(self.row, self.accessor)

        interval = self.row.get('identity/TimeInterval')
        start, _ = self.processor._get_cost_entry_time_interval(interval)
        key = (bill_id, start)
        expected_id = random.randint(1, 9)
        self.processor.existing_cost_entry_map[key] = expected_id

        cost_entry_id = self.processor._create_cost_entry(self.row, bill_id, self.accessor)
        self.assertEqual(cost_entry_id, expected_id)

    def test_create_cost_entry_line_item(self):
        """Test that line item data is returned properly."""
        bill_id = self.processor._create_cost_entry_bill(self.row, self.accessor)
        cost_entry_id = self.processor._create_cost_entry(self.row, bill_id, self.accessor)
        product_id = self.processor._create_cost_entry_product(self.row, self.accessor)
        pricing_id = self.processor._create_cost_entry_pricing(self.row, self.accessor)
        reservation_id = self.processor._create_cost_entry_reservation(self.row, self.accessor)

        self.processor._create_cost_entry_line_item(
            self.row, cost_entry_id, bill_id, product_id, pricing_id, reservation_id, self.accessor,
        )

        line_item = None
        if self.processor.processed_report.line_items:
            line_item = self.processor.processed_report.line_items[-1]

        self.assertIsNotNone(line_item)
        self.assertIn('tags', line_item)
        self.assertEqual(line_item.get('cost_entry_id'), cost_entry_id)
        self.assertEqual(line_item.get('cost_entry_bill_id'), bill_id)
        self.assertEqual(line_item.get('cost_entry_product_id'), product_id)
        self.assertEqual(line_item.get('cost_entry_pricing_id'), pricing_id)
        self.assertEqual(line_item.get('cost_entry_reservation_id'), reservation_id)

        self.assertIsNotNone(self.processor.line_item_columns)

    def test_create_cost_entry_product(self):
        """Test that a cost entry product id is returned."""
        table_name = AWS_CUR_TABLE_MAP['product']

        product_id = self.processor._create_cost_entry_product(self.row, self.accessor)

        self.assertIsNotNone(product_id)

        query = self.accessor._get_db_obj_query(table_name)
        id_in_db = query.order_by('-id').first().id

        self.assertEqual(product_id, id_in_db)

    def test_create_cost_entry_product_already_processed(self):
        """Test that an already processed product id is returned."""
        expected_id = random.randint(1, 9)
        sku = self.row.get('product/sku')
        product_name = self.row.get('product/ProductName')
        region = self.row.get('product/region')
        key = (sku, product_name, region)
        self.processor.processed_report.products.update({key: expected_id})

        product_id = self.processor._create_cost_entry_product(self.row, self.accessor)

        self.assertEqual(product_id, expected_id)

    def test_create_cost_entry_product_existing(self):
        """Test that a previously existing product id is returned."""
        expected_id = random.randint(1, 9)
        sku = self.row.get('product/sku')
        product_name = self.row.get('product/ProductName')
        region = self.row.get('product/region')
        key = (sku, product_name, region)
        self.processor.existing_product_map.update({key: expected_id})

        product_id = self.processor._create_cost_entry_product(self.row, self.accessor)

        self.assertEqual(product_id, expected_id)

    def test_create_cost_entry_pricing(self):
        """Test that a cost entry pricing id is returned."""
        table_name = AWS_CUR_TABLE_MAP['pricing']

        pricing_id = self.processor._create_cost_entry_pricing(self.row, self.accessor)

        self.assertIsNotNone(pricing_id)

        query = self.accessor._get_db_obj_query(table_name)
        id_in_db = query.order_by('-id').first().id

        self.assertEqual(pricing_id, id_in_db)

    def test_create_cost_entry_pricing_already_processed(self):
        """Test that an already processed pricing id is returned."""
        expected_id = random.randint(1, 9)

        key = '{term}-{unit}'.format(term=self.row['pricing/term'], unit=self.row['pricing/unit'])
        self.processor.processed_report.pricing.update({key: expected_id})

        pricing_id = self.processor._create_cost_entry_pricing(self.row, self.accessor)

        self.assertEqual(pricing_id, expected_id)

    def test_create_cost_entry_pricing_existing(self):
        """Test that a previously existing pricing id is returned."""
        expected_id = random.randint(1, 9)

        key = '{term}-{unit}'.format(term=self.row['pricing/term'], unit=self.row['pricing/unit'])
        self.processor.existing_pricing_map.update({key: expected_id})

        pricing_id = self.processor._create_cost_entry_pricing(self.row, self.accessor)

        self.assertEqual(pricing_id, expected_id)

    def test_create_cost_entry_reservation(self):
        """Test that a cost entry reservation id is returned."""
        # Ensure a reservation exists on the row
        arn = 'TestARN'
        row = copy.deepcopy(self.row)
        row['reservation/ReservationARN'] = arn

        table_name = AWS_CUR_TABLE_MAP['reservation']

        reservation_id = self.processor._create_cost_entry_reservation(row, self.accessor)

        self.assertIsNotNone(reservation_id)

        query = self.accessor._get_db_obj_query(table_name)
        id_in_db = query.order_by('-id').first().id

        self.assertEqual(reservation_id, id_in_db)

    def test_create_cost_entry_reservation_update(self):
        """Test that a cost entry reservation id is returned."""
        # Ensure a reservation exists on the row
        arn = 'TestARN'
        row = copy.deepcopy(self.row)
        row['reservation/ReservationARN'] = arn
        row['reservation/NumberOfReservations'] = 1

        table_name = AWS_CUR_TABLE_MAP['reservation']

        reservation_id = self.processor._create_cost_entry_reservation(row, self.accessor)

        self.assertIsNotNone(reservation_id)

        with schema_context(self.schema):
            query = self.accessor._get_db_obj_query(table_name)
            id_in_db = query.order_by('-id').first().id

        self.assertEqual(reservation_id, id_in_db)

        row['lineItem/LineItemType'] = 'RIFee'
        res_count = row['reservation/NumberOfReservations']
        row['reservation/NumberOfReservations'] = res_count + 1
        reservation_id = self.processor._create_cost_entry_reservation(row, self.accessor)

        self.assertEqual(reservation_id, id_in_db)

        db_row = query.filter(id=id_in_db).first()
        self.assertEqual(db_row.number_of_reservations, row['reservation/NumberOfReservations'])

    def test_create_cost_entry_reservation_already_processed(self):
        """Test that an already processed reservation id is returned."""
        expected_id = random.randint(1, 9)
        arn = self.row.get('reservation/ReservationARN')
        self.processor.processed_report.reservations.update({arn: expected_id})

        reservation_id = self.processor._create_cost_entry_reservation(self.row, self.accessor)

        self.assertEqual(reservation_id, expected_id)

    def test_create_cost_entry_reservation_existing(self):
        """Test that a previously existing reservation id is returned."""
        expected_id = random.randint(1, 9)
        arn = self.row.get('reservation/ReservationARN')
        self.processor.existing_reservation_map.update({arn: expected_id})

        product_id = self.processor._create_cost_entry_reservation(self.row, self.accessor)

        self.assertEqual(product_id, expected_id)

    def test_check_for_finalized_bill_bill_is_finalized(self):
        """Verify that a file with invoice_id is marked as finalzed."""
        data = []

        with open(self.test_report, 'r') as f:
            reader = csv.DictReader(f)
            for row in reader:
                data.append(row)

        for row in data:
            row['bill/InvoiceId'] = '12345'

        tmp_file = '/tmp/test_process_finalized_rows.csv'
        field_names = data[0].keys()

        with open(tmp_file, 'w') as f:
            writer = csv.DictWriter(f, fieldnames=field_names)
            writer.writeheader()
            writer.writerows(data)

        processor = AWSReportProcessor(
            schema_name=self.schema,
            report_path=tmp_file,
            compression=UNCOMPRESSED,
            provider_uuid=self.aws_provider_uuid,
        )

        result = processor._check_for_finalized_bill()

        self.assertTrue(result)

    def test_check_for_finalized_bill_bill_not_finalized(self):
        """Verify that a file without invoice_id is not marked as finalzed."""
        processor = AWSReportProcessor(
            schema_name=self.schema,
            report_path=self.test_report,
            compression=UNCOMPRESSED,
            provider_uuid=self.aws_provider_uuid,
        )

        result = processor._check_for_finalized_bill()

        self.assertFalse(result)

    def test_delete_line_items_success(self):
        """Test that data is deleted before processing a manifest."""
        processor = AWSReportProcessor(
            schema_name=self.schema,
            report_path=self.test_report,
            compression=UNCOMPRESSED,
            provider_uuid=self.aws_provider_uuid,
            manifest_id=self.manifest.id,
        )
        processor.process()
        result = processor._delete_line_items(AWSReportDBAccessor, self.column_map)

        with schema_context(self.schema):
            bills = self.accessor.get_cost_entry_bills()
            for bill_id in bills.values():
                line_item_query = self.accessor.get_lineitem_query_for_billid(bill_id)
                self.assertTrue(result)
                self.assertEqual(line_item_query.count(), 0)

    def test_delete_line_items_not_first_file_in_manifest(self):
        """Test that data is not deleted once a file has been processed."""
        self.manifest.num_processed_files = 1
        self.manifest.save()
        processor = AWSReportProcessor(
            schema_name=self.schema,
            report_path=self.test_report,
            compression=UNCOMPRESSED,
            provider_uuid=self.aws_provider_uuid,
            manifest_id=self.manifest.id,
        )
        processor.process()
        result = processor._delete_line_items(AWSReportDBAccessor, self.column_map)
        with schema_context(self.schema):
            bills = self.accessor.get_cost_entry_bills()
            for bill_id in bills.values():
                line_item_query = self.accessor.get_lineitem_query_for_billid(bill_id)
                self.assertFalse(result)
                self.assertNotEqual(line_item_query.count(), 0)

    def test_delete_line_items_no_manifest(self):
        """Test that no data is deleted without a manifest id."""
        processor = AWSReportProcessor(
            schema_name=self.schema,
            report_path=self.test_report,
            compression=UNCOMPRESSED,
            provider_uuid=self.aws_provider_uuid,
        )
        processor.process()
        result = processor._delete_line_items(AWSReportDBAccessor, self.column_map)
        with schema_context(self.schema):
            bills = self.accessor.get_cost_entry_bills()
            for bill_id in bills.values():
                line_item_query = self.accessor.get_lineitem_query_for_billid(bill_id)
                self.assertFalse(result)
                self.assertNotEqual(line_item_query.count(), 0)

    @patch('masu.processor.report_processor_base.ReportProcessorBase._should_process_full_month')
    def test_delete_line_items_use_data_cutoff_date(self, mock_should_process):
        """Test that only three days of data are deleted."""
        mock_should_process.return_value = True

        today = self.date_accessor.today_with_timezone('UTC').replace(
            hour=0, minute=0, second=0, microsecond=0
        )
        first_of_month = today.replace(day=1)
        first_of_next_month = first_of_month + relativedelta(months=1)
        days_in_month = [today - relativedelta(days=i) for i in range(today.day)]

        self.manifest.billing_period_start_datetime = first_of_month
        self.manifest.save()

        data = []

        with open(self.test_report, 'r') as f:
            reader = csv.DictReader(f)
            for row in reader:
                data.append(row)

        for row in data:
            row['lineItem/UsageStartDate'] = random.choice(days_in_month)
            row['bill/BillingPeriodStartDate'] = first_of_month
            row['bill/BillingPeriodEndDate'] = first_of_next_month

        tmp_file = '/tmp/test_delete_data_cutoff.csv'
        field_names = data[0].keys()

        with open(tmp_file, 'w') as f:
            writer = csv.DictWriter(f, fieldnames=field_names)
            writer.writeheader()
            writer.writerows(data)

        processor = AWSReportProcessor(
            schema_name=self.schema,
            report_path=tmp_file,
            compression=UNCOMPRESSED,
            provider_uuid=self.aws_provider_uuid,
            manifest_id=self.manifest.id,
        )
        processor.process()

        # Get latest data date.
        with schema_context(self.schema):
            bills = self.accessor.get_cost_entry_bills()
            for bill_id in bills.values():
                line_item_query = self.accessor.get_lineitem_query_for_billid(bill_id)
                undeleted_max_date = line_item_query.aggregate(max_date=Max('usage_start'))

        mock_should_process.return_value = False
        processor._delete_line_items(AWSReportDBAccessor, self.column_map, is_finalized=False)

        with schema_context(self.schema):
            bills = self.accessor.get_cost_entry_bills()
            for bill_id in bills.values():
                line_item_query = self.accessor.get_lineitem_query_for_billid(bill_id)
                if today.day <= 3:
                    self.assertEqual(line_item_query.count(), 0)
                else:
                    max_date = line_item_query.aggregate(max_date=Max('usage_start'))
                    self.assertLess(max_date.get('max_date').date(), processor.data_cutoff_date)
                    self.assertLess(
                        max_date.get('max_date').date(), undeleted_max_date.get('max_date').date()
                    )
                    self.assertNotEqual(line_item_query.count(), 0)

    @patch('masu.processor.report_processor_base.DateAccessor')
    def test_data_cutoff_date_not_start_of_month(self, mock_date):
        """Test that the data_cuttof_date respects month boundaries."""
        today = self.date_accessor.today_with_timezone('UTC').replace(day=10)
        expected = today.date() - relativedelta(days=2)

        mock_date.return_value.today_with_timezone.return_value = today

        processor = AWSReportProcessor(
            schema_name=self.schema,
            report_path=self.test_report,
            compression=UNCOMPRESSED,
            provider_uuid=self.aws_provider_uuid,
            manifest_id=self.manifest.id,
        )
        self.assertEqual(expected, processor.data_cutoff_date)

    @patch('masu.processor.report_processor_base.DateAccessor')
    def test_data_cutoff_date_start_of_month(self, mock_date):
        """Test that the data_cuttof_date respects month boundaries."""
        today = self.date_accessor.today_with_timezone('UTC')
        first_of_month = today.replace(day=1)

        mock_date.return_value.today_with_timezone.return_value = first_of_month

        processor = AWSReportProcessor(
            schema_name=self.schema,
            report_path=self.test_report,
            compression=UNCOMPRESSED,
            provider_uuid=self.aws_provider_uuid,
            manifest_id=self.manifest.id,
        )
        self.assertEqual(first_of_month.date(), processor.data_cutoff_date)

    @patch('masu.processor.report_processor_base.ReportManifestDBAccessor')
    def test_should_process_full_month_first_manifest_for_bill(self, mock_manifest_accessor):
        """Test that we process data for a new bill/manifest completely."""
        mock_manifest = Mock()
        today = self.date_accessor.today_with_timezone('UTC')
        mock_manifest.billing_period_start_datetime = today
        mock_manifest.num_processed_files = 1
        mock_manifest.num_total_files = 2
        mock_manifest_accessor.return_value.__enter__.return_value.get_manifest_by_id.return_value = (
            mock_manifest
        )
        mock_manifest_accessor.return_value.__enter__.return_value.\
            get_manifest_list_for_provider_and_bill_date.return_value = [
                mock_manifest
            ]
        processor = AWSReportProcessor(
            schema_name=self.schema,
            report_path=self.test_report,
            compression=UNCOMPRESSED,
            provider_uuid=self.aws_provider_uuid,
            manifest_id=self.manifest.id,
        )

        self.assertTrue(processor._should_process_full_month())

    @patch('masu.processor.report_processor_base.ReportManifestDBAccessor')
    def test_should_process_full_month_not_first_manifest_for_bill(self, mock_manifest_accessor):
        """Test that we process a window of data for the bill/manifest."""
        mock_manifest = Mock()
        today = self.date_accessor.today_with_timezone('UTC')
        mock_manifest.billing_period_start_datetime = today
        mock_manifest.num_processed_files = 1
        mock_manifest.num_total_files = 1
        mock_manifest_accessor.return_value.__enter__.return_value.get_manifest_by_id.return_value = (
            mock_manifest
        )
        mock_manifest_accessor.return_value.__enter__.return_value.\
            get_manifest_list_for_provider_and_bill_date.return_value = [
                mock_manifest,
                mock_manifest,
            ]
        processor = AWSReportProcessor(
            schema_name=self.schema,
            report_path=self.test_report,
            compression=UNCOMPRESSED,
            provider_uuid=self.aws_provider_uuid,
            manifest_id=self.manifest.id,
        )

        self.assertFalse(processor._should_process_full_month())

    @patch('masu.processor.report_processor_base.ReportManifestDBAccessor')
    def test_should_process_full_month_manifest_for_not_current_month(self, mock_manifest_accessor):
        """Test that we process this manifest completely."""
        mock_manifest = Mock()
        last_month = self.date_accessor.today_with_timezone('UTC') - relativedelta(months=1)
        mock_manifest.billing_period_start_datetime = last_month
        mock_manifest_accessor.return_value.__enter__.return_value.get_manifest_by_id.return_value = (
            mock_manifest
        )
        processor = AWSReportProcessor(
            schema_name=self.schema,
            report_path=self.test_report,
            compression=UNCOMPRESSED,
            provider_uuid=self.aws_provider_uuid,
            manifest_id=self.manifest.id,
        )

        self.assertTrue(processor._should_process_full_month())

    def test_should_process_full_month_no_manifest(self):
        """Test that we process this manifest completely."""
        processor = AWSReportProcessor(
            schema_name=self.schema,
            report_path=self.test_report,
            compression=UNCOMPRESSED,
            provider_uuid=self.aws_provider_uuid,
        )

        self.assertTrue(processor._should_process_full_month())

    def test_should_process_row_within_cuttoff_date(self):
        """Test that we correctly determine a row should be processed."""
        today = self.date_accessor.today_with_timezone('UTC')
        row = {'lineItem/UsageStartDate': today.isoformat()}

        processor = AWSReportProcessor(
            schema_name=self.schema,
            report_path=self.test_report,
            compression=UNCOMPRESSED,
            provider_uuid=self.aws_provider_uuid,
            manifest_id=self.manifest.id,
        )

        should_process = processor._should_process_row(row, 'lineItem/UsageStartDate', False)

        self.assertTrue(should_process)

    def test_should_process_row_outside_cuttoff_date(self):
        """Test that we correctly determine a row should be processed."""
        today = self.date_accessor.today_with_timezone('UTC')
        usage_start = today - relativedelta(days=10)
        row = {'lineItem/UsageStartDate': usage_start.isoformat()}

        processor = AWSReportProcessor(
            schema_name=self.schema,
            report_path=self.test_report,
            compression=UNCOMPRESSED,
            provider_uuid=self.aws_provider_uuid,
            manifest_id=self.manifest.id,
        )

        should_process = processor._should_process_row(row, 'lineItem/UsageStartDate', False)

        self.assertFalse(should_process)

    def test_should_process_is_full_month(self):
        """Test that we correctly determine a row should be processed."""
        today = self.date_accessor.today_with_timezone('UTC')
        usage_start = today - relativedelta(days=10)
        row = {'lineItem/UsageStartDate': usage_start.isoformat()}

        processor = AWSReportProcessor(
            schema_name=self.schema,
            report_path=self.test_report,
            compression=UNCOMPRESSED,
            provider_uuid=self.aws_provider_uuid,
            manifest_id=self.manifest.id,
        )

        should_process = processor._should_process_row(row, 'lineItem/UsageStartDate', True)

        self.assertTrue(should_process)

    def test_should_process_is_finalized(self):
        """Test that we correctly determine a row should be processed."""
        today = self.date_accessor.today_with_timezone('UTC')
        usage_start = today - relativedelta(days=10)
        row = {'lineItem/UsageStartDate': usage_start.isoformat()}

        processor = AWSReportProcessor(
            schema_name=self.schema,
            report_path=self.test_report,
            compression=UNCOMPRESSED,
            provider_uuid=self.aws_provider_uuid,
            manifest_id=self.manifest.id,
        )

        should_process = processor._should_process_row(
            row, 'lineItem/UsageStartDate', False, is_finalized=True
        )

        self.assertTrue(should_process)

    def test_get_date_column_filter(self):
        """Test that the Azure specific filter is returned."""
        processor = AWSReportProcessor(
            schema_name=self.schema,
            report_path=self.test_report,
            compression=UNCOMPRESSED,
            provider_uuid=self.aws_provider_uuid,
            manifest_id=self.manifest.id,
        )
        date_filter = processor.get_date_column_filter()

        self.assertIn('usage_start__gte', date_filter)
예제 #3
0
class AWSReportProcessorTest(MasuTestCase):
    """Test Cases for the AWSReportProcessor object."""
    @classmethod
    def setUpClass(cls):
        """Set up the test class with required objects."""
        super().setUpClass()
        cls.test_report = './koku/masu/test/data/test_cur.csv'
        cls.test_report_gzip = './koku/masu/test/data/test_cur.csv.gz'

        cls.date_accessor = DateAccessor()
        cls.manifest_accessor = ReportManifestDBAccessor()

        with ReportingCommonDBAccessor() as report_common_db:
            cls.column_map = report_common_db.column_map

        _report_tables = copy.deepcopy(AWS_CUR_TABLE_MAP)
        _report_tables.pop('line_item_daily', None)
        _report_tables.pop('line_item_daily_summary', None)
        _report_tables.pop('tags_summary', None)
        cls.report_tables = list(_report_tables.values())
        # Grab a single row of test data to work with
        with open(cls.test_report, 'r') as f:
            reader = csv.DictReader(f)
            cls.row = next(reader)

    def setUp(self):
        super().setUp()

        self.processor = AWSReportProcessor(
            schema_name=self.schema,
            report_path=self.test_report,
            compression=UNCOMPRESSED,
            provider_id=self.aws_provider.id,
        )

        billing_start = self.date_accessor.today_with_timezone('UTC').replace(
            year=2018, month=6, day=1, hour=0, minute=0, second=0)
        self.assembly_id = '1234'
        self.manifest_dict = {
            'assembly_id': self.assembly_id,
            'billing_period_start_datetime': billing_start,
            'num_total_files': 2,
            'provider_id': self.aws_provider.id,
        }

        self.accessor = AWSReportDBAccessor(self.schema, self.column_map)
        self.report_schema = self.accessor.report_schema
        self.manifest = self.manifest_accessor.add(**self.manifest_dict)
        self.manifest_accessor.commit()

    def tearDown(self):
        """Return the database to a pre-test state."""
        super().tearDown()

        self.processor.processed_report.remove_processed_rows()
        self.processor.line_item_columns = None

    def test_initializer(self):
        """Test initializer."""
        self.assertIsNotNone(self.processor._schema_name)
        self.assertIsNotNone(self.processor._report_path)
        self.assertIsNotNone(self.processor._report_name)
        self.assertIsNotNone(self.processor._compression)
        self.assertEqual(self.processor._datetime_format,
                         Config.AWS_DATETIME_STR_FORMAT)
        self.assertEqual(self.processor._batch_size,
                         Config.REPORT_PROCESSING_BATCH_SIZE)

    def test_initializer_unsupported_compression(self):
        """Assert that an error is raised for an invalid compression."""
        with self.assertRaises(MasuProcessingError):
            AWSReportProcessor(
                schema_name=self.schema,
                report_path=self.test_report,
                compression='unsupported',
                provider_id=self.aws_provider.id,
            )

    def test_process_default(self):
        """Test the processing of an uncompressed file."""
        counts = {}
        processor = AWSReportProcessor(
            schema_name=self.schema,
            report_path=self.test_report,
            compression=UNCOMPRESSED,
            provider_id=self.aws_provider.id,
            manifest_id=self.manifest.id,
        )
        report_db = self.accessor
        report_schema = report_db.report_schema
        for table_name in self.report_tables:
            table = getattr(report_schema, table_name)
            with schema_context(self.schema):
                count = table.objects.count()
            counts[table_name] = count

        bill_date = self.manifest.billing_period_start_datetime.date()
        expected = f'INFO:masu.processor.aws.aws_report_processor:Deleting data for schema: acct10001 and bill date: {bill_date}'
        logging.disable(
            logging.NOTSET
        )  # We are currently disabling all logging below CRITICAL in masu/__init__.py
        with self.assertLogs('masu.processor.aws.aws_report_processor',
                             level='INFO') as logger:
            processor.process()
            self.assertIn(expected, logger.output)

        for table_name in self.report_tables:
            table = getattr(report_schema, table_name)
            with schema_context(self.schema):
                count = table.objects.count()

            if table_name in (
                    'reporting_awscostentryreservation',
                    'reporting_ocpawscostlineitem_daily_summary',
                    'reporting_ocpawscostlineitem_project_daily_summary',
            ):
                self.assertTrue(count >= counts[table_name])
            else:
                self.assertTrue(count > counts[table_name])

    def test_process_gzip(self):
        """Test the processing of a gzip compressed file."""
        counts = {}
        processor = AWSReportProcessor(
            schema_name=self.schema,
            report_path=self.test_report_gzip,
            compression=GZIP_COMPRESSED,
            provider_id=self.aws_provider.id,
        )
        report_db = self.accessor
        report_schema = report_db.report_schema
        for table_name in self.report_tables:
            table = getattr(report_schema, table_name)
            with schema_context(self.schema):
                count = table.objects.count()
            counts[table_name] = count

        processor.process()

        for table_name in self.report_tables:
            table = getattr(report_schema, table_name)
            with schema_context(self.schema):
                count = table.objects.count()

            if table_name in (
                    'reporting_awscostentryreservation',
                    'reporting_ocpawscostlineitem_daily_summary',
                    'reporting_ocpawscostlineitem_project_daily_summary',
            ):
                self.assertTrue(count >= counts[table_name])
            else:
                self.assertTrue(count > counts[table_name])

    def test_process_duplicates(self):
        """Test that row duplicates are not inserted into the DB."""
        counts = {}
        processor = AWSReportProcessor(
            schema_name=self.schema,
            report_path=self.test_report,
            compression=UNCOMPRESSED,
            provider_id=self.aws_provider.id,
        )

        # Process for the first time
        processor.process()
        report_db = self.accessor
        report_schema = report_db.report_schema

        for table_name in self.report_tables:
            table = getattr(report_schema, table_name)
            with schema_context(self.schema):
                count = table.objects.count()
            counts[table_name] = count

        # Wipe stale data
        with schema_context(self.schema):
            self.accessor._get_db_obj_query(
                AWS_CUR_TABLE_MAP['line_item']).delete()

        processor = AWSReportProcessor(
            schema_name=self.schema,
            report_path=self.test_report,
            compression=UNCOMPRESSED,
            provider_id=self.aws_provider.id,
        )
        # Process for the second time
        processor.process()

        for table_name in self.report_tables:
            table = getattr(report_schema, table_name)
            with schema_context(self.schema):
                count = table.objects.count()
            self.assertTrue(count == counts[table_name])

    def test_process_finalized_rows(self):
        """Test that a finalized bill is processed properly."""
        data = []
        table_name = AWS_CUR_TABLE_MAP['line_item']

        with open(self.test_report, 'r') as f:
            reader = csv.DictReader(f)
            for row in reader:
                data.append(row)

        for row in data:
            row['bill/InvoiceId'] = '12345'

        tmp_file = '/tmp/test_process_finalized_rows.csv'
        field_names = data[0].keys()

        with open(tmp_file, 'w') as f:
            writer = csv.DictWriter(f, fieldnames=field_names)
            writer.writeheader()
            writer.writerows(data)

        processor = AWSReportProcessor(
            schema_name=self.schema,
            report_path=self.test_report,
            compression=UNCOMPRESSED,
            provider_id=self.aws_provider.id,
        )

        # Process for the first time
        processor.process()
        report_db = self.accessor
        report_schema = report_db.report_schema

        bill_table_name = AWS_CUR_TABLE_MAP['bill']
        bill_table = getattr(report_schema, bill_table_name)
        with schema_context(self.schema):
            bill = bill_table.objects.first()
            self.assertIsNone(bill.finalized_datetime)

        table = getattr(report_schema, table_name)
        with schema_context(self.schema):
            orig_count = table.objects.count()

        # Wipe stale data
        with schema_context(self.schema):
            self.accessor._get_db_obj_query(table_name).delete()

        processor = AWSReportProcessor(
            schema_name=self.schema,
            report_path=tmp_file,
            compression=UNCOMPRESSED,
            provider_id=self.aws_provider.id,
        )
        # Process for the second time
        processor.process()

        with schema_context(self.schema):
            count = table.objects.count()
            self.assertTrue(count == orig_count)
            count = table.objects.filter(invoice_id__isnull=False).count()
            self.assertTrue(count == orig_count)

        with schema_context(self.schema):
            bill = bill_table.objects.first()
            self.assertIsNotNone(bill.finalized_datetime)

    def test_process_finalized_rows_small_batch_size(self):
        """Test that a finalized bill is processed properly on batch size."""
        data = []
        table_name = AWS_CUR_TABLE_MAP['line_item']

        with open(self.test_report, 'r') as f:
            reader = csv.DictReader(f)
            for row in reader:
                data.append(row)

        for row in data:
            row['bill/InvoiceId'] = '12345'

        tmp_file = '/tmp/test_process_finalized_rows.csv'
        field_names = data[0].keys()

        with open(tmp_file, 'w') as f:
            writer = csv.DictWriter(f, fieldnames=field_names)
            writer.writeheader()
            writer.writerows(data)

        processor = AWSReportProcessor(
            schema_name=self.schema,
            report_path=self.test_report,
            compression=UNCOMPRESSED,
            provider_id=self.aws_provider.id,
        )

        # Process for the first time
        processor.process()
        report_db = self.accessor
        report_schema = report_db.report_schema

        bill_table_name = AWS_CUR_TABLE_MAP['bill']
        bill_table = getattr(report_schema, bill_table_name)
        with schema_context(self.schema):
            bill = bill_table.objects.first()
            self.assertIsNone(bill.finalized_datetime)

        table = getattr(report_schema, table_name)
        with schema_context(self.schema):
            orig_count = table.objects.count()

        # Wipe stale data
        with schema_context(self.schema):
            self.accessor._get_db_obj_query(table_name).delete()

        processor = AWSReportProcessor(
            schema_name=self.schema,
            report_path=tmp_file,
            compression=UNCOMPRESSED,
            provider_id=self.aws_provider.id,
        )
        processor._batch_size = 2
        # Process for the second time
        processor.process()

        with schema_context(self.schema):
            count = table.objects.count()
            self.assertTrue(count == orig_count)
            count = table.objects.filter(invoice_id__isnull=False).count()
            self.assertTrue(count == orig_count)

        with schema_context(self.schema):
            bill = bill_table.objects.first()
            self.assertIsNotNone(bill.finalized_datetime)

    def test_do_not_overwrite_finalized_bill_timestamp(self):
        """Test that a finalized bill timestamp does not get overwritten."""
        data = []
        with open(self.test_report, 'r') as f:
            reader = csv.DictReader(f)
            for row in reader:
                data.append(row)

        for row in data:
            row['bill/InvoiceId'] = '12345'

        tmp_file = '/tmp/test_process_finalized_rows.csv'
        field_names = data[0].keys()

        with open(tmp_file, 'w') as f:
            writer = csv.DictWriter(f, fieldnames=field_names)
            writer.writeheader()
            writer.writerows(data)

        processor = AWSReportProcessor(
            schema_name=self.schema,
            report_path=self.test_report,
            compression=UNCOMPRESSED,
            provider_id=self.aws_provider.id,
        )

        # Process for the first time
        processor.process()
        report_db = self.accessor
        report_schema = report_db.report_schema

        bill_table_name = AWS_CUR_TABLE_MAP['bill']
        bill_table = getattr(report_schema, bill_table_name)
        with schema_context(self.schema):
            bill = bill_table.objects.first()

        processor = AWSReportProcessor(
            schema_name=self.schema,
            report_path=tmp_file,
            compression=UNCOMPRESSED,
            provider_id=self.aws_provider.id,
        )
        # Process for the second time
        processor.process()

        finalized_datetime = bill.finalized_datetime

        processor = AWSReportProcessor(
            schema_name=self.schema,
            report_path=tmp_file,
            compression=UNCOMPRESSED,
            provider_id=self.aws_provider.id,
        )
        # Process for the third time to make sure the timestamp is the same
        processor.process()
        self.assertEqual(bill.finalized_datetime, finalized_datetime)

    def test_get_file_opener_default(self):
        """Test that the default file opener is returned."""
        opener, mode = self.processor._get_file_opener(UNCOMPRESSED)

        self.assertEqual(opener, open)
        self.assertEqual(mode, 'r')

    def test_get_file_opener_gzip(self):
        """Test that the gzip file opener is returned."""
        opener, mode = self.processor._get_file_opener(GZIP_COMPRESSED)

        self.assertEqual(opener, gzip.open)
        self.assertEqual(mode, 'rt')

    def test_update_mappings(self):
        """Test that mappings are updated."""
        test_entry = {'key': 'value'}
        counts = {}
        ce_maps = {
            'cost_entry': self.processor.existing_cost_entry_map,
            'product': self.processor.existing_product_map,
            'pricing': self.processor.existing_pricing_map,
            'reservation': self.processor.existing_reservation_map,
        }

        for name, ce_map in ce_maps.items():
            counts[name] = len(ce_map.values())
            ce_map.update(test_entry)

        self.processor._update_mappings()

        for name, ce_map in ce_maps.items():
            self.assertTrue(len(ce_map.values()) > counts[name])
            for key in test_entry:
                self.assertIn(key, ce_map)

    def test_write_processed_rows_to_csv(self):
        """Test that the CSV bulk upload file contains proper data."""
        bill_id = self.processor._create_cost_entry_bill(
            self.row, self.accessor)
        cost_entry_id = self.processor._create_cost_entry(
            self.row, bill_id, self.accessor)
        product_id = self.processor._create_cost_entry_product(
            self.row, self.accessor)
        pricing_id = self.processor._create_cost_entry_pricing(
            self.row, self.accessor)
        reservation_id = self.processor._create_cost_entry_reservation(
            self.row, self.accessor)
        self.processor._create_cost_entry_line_item(
            self.row,
            cost_entry_id,
            bill_id,
            product_id,
            pricing_id,
            reservation_id,
            self.accessor,
        )

        file_obj = self.processor._write_processed_rows_to_csv()

        line_item_data = self.processor.processed_report.line_items.pop()
        # Convert data to CSV format
        expected_values = [
            str(value) if value else None for value in line_item_data.values()
        ]

        reader = csv.reader(file_obj)
        new_row = next(reader)
        new_row = new_row[0].split('\t')
        actual = {}

        for i, key in enumerate(line_item_data.keys()):
            actual[key] = new_row[i] if new_row[i] else None

        self.assertEqual(actual.keys(), line_item_data.keys())
        self.assertEqual(list(actual.values()), expected_values)

    def test_get_data_for_table(self):
        """Test that a row is disected into appropriate data structures."""
        column_map = self.column_map

        for table_name in self.report_tables:
            expected_columns = sorted(column_map[table_name].values())
            data = self.processor._get_data_for_table(self.row, table_name)

            for key in data:
                self.assertIn(key, expected_columns)

    def test_process_tags(self):
        """Test that tags are properly packaged in a JSON string."""
        row = {
            'resourceTags/user:environment': 'prod',
            'notATag': 'value',
            'resourceTags/System': 'value',
            'resourceTags/system:system_key': 'system_value',
        }
        expected = {'environment': 'prod', 'system_key': 'system_value'}
        actual = json.loads(self.processor._process_tags(row))

        self.assertNotIn(row['notATag'], actual)
        self.assertEqual(expected, actual)

    def test_get_cost_entry_time_interval(self):
        """Test that an interval string is properly split."""
        fmt = Config.AWS_DATETIME_STR_FORMAT
        end = datetime.datetime.utcnow()
        expected_start = (end - datetime.timedelta(days=1)).strftime(fmt)
        expected_end = end.strftime(fmt)
        interval = expected_start + '/' + expected_end

        actual_start, actual_end = self.processor._get_cost_entry_time_interval(
            interval)

        self.assertEqual(expected_start, actual_start)
        self.assertEqual(expected_end, actual_end)

    def test_create_cost_entry_bill(self):
        """Test that a cost entry bill id is returned."""
        table_name = AWS_CUR_TABLE_MAP['bill']
        table = getattr(self.report_schema, table_name)

        bill_id = self.processor._create_cost_entry_bill(
            self.row, self.accessor)

        self.assertIsNotNone(bill_id)

        query = self.accessor._get_db_obj_query(table_name)
        id_in_db = query.order_by('-id').first().id
        provider_id = query.order_by('-id').first().provider_id

        self.assertEqual(bill_id, id_in_db)
        self.assertIsNotNone(provider_id)

    def test_create_cost_entry_bill_existing(self):
        """Test that a cost entry bill id is returned from an existing bill."""
        table_name = AWS_CUR_TABLE_MAP['bill']
        table = getattr(self.report_schema, table_name)

        bill_id = self.processor._create_cost_entry_bill(
            self.row, self.accessor)

        query = self.accessor._get_db_obj_query(table_name)
        bill = query.first()

        self.processor.current_bill = bill

        new_bill_id = self.processor._create_cost_entry_bill(
            self.row, self.accessor)

        self.assertEqual(bill_id, new_bill_id)

        self.processor.current_bill = None

    def test_create_cost_entry(self):
        """Test that a cost entry id is returned."""
        table_name = AWS_CUR_TABLE_MAP['cost_entry']
        table = getattr(self.report_schema, table_name)

        bill_id = self.processor._create_cost_entry_bill(
            self.row, self.accessor)

        cost_entry_id = self.processor._create_cost_entry(
            self.row, bill_id, self.accessor)
        self.accessor.commit()

        self.assertIsNotNone(cost_entry_id)

        query = self.accessor._get_db_obj_query(table_name)
        id_in_db = query.order_by('-id').first().id

        self.assertEqual(cost_entry_id, id_in_db)

    def test_create_cost_entry_existing(self):
        """Test that a cost entry id is returned from an existing entry."""
        table_name = AWS_CUR_TABLE_MAP['cost_entry']
        table = getattr(self.report_schema, table_name)

        bill_id = self.processor._create_cost_entry_bill(
            self.row, self.accessor)
        self.accessor.commit()

        interval = self.row.get('identity/TimeInterval')
        start, _ = self.processor._get_cost_entry_time_interval(interval)
        key = (bill_id, start)
        expected_id = random.randint(1, 9)
        self.processor.existing_cost_entry_map[key] = expected_id

        cost_entry_id = self.processor._create_cost_entry(
            self.row, bill_id, self.accessor)
        self.assertEqual(cost_entry_id, expected_id)

    def test_create_cost_entry_line_item(self):
        """Test that line item data is returned properly."""
        bill_id = self.processor._create_cost_entry_bill(
            self.row, self.accessor)
        cost_entry_id = self.processor._create_cost_entry(
            self.row, bill_id, self.accessor)
        product_id = self.processor._create_cost_entry_product(
            self.row, self.accessor)
        pricing_id = self.processor._create_cost_entry_pricing(
            self.row, self.accessor)
        reservation_id = self.processor._create_cost_entry_reservation(
            self.row, self.accessor)

        self.accessor.commit()

        self.processor._create_cost_entry_line_item(
            self.row,
            cost_entry_id,
            bill_id,
            product_id,
            pricing_id,
            reservation_id,
            self.accessor,
        )

        line_item = None
        if self.processor.processed_report.line_items:
            line_item = self.processor.processed_report.line_items[-1]

        self.assertIsNotNone(line_item)
        self.assertIn('tags', line_item)
        self.assertEqual(line_item.get('cost_entry_id'), cost_entry_id)
        self.assertEqual(line_item.get('cost_entry_bill_id'), bill_id)
        self.assertEqual(line_item.get('cost_entry_product_id'), product_id)
        self.assertEqual(line_item.get('cost_entry_pricing_id'), pricing_id)
        self.assertEqual(line_item.get('cost_entry_reservation_id'),
                         reservation_id)

        self.assertIsNotNone(self.processor.line_item_columns)

    def test_create_cost_entry_product(self):
        """Test that a cost entry product id is returned."""
        table_name = AWS_CUR_TABLE_MAP['product']
        table = getattr(self.report_schema, table_name)

        product_id = self.processor._create_cost_entry_product(
            self.row, self.accessor)

        self.accessor.commit()

        self.assertIsNotNone(product_id)

        query = self.accessor._get_db_obj_query(table_name)
        id_in_db = query.order_by('-id').first().id

        self.assertEqual(product_id, id_in_db)

    def test_create_cost_entry_product_already_processed(self):
        """Test that an already processed product id is returned."""
        expected_id = random.randint(1, 9)
        sku = self.row.get('product/sku')
        product_name = self.row.get('product/ProductName')
        region = self.row.get('product/region')
        key = (sku, product_name, region)
        self.processor.processed_report.products.update({key: expected_id})

        product_id = self.processor._create_cost_entry_product(
            self.row, self.accessor)

        self.assertEqual(product_id, expected_id)

    def test_create_cost_entry_product_existing(self):
        """Test that a previously existing product id is returned."""
        expected_id = random.randint(1, 9)
        sku = self.row.get('product/sku')
        product_name = self.row.get('product/ProductName')
        region = self.row.get('product/region')
        key = (sku, product_name, region)
        self.processor.existing_product_map.update({key: expected_id})

        product_id = self.processor._create_cost_entry_product(
            self.row, self.accessor)

        self.assertEqual(product_id, expected_id)

    def test_create_cost_entry_pricing(self):
        """Test that a cost entry pricing id is returned."""
        table_name = AWS_CUR_TABLE_MAP['pricing']
        table = getattr(self.report_schema, table_name)

        pricing_id = self.processor._create_cost_entry_pricing(
            self.row, self.accessor)

        self.accessor.commit()

        self.assertIsNotNone(pricing_id)

        query = self.accessor._get_db_obj_query(table_name)
        id_in_db = query.order_by('-id').first().id

        self.assertEqual(pricing_id, id_in_db)

    def test_create_cost_entry_pricing_already_processed(self):
        """Test that an already processed pricing id is returned."""
        expected_id = random.randint(1, 9)

        key = '{term}-{unit}'.format(term=self.row['pricing/term'],
                                     unit=self.row['pricing/unit'])
        self.processor.processed_report.pricing.update({key: expected_id})

        pricing_id = self.processor._create_cost_entry_pricing(
            self.row, self.accessor)

        self.assertEqual(pricing_id, expected_id)

    def test_create_cost_entry_pricing_existing(self):
        """Test that a previously existing pricing id is returned."""
        expected_id = random.randint(1, 9)

        key = '{term}-{unit}'.format(term=self.row['pricing/term'],
                                     unit=self.row['pricing/unit'])
        self.processor.existing_pricing_map.update({key: expected_id})

        pricing_id = self.processor._create_cost_entry_pricing(
            self.row, self.accessor)

        self.assertEqual(pricing_id, expected_id)

    def test_create_cost_entry_reservation(self):
        """Test that a cost entry reservation id is returned."""
        # Ensure a reservation exists on the row
        arn = 'TestARN'
        row = copy.deepcopy(self.row)
        row['reservation/ReservationARN'] = arn

        table_name = AWS_CUR_TABLE_MAP['reservation']
        table = getattr(self.report_schema, table_name)

        reservation_id = self.processor._create_cost_entry_reservation(
            row, self.accessor)

        self.accessor.commit()

        self.assertIsNotNone(reservation_id)

        query = self.accessor._get_db_obj_query(table_name)
        id_in_db = query.order_by('-id').first().id

        self.assertEqual(reservation_id, id_in_db)

    def test_create_cost_entry_reservation_update(self):
        """Test that a cost entry reservation id is returned."""
        # Ensure a reservation exists on the row
        arn = 'TestARN'
        row = copy.deepcopy(self.row)
        row['reservation/ReservationARN'] = arn
        row['reservation/NumberOfReservations'] = 1

        table_name = AWS_CUR_TABLE_MAP['reservation']
        table = getattr(self.report_schema, table_name)

        reservation_id = self.processor._create_cost_entry_reservation(
            row, self.accessor)

        self.accessor.commit()

        self.assertIsNotNone(reservation_id)

        with schema_context(self.schema):
            query = self.accessor._get_db_obj_query(table_name)
            id_in_db = query.order_by('-id').first().id

        self.assertEqual(reservation_id, id_in_db)

        row['lineItem/LineItemType'] = 'RIFee'
        res_count = row['reservation/NumberOfReservations']
        row['reservation/NumberOfReservations'] = res_count + 1
        reservation_id = self.processor._create_cost_entry_reservation(
            row, self.accessor)
        self.accessor.commit()

        self.assertEqual(reservation_id, id_in_db)

        db_row = query.filter(id=id_in_db).first()
        self.assertEqual(db_row.number_of_reservations,
                         row['reservation/NumberOfReservations'])

    def test_create_cost_entry_reservation_already_processed(self):
        """Test that an already processed reservation id is returned."""
        expected_id = random.randint(1, 9)
        arn = self.row.get('reservation/ReservationARN')
        self.processor.processed_report.reservations.update({arn: expected_id})

        reservation_id = self.processor._create_cost_entry_reservation(
            self.row, self.accessor)

        self.assertEqual(reservation_id, expected_id)

    def test_create_cost_entry_reservation_existing(self):
        """Test that a previously existing reservation id is returned."""
        expected_id = random.randint(1, 9)
        arn = self.row.get('reservation/ReservationARN')
        self.processor.existing_reservation_map.update({arn: expected_id})

        product_id = self.processor._create_cost_entry_reservation(
            self.row, self.accessor)

        self.assertEqual(product_id, expected_id)

    def test_remove_temp_cur_files(self):
        """Test to remove temporary cost usage files."""
        cur_dir = tempfile.mkdtemp()

        manifest_data = {"assemblyId": "6e019de5-a41d-4cdb-b9a0-99bfba9a9cb5"}
        manifest = '{}/{}'.format(cur_dir, 'koku-Manifest.json')
        with open(manifest, 'w') as outfile:
            json.dump(manifest_data, outfile)

        file_list = [
            {
                'file': '6e019de5-a41d-4cdb-b9a0-99bfba9a9cb5-koku-1.csv.gz',
                'processed_date': datetime.datetime(year=2018, month=5, day=3),
            },
            {
                'file': '6e019de5-a41d-4cdb-b9a0-99bfba9a9cb5-koku-2.csv.gz',
                'processed_date': datetime.datetime(year=2018, month=5, day=3),
            },
            {
                'file': '2aeb9169-2526-441c-9eca-d7ed015d52bd-koku-1.csv.gz',
                'processed_date': datetime.datetime(year=2018, month=5, day=2),
            },
            {
                'file': '6c8487e8-c590-4e6a-b2c2-91a2375c0bad-koku-1.csv.gz',
                'processed_date': datetime.datetime(year=2018, month=5, day=1),
            },
            {
                'file': '6c8487e8-c590-4e6a-b2c2-91a2375d0bed-koku-1.csv.gz',
                'processed_date': None,
            },
        ]
        expected_delete_list = []
        for item in file_list:
            path = '{}/{}'.format(cur_dir, item['file'])
            f = open(path, 'w')
            obj = self.manifest_accessor.get_manifest(self.assembly_id,
                                                      self.aws_provider.id)
            with ReportStatsDBAccessor(item['file'], obj.id) as stats:
                stats.update(last_completed_datetime=item['processed_date'])
            f.close()
            if (not item['file'].startswith(manifest_data.get('assemblyId'))
                    and item['processed_date']):
                expected_delete_list.append(path)

        removed_files = self.processor.remove_temp_cur_files(cur_dir)
        self.assertEqual(sorted(removed_files), sorted(expected_delete_list))
        shutil.rmtree(cur_dir)

    def test_check_for_finalized_bill_bill_is_finalized(self):
        """Verify that a file with invoice_id is marked as finalzed."""
        data = []
        table_name = AWS_CUR_TABLE_MAP['line_item']

        with open(self.test_report, 'r') as f:
            reader = csv.DictReader(f)
            for row in reader:
                data.append(row)

        for row in data:
            row['bill/InvoiceId'] = '12345'

        tmp_file = '/tmp/test_process_finalized_rows.csv'
        field_names = data[0].keys()

        with open(tmp_file, 'w') as f:
            writer = csv.DictWriter(f, fieldnames=field_names)
            writer.writeheader()
            writer.writerows(data)

        processor = AWSReportProcessor(
            schema_name=self.schema,
            report_path=tmp_file,
            compression=UNCOMPRESSED,
            provider_id=self.aws_provider.id,
        )

        result = processor._check_for_finalized_bill()

        self.assertTrue(result)

    def test_check_for_finalized_bill_bill_not_finalized(self):
        """Verify that a file without invoice_id is not marked as finalzed."""

        processor = AWSReportProcessor(
            schema_name=self.schema,
            report_path=self.test_report,
            compression=UNCOMPRESSED,
            provider_id=self.aws_provider.id,
        )

        result = processor._check_for_finalized_bill()

        self.assertFalse(result)

    def test_delete_line_items_success(self):
        """Test that data is deleted before processing a manifest."""
        processor = AWSReportProcessor(
            schema_name=self.schema,
            report_path=self.test_report,
            compression=UNCOMPRESSED,
            provider_id=self.aws_provider.id,
            manifest_id=self.manifest.id,
        )
        processor.process()
        result = processor._delete_line_items()

        with schema_context(self.schema):
            bills = self.accessor.get_cost_entry_bills()
            for bill_id in bills.values():
                line_item_query = self.accessor.get_lineitem_query_for_billid(
                    bill_id)
                self.assertTrue(result)
                self.assertEqual(line_item_query.count(), 0)

    def test_delete_line_items_not_first_file_in_manifest(self):
        """Test that data is not deleted once a file has been processed."""
        self.manifest.num_processed_files = 1
        self.manifest.save()
        self.manifest_accessor.commit()
        processor = AWSReportProcessor(
            schema_name=self.schema,
            report_path=self.test_report,
            compression=UNCOMPRESSED,
            provider_id=self.aws_provider.id,
            manifest_id=self.manifest.id,
        )
        processor.process()
        result = processor._delete_line_items()
        with schema_context(self.schema):
            bills = self.accessor.get_cost_entry_bills()
            for bill_id in bills.values():
                line_item_query = self.accessor.get_lineitem_query_for_billid(
                    bill_id)
                self.assertFalse(result)
                self.assertNotEqual(line_item_query.count(), 0)

    def test_delete_line_items_no_manifest(self):
        """Test that no data is deleted without a manifest id."""
        processor = AWSReportProcessor(
            schema_name=self.schema,
            report_path=self.test_report,
            compression=UNCOMPRESSED,
            provider_id=self.aws_provider.id,
        )
        processor.process()
        result = processor._delete_line_items()
        with schema_context(self.schema):
            bills = self.accessor.get_cost_entry_bills()
            for bill_id in bills.values():
                line_item_query = self.accessor.get_lineitem_query_for_billid(
                    bill_id)
                self.assertFalse(result)
                self.assertNotEqual(line_item_query.count(), 0)
예제 #4
0
class TestUpdateSummaryTablesTask(MasuTestCase):
    """Test cases for Processor summary table Celery tasks."""

    @classmethod
    def setUpClass(cls):
        """Setup for the class."""
        super().setUpClass()
        cls.aws_tables = list(AWS_CUR_TABLE_MAP.values())
        cls.ocp_tables = list(OCP_REPORT_TABLE_MAP.values())
        cls.all_tables = list(AWS_CUR_TABLE_MAP.values()) + \
            list(OCP_REPORT_TABLE_MAP.values())
        report_common_db = ReportingCommonDBAccessor()
        cls.column_map = report_common_db.column_map
        report_common_db.close_session()

    @classmethod
    def tearDownClass(cls):
        """Tear down the test class."""
        super().tearDownClass()

    def setUp(self):
        """Set up each test."""
        super().setUp()
        self.schema_name = self.test_schema
        self.aws_accessor = AWSReportDBAccessor(schema=self.schema_name,
                                                column_map=self.column_map)
        self.ocp_accessor = OCPReportDBAccessor(schema=self.schema_name,
                                                column_map=self.column_map)

        self.creator = ReportObjectCreator(
            self.aws_accessor,
            self.column_map,
            self.aws_accessor.report_schema.column_types
        )

        # Populate some line item data so that the summary tables
        # have something to pull from
        self.start_date = DateAccessor().today_with_timezone('UTC').replace(day=1)
        last_month = self.start_date - relativedelta.relativedelta(months=1)

        for cost_entry_date in (self.start_date, last_month):
            bill = self.creator.create_cost_entry_bill(cost_entry_date)
            cost_entry = self.creator.create_cost_entry(bill, cost_entry_date)
            for family in ['Storage', 'Compute Instance', 'Database Storage',
                           'Database Instance']:
                product = self.creator.create_cost_entry_product(family)
                pricing = self.creator.create_cost_entry_pricing()
                reservation = self.creator.create_cost_entry_reservation()
                self.creator.create_cost_entry_line_item(
                    bill,
                    cost_entry,
                    product,
                    pricing,
                    reservation
                )
        provider_ocp_uuid = self.ocp_test_provider_uuid

        with ProviderDBAccessor(provider_uuid=provider_ocp_uuid) as provider_accessor:
            provider_id = provider_accessor.get_provider().id

        cluster_id = self.ocp_provider_resource_name
        for period_date in (self.start_date, last_month):
            period = self.creator.create_ocp_report_period(period_date, provider_id=provider_id,
                                                           cluster_id=cluster_id)
            report = self.creator.create_ocp_report(period, period_date)
            for _ in range(25):
                self.creator.create_ocp_usage_line_item(period, report)

    def tearDown(self):
        """Return the database to a pre-test state."""
        for table_name in self.aws_tables:
            tables = self.aws_accessor._get_db_obj_query(table_name).all()
            for table in tables:
                self.aws_accessor._session.delete(table)
        self.aws_accessor.commit()
        for table_name in self.ocp_tables:
            tables = self.ocp_accessor._get_db_obj_query(table_name).all()
            for table in tables:
                self.ocp_accessor._session.delete(table)
        self.ocp_accessor.commit()

        self.aws_accessor._session.rollback()
        self.aws_accessor.close_connections()
        self.aws_accessor.close_session()
        self.ocp_accessor.close_connections()
        self.ocp_accessor.close_session()
        super().tearDown()

    @patch('masu.processor.tasks.update_charge_info')
    def test_update_summary_tables_aws(self, mock_charge_info):
        """Test that the summary table task runs."""
        provider = 'AWS'
        provider_aws_uuid = self.aws_test_provider_uuid

        daily_table_name = AWS_CUR_TABLE_MAP['line_item_daily']
        summary_table_name = AWS_CUR_TABLE_MAP['line_item_daily_summary']
        start_date = self.start_date.replace(day=1) + relativedelta.relativedelta(months=-1)

        daily_query = self.aws_accessor._get_db_obj_query(daily_table_name)
        summary_query = self.aws_accessor._get_db_obj_query(summary_table_name)

        initial_daily_count = daily_query.count()
        initial_summary_count = summary_query.count()

        self.assertEqual(initial_daily_count, 0)
        self.assertEqual(initial_summary_count, 0)

        update_summary_tables(self.schema_name, provider, provider_aws_uuid, start_date)

        self.assertNotEqual(daily_query.count(), initial_daily_count)
        self.assertNotEqual(summary_query.count(), initial_summary_count)

    @patch('masu.processor.tasks.update_charge_info')
    def test_update_summary_tables_aws_end_date(self, mock_charge_info):
        """Test that the summary table task respects a date range."""
        provider = 'AWS'
        provider_aws_uuid = self.aws_test_provider_uuid
        ce_table_name = AWS_CUR_TABLE_MAP['cost_entry']
        daily_table_name = AWS_CUR_TABLE_MAP['line_item_daily']
        summary_table_name = AWS_CUR_TABLE_MAP['line_item_daily_summary']

        start_date = self.start_date.replace(day=1,
                                             hour=0,
                                             minute=0,
                                             second=0,
                                             microsecond=0) + relativedelta.relativedelta(months=-1)

        end_date = start_date + timedelta(days=10)
        end_date = end_date.replace(hour=23, minute=59, second=59)

        daily_table = getattr(self.aws_accessor.report_schema, daily_table_name)
        summary_table = getattr(self.aws_accessor.report_schema, summary_table_name)
        ce_table = getattr(self.aws_accessor.report_schema, ce_table_name)

        ce_start_date = self.aws_accessor._session\
            .query(func.min(ce_table.interval_start))\
            .filter(ce_table.interval_start >= start_date).first()[0]

        ce_end_date = self.aws_accessor._session\
            .query(func.max(ce_table.interval_start))\
            .filter(ce_table.interval_start <= end_date).first()[0]

        # The summary tables will only include dates where there is data
        expected_start_date = max(start_date, ce_start_date)
        expected_start_date = expected_start_date.replace(hour=0, minute=0,
                                                          second=0,
                                                          microsecond=0)
        expected_end_date = min(end_date, ce_end_date)
        expected_end_date = expected_end_date.replace(hour=0, minute=0,
                                                      second=0, microsecond=0)

        update_summary_tables(self.schema_name, provider, provider_aws_uuid, start_date, end_date)

        result_start_date, result_end_date = self.aws_accessor._session.query(
            func.min(daily_table.usage_start),
            func.max(daily_table.usage_end)
        ).first()

        self.assertEqual(result_start_date, expected_start_date)
        self.assertEqual(result_end_date, expected_end_date)

        result_start_date, result_end_date = self.aws_accessor._session.query(
            func.min(summary_table.usage_start),
            func.max(summary_table.usage_end)
        ).first()

        self.assertEqual(result_start_date, expected_start_date)
        self.assertEqual(result_end_date, expected_end_date)

    @patch('masu.processor.tasks.update_charge_info')
    @patch('masu.database.ocp_rate_db_accessor.OCPRateDBAccessor.get_memory_gb_usage_per_hour_rates')
    @patch('masu.database.ocp_rate_db_accessor.OCPRateDBAccessor.get_cpu_core_usage_per_hour_rates')
    def test_update_summary_tables_ocp(self, mock_cpu_rate, mock_mem_rate, mock_charge_info):
        """Test that the summary table task runs."""
        mem_rate = {'tiered_rate': [{'value': '1.5', 'unit': 'USD'}]}
        cpu_rate = {'tiered_rate': [{'value': '2.5', 'unit': 'USD'}]}

        mock_cpu_rate.return_value = cpu_rate
        mock_mem_rate.return_value = mem_rate

        provider = 'OCP'
        provider_ocp_uuid = self.ocp_test_provider_uuid

        daily_table_name = OCP_REPORT_TABLE_MAP['line_item_daily']
        start_date = self.start_date.replace(day=1) + relativedelta.relativedelta(months=-1)

        daily_query = self.ocp_accessor._get_db_obj_query(daily_table_name)

        initial_daily_count = daily_query.count()

        self.assertEqual(initial_daily_count, 0)
        update_summary_tables(self.schema_name, provider, provider_ocp_uuid, start_date)

        self.assertNotEqual(daily_query.count(), initial_daily_count)

        update_charge_info(schema_name=self.test_schema, provider_uuid=provider_ocp_uuid)

        table_name = OCP_REPORT_TABLE_MAP['line_item_daily_summary']
        with ProviderDBAccessor(provider_ocp_uuid) as provider_accessor:
            provider_obj = provider_accessor.get_provider()

        usage_period_qry = self.ocp_accessor.get_usage_period_query_by_provider(provider_obj.id)
        cluster_id = usage_period_qry.first().cluster_id

        items = self.ocp_accessor._get_db_obj_query(table_name).filter_by(cluster_id=cluster_id)
        for item in items:
            self.assertIsNotNone(item.pod_charge_memory_gigabyte_hours)
            self.assertIsNotNone(item.pod_charge_cpu_core_hours)

        storage_daily_name = OCP_REPORT_TABLE_MAP['storage_line_item_daily']
        items = self.ocp_accessor._get_db_obj_query(storage_daily_name).filter_by(cluster_id=cluster_id)
        for item in items:
            self.assertIsNotNone(item.volume_request_storage_byte_seconds)
            self.assertIsNotNone(item.persistentvolumeclaim_usage_byte_seconds)

        storage_summary_name = OCP_REPORT_TABLE_MAP['storage_line_item_daily_summary']
        items = self.ocp_accessor._get_db_obj_query(storage_summary_name).filter_by(cluster_id=cluster_id)
        for item in items:
            self.assertIsNotNone(item.volume_request_storage_gigabyte_months)
            self.assertIsNotNone(item.persistentvolumeclaim_usage_gigabyte_months)

    @patch('masu.processor.tasks.update_charge_info')
    @patch('masu.database.ocp_rate_db_accessor.OCPRateDBAccessor.get_memory_gb_usage_per_hour_rates')
    @patch('masu.database.ocp_rate_db_accessor.OCPRateDBAccessor.get_cpu_core_usage_per_hour_rates')
    def test_update_summary_tables_ocp_end_date(self, mock_cpu_rate, mock_mem_rate, mock_charge_info, ):
        """Test that the summary table task respects a date range."""
        mock_cpu_rate.return_value = 1.5
        mock_mem_rate.return_value = 2.5
        provider = 'OCP'
        provider_ocp_uuid = self.ocp_test_provider_uuid
        ce_table_name = OCP_REPORT_TABLE_MAP['report']
        daily_table_name = OCP_REPORT_TABLE_MAP['line_item_daily']

        start_date = self.start_date.replace(day=1,
                                             hour=0, minute=0, second=0,
                                             microsecond=0) + relativedelta.relativedelta(months=-1)

        end_date = start_date + timedelta(days=10)
        end_date = end_date.replace(hour=23, minute=59, second=59)

        daily_table = getattr(self.ocp_accessor.report_schema, daily_table_name)
        ce_table = getattr(self.ocp_accessor.report_schema, ce_table_name)

        ce_start_date = self.ocp_accessor._session\
            .query(func.min(ce_table.interval_start))\
            .filter(ce_table.interval_start >= start_date).first()[0]

        ce_end_date = self.ocp_accessor._session\
            .query(func.max(ce_table.interval_start))\
            .filter(ce_table.interval_start <= end_date).first()[0]

        # The summary tables will only include dates where there is data
        expected_start_date = max(start_date, ce_start_date)
        expected_start_date = expected_start_date.replace(hour=0, minute=0,
                                                          second=0,
                                                          microsecond=0)
        expected_end_date = min(end_date, ce_end_date)
        expected_end_date = expected_end_date.replace(hour=0, minute=0,
                                                      second=0, microsecond=0)

        update_summary_tables(self.schema_name, provider, provider_ocp_uuid, start_date, end_date)
        result_start_date, result_end_date = self.ocp_accessor._session.query(
            func.min(daily_table.usage_start),
            func.max(daily_table.usage_end)
        ).first()

        self.assertEqual(result_start_date, expected_start_date)
        self.assertEqual(result_end_date, expected_end_date)

    def test_update_charge_info_aws(self):
        """Test that update_charge_info is not called for AWS."""
        update_charge_info(schema_name=self.test_schema,
                           provider_uuid=self.aws_test_provider_uuid)
        # FIXME: no asserts on test

    @patch('masu.processor.tasks.update_summary_tables')
    def test_get_report_data_for_all_providers(self, mock_update):
        """Test GET report_data endpoint with provider_uuid=*."""
        start_date = date.today()
        update_all_summary_tables(start_date)

        mock_update.delay.assert_called_with(
            ANY, ANY, ANY, str(start_date), ANY)
예제 #5
0
파일: test_tasks.py 프로젝트: ebpetway/koku
class TestUpdateSummaryTablesTask(MasuTestCase):
    """Test cases for Processor summary table Celery tasks."""
    @classmethod
    def setUpClass(cls):
        """Set up for the class."""
        super().setUpClass()
        cls.aws_tables = list(AWS_CUR_TABLE_MAP.values())
        cls.ocp_tables = list(OCP_REPORT_TABLE_MAP.values())
        cls.all_tables = list(AWS_CUR_TABLE_MAP.values()) + list(
            OCP_REPORT_TABLE_MAP.values())

        cls.creator = ReportObjectCreator(cls.schema)

    def setUp(self):
        """Set up each test."""
        super().setUp()
        self.aws_accessor = AWSReportDBAccessor(schema=self.schema)
        self.ocp_accessor = OCPReportDBAccessor(schema=self.schema)

        # Populate some line item data so that the summary tables
        # have something to pull from
        self.start_date = DateHelper().today.replace(day=1)

    @patch("masu.processor.tasks.chain")
    @patch("masu.processor.tasks.refresh_materialized_views")
    @patch("masu.processor.tasks.update_cost_model_costs")
    def test_update_summary_tables_aws(self, mock_charge_info, mock_views,
                                       mock_chain):
        """Test that the summary table task runs."""
        provider = Provider.PROVIDER_AWS
        provider_aws_uuid = self.aws_provider_uuid

        daily_table_name = AWS_CUR_TABLE_MAP["line_item_daily"]
        summary_table_name = AWS_CUR_TABLE_MAP["line_item_daily_summary"]
        start_date = self.start_date.replace(
            day=1) + relativedelta.relativedelta(months=-1)

        with schema_context(self.schema):
            daily_query = self.aws_accessor._get_db_obj_query(daily_table_name)
            summary_query = self.aws_accessor._get_db_obj_query(
                summary_table_name)
            daily_query.delete()
            summary_query.delete()

            initial_daily_count = daily_query.count()
            initial_summary_count = summary_query.count()

        self.assertEqual(initial_daily_count, 0)
        self.assertEqual(initial_summary_count, 0)

        update_summary_tables(self.schema, provider, provider_aws_uuid,
                              start_date)

        with schema_context(self.schema):
            self.assertNotEqual(daily_query.count(), initial_daily_count)
            self.assertNotEqual(summary_query.count(), initial_summary_count)

        mock_chain.return_value.apply_async.assert_called()

    @patch("masu.processor.tasks.update_cost_model_costs")
    def test_update_summary_tables_aws_end_date(self, mock_charge_info):
        """Test that the summary table task respects a date range."""
        provider = Provider.PROVIDER_AWS_LOCAL
        provider_aws_uuid = self.aws_provider_uuid
        ce_table_name = AWS_CUR_TABLE_MAP["cost_entry"]
        daily_table_name = AWS_CUR_TABLE_MAP["line_item_daily"]
        summary_table_name = AWS_CUR_TABLE_MAP["line_item_daily_summary"]

        start_date = DateHelper().last_month_start

        end_date = DateHelper().last_month_end

        daily_table = getattr(self.aws_accessor.report_schema,
                              daily_table_name)
        summary_table = getattr(self.aws_accessor.report_schema,
                                summary_table_name)
        ce_table = getattr(self.aws_accessor.report_schema, ce_table_name)
        with schema_context(self.schema):
            daily_table.objects.all().delete()
            summary_table.objects.all().delete()
            ce_start_date = ce_table.objects.filter(
                interval_start__gte=start_date.date()).aggregate(
                    Min("interval_start"))["interval_start__min"]
            ce_end_date = ce_table.objects.filter(
                interval_start__lte=end_date.date()).aggregate(
                    Max("interval_start"))["interval_start__max"]

        # The summary tables will only include dates where there is data
        expected_start_date = max(start_date, ce_start_date)
        expected_start_date = expected_start_date.replace(hour=0,
                                                          minute=0,
                                                          second=0,
                                                          microsecond=0)
        expected_end_date = min(end_date, ce_end_date)
        expected_end_date = expected_end_date.replace(hour=0,
                                                      minute=0,
                                                      second=0,
                                                      microsecond=0)

        update_summary_tables(self.schema, provider, provider_aws_uuid,
                              start_date, end_date)

        with schema_context(self.schema):
            daily_entry = daily_table.objects.all().aggregate(
                Min("usage_start"), Max("usage_end"))
            result_start_date = daily_entry["usage_start__min"]
            result_end_date = daily_entry["usage_end__max"]

        self.assertEqual(result_start_date, expected_start_date.date())
        self.assertEqual(result_end_date, expected_end_date.date())

        with schema_context(self.schema):
            summary_entry = summary_table.objects.all().aggregate(
                Min("usage_start"), Max("usage_end"))
            result_start_date = summary_entry["usage_start__min"]
            result_end_date = summary_entry["usage_end__max"]

        self.assertEqual(result_start_date, expected_start_date.date())
        self.assertEqual(result_end_date, expected_end_date.date())

    @patch("masu.processor.tasks.chain")
    @patch("masu.processor.tasks.refresh_materialized_views")
    @patch("masu.processor.tasks.update_cost_model_costs")
    @patch("masu.processor.ocp.ocp_cost_model_cost_updater.CostModelDBAccessor"
           )
    def test_update_summary_tables_ocp(self, mock_cost_model, mock_charge_info,
                                       mock_view, mock_chain):
        """Test that the summary table task runs."""
        infrastructure_rates = {
            "cpu_core_usage_per_hour": 1.5,
            "memory_gb_usage_per_hour": 2.5,
            "storage_gb_usage_per_month": 0.5,
        }
        markup = {}

        mock_cost_model.return_value.__enter__.return_value.infrastructure_rates = infrastructure_rates
        mock_cost_model.return_value.__enter__.return_value.supplementary_rates = {}
        mock_cost_model.return_value.__enter__.return_value.markup = markup

        provider = Provider.PROVIDER_OCP
        provider_ocp_uuid = self.ocp_test_provider_uuid

        daily_table_name = OCP_REPORT_TABLE_MAP["line_item_daily"]
        start_date = DateHelper().last_month_start
        end_date = DateHelper().last_month_end

        with schema_context(self.schema):
            daily_query = self.ocp_accessor._get_db_obj_query(daily_table_name)
            daily_query.delete()

            initial_daily_count = daily_query.count()

        self.assertEqual(initial_daily_count, 0)
        update_summary_tables(self.schema, provider, provider_ocp_uuid,
                              start_date, end_date)

        with schema_context(self.schema):
            self.assertNotEqual(daily_query.count(), initial_daily_count)

        update_cost_model_costs(schema_name=self.schema,
                                provider_uuid=provider_ocp_uuid,
                                start_date=start_date,
                                end_date=end_date)

        table_name = OCP_REPORT_TABLE_MAP["line_item_daily_summary"]
        with ProviderDBAccessor(provider_ocp_uuid) as provider_accessor:
            provider_obj = provider_accessor.get_provider()

        usage_period_qry = self.ocp_accessor.get_usage_period_query_by_provider(
            provider_obj.uuid)
        with schema_context(self.schema):
            cluster_id = usage_period_qry.first().cluster_id

            items = self.ocp_accessor._get_db_obj_query(table_name).filter(
                usage_start__gte=start_date,
                usage_start__lte=end_date,
                cluster_id=cluster_id,
                data_source="Pod")
            for item in items:
                self.assertNotEqual(item.infrastructure_usage_cost.get("cpu"),
                                    0)
                self.assertNotEqual(
                    item.infrastructure_usage_cost.get("memory"), 0)

            storage_daily_name = OCP_REPORT_TABLE_MAP[
                "storage_line_item_daily"]

            items = self.ocp_accessor._get_db_obj_query(
                storage_daily_name).filter(cluster_id=cluster_id)
            for item in items:
                self.assertIsNotNone(item.volume_request_storage_byte_seconds)
                self.assertIsNotNone(
                    item.persistentvolumeclaim_usage_byte_seconds)

            storage_summary_name = OCP_REPORT_TABLE_MAP[
                "line_item_daily_summary"]
            items = self.ocp_accessor._get_db_obj_query(
                storage_summary_name).filter(cluster_id=cluster_id,
                                             data_source="Storage")
            for item in items:
                self.assertIsNotNone(
                    item.volume_request_storage_gigabyte_months)
                self.assertIsNotNone(
                    item.persistentvolumeclaim_usage_gigabyte_months)

        mock_chain.return_value.apply_async.assert_called()

    @patch("masu.processor.tasks.update_cost_model_costs")
    @patch(
        "masu.database.cost_model_db_accessor.CostModelDBAccessor.get_memory_gb_usage_per_hour_rates"
    )
    @patch(
        "masu.database.cost_model_db_accessor.CostModelDBAccessor.get_cpu_core_usage_per_hour_rates"
    )
    def test_update_summary_tables_ocp_end_date(self, mock_cpu_rate,
                                                mock_mem_rate,
                                                mock_charge_info):
        """Test that the summary table task respects a date range."""
        mock_cpu_rate.return_value = 1.5
        mock_mem_rate.return_value = 2.5
        provider = Provider.PROVIDER_OCP
        provider_ocp_uuid = self.ocp_test_provider_uuid
        ce_table_name = OCP_REPORT_TABLE_MAP["report"]
        daily_table_name = OCP_REPORT_TABLE_MAP["line_item_daily"]

        start_date = DateHelper().last_month_start
        end_date = DateHelper().last_month_end
        daily_table = getattr(self.ocp_accessor.report_schema,
                              daily_table_name)
        ce_table = getattr(self.ocp_accessor.report_schema, ce_table_name)

        with schema_context(self.schema):
            daily_table.objects.all().delete()
            ce_start_date = ce_table.objects.filter(
                interval_start__gte=start_date.date()).aggregate(
                    Min("interval_start"))["interval_start__min"]

            ce_end_date = ce_table.objects.filter(
                interval_start__lte=end_date.date()).aggregate(
                    Max("interval_start"))["interval_start__max"]

        # The summary tables will only include dates where there is data
        expected_start_date = max(start_date, ce_start_date)
        expected_end_date = min(end_date, ce_end_date)

        update_summary_tables(self.schema, provider, provider_ocp_uuid,
                              start_date, end_date)
        with schema_context(self.schema):
            daily_entry = daily_table.objects.all().aggregate(
                Min("usage_start"), Max("usage_end"))
            result_start_date = daily_entry["usage_start__min"]
            result_end_date = daily_entry["usage_end__max"]

        self.assertEqual(result_start_date, expected_start_date.date())
        self.assertEqual(result_end_date, expected_end_date.date())

    @patch("masu.processor.tasks.update_summary_tables")
    def test_get_report_data_for_all_providers(self, mock_update):
        """Test GET report_data endpoint with provider_uuid=*."""
        start_date = date.today()
        update_all_summary_tables(start_date)

        mock_update.delay.assert_called_with(ANY, ANY, ANY, str(start_date),
                                             ANY)

    def test_refresh_materialized_views(self):
        """Test that materialized views are refreshed."""
        manifest_dict = {
            "assembly_id": "12345",
            "billing_period_start_datetime": DateHelper().today,
            "num_total_files": 2,
            "provider_uuid": self.aws_provider_uuid,
            "task": "170653c0-3e66-4b7e-a764-336496d7ca5a",
        }

        with ReportManifestDBAccessor() as manifest_accessor:
            manifest = manifest_accessor.add(**manifest_dict)
            manifest.save()

        refresh_materialized_views(self.schema,
                                   Provider.PROVIDER_AWS,
                                   manifest_id=manifest.id)

        views_to_check = [
            view for view in AWS_MATERIALIZED_VIEWS
            if "Cost" in view._meta.db_table
        ]

        with schema_context(self.schema):
            for view in views_to_check:
                self.assertNotEqual(view.objects.count(), 0)

        with ReportManifestDBAccessor() as manifest_accessor:
            manifest = manifest_accessor.get_manifest_by_id(manifest.id)
            self.assertIsNotNone(manifest.manifest_completed_datetime)

    @patch("masu.processor.tasks.connection")
    def test_vacuum_schema(self, mock_conn):
        """Test that the vacuum schema task runs."""
        logging.disable(logging.NOTSET)
        mock_conn.cursor.return_value.__enter__.return_value.fetchall.return_value = [
            ("table", )
        ]
        expected = "INFO:masu.processor.tasks:VACUUM ANALYZE acct10001.table"
        with self.assertLogs("masu.processor.tasks", level="INFO") as logger:
            vacuum_schema(self.schema)
            self.assertIn(expected, logger.output)

    @patch("masu.processor.tasks.connection")
    def test_autovacuum_tune_schema_default_table(self, mock_conn):
        """Test that the autovacuum tuning runs."""
        logging.disable(logging.NOTSET)

        # Make sure that the AUTOVACUUM_TUNING environment variable is unset!
        if "AUTOVACUUM_TUNING" in os.environ:
            del os.environ["AUTOVACUUM_TUNING"]

        mock_conn.cursor.return_value.__enter__.return_value.fetchall.return_value = [
            ("cost_model", 20000000, {})
        ]
        expected = (
            "INFO:masu.processor.tasks:ALTER TABLE acct10001.cost_model set (autovacuum_vacuum_scale_factor = 0.01);"
        )
        with self.assertLogs("masu.processor.tasks", level="INFO") as logger:
            autovacuum_tune_schema(self.schema)
            self.assertIn(expected, logger.output)

        mock_conn.cursor.return_value.__enter__.return_value.fetchall.return_value = [
            ("cost_model", 2000000, {})
        ]
        expected = (
            "INFO:masu.processor.tasks:ALTER TABLE acct10001.cost_model set (autovacuum_vacuum_scale_factor = 0.02);"
        )
        with self.assertLogs("masu.processor.tasks", level="INFO") as logger:
            autovacuum_tune_schema(self.schema)
            self.assertIn(expected, logger.output)

        mock_conn.cursor.return_value.__enter__.return_value.fetchall.return_value = [
            ("cost_model", 200000, {})
        ]
        expected = (
            "INFO:masu.processor.tasks:ALTER TABLE acct10001.cost_model set (autovacuum_vacuum_scale_factor = 0.05);"
        )
        with self.assertLogs("masu.processor.tasks", level="INFO") as logger:
            autovacuum_tune_schema(self.schema)
            self.assertIn(expected, logger.output)

        mock_conn.cursor.return_value.__enter__.return_value.fetchall.return_value = [
            ("cost_model", 200000, {
                "autovacuum_vacuum_scale_factor": Decimal("0.05")
            })
        ]
        expected = "INFO:masu.processor.tasks:Altered autovacuum_vacuum_scale_factor on 0 tables"
        with self.assertLogs("masu.processor.tasks", level="INFO") as logger:
            autovacuum_tune_schema(self.schema)
            self.assertIn(expected, logger.output)

        mock_conn.cursor.return_value.__enter__.return_value.fetchall.return_value = [
            ("cost_model", 20000, {
                "autovacuum_vacuum_scale_factor": Decimal("0.02")
            })
        ]
        expected = "INFO:masu.processor.tasks:ALTER TABLE acct10001.cost_model reset (autovacuum_vacuum_scale_factor);"
        with self.assertLogs("masu.processor.tasks", level="INFO") as logger:
            autovacuum_tune_schema(self.schema)
            self.assertIn(expected, logger.output)

    @patch("masu.processor.tasks.connection")
    def test_autovacuum_tune_schema_custom_table(self, mock_conn):
        """Test that the autovacuum tuning runs."""
        logging.disable(logging.NOTSET)
        scale_table = [(10000000, "0.0001"), (1000000, "0.004"),
                       (100000, "0.011")]
        os.environ["AUTOVACUUM_TUNING"] = json.dumps(scale_table)

        mock_conn.cursor.return_value.__enter__.return_value.fetchall.return_value = [
            ("cost_model", 20000000, {})
        ]
        expected = (
            "INFO:masu.processor.tasks:ALTER TABLE acct10001.cost_model set (autovacuum_vacuum_scale_factor = 0.0001);"
        )
        with self.assertLogs("masu.processor.tasks", level="INFO") as logger:
            autovacuum_tune_schema(self.schema)
            self.assertIn(expected, logger.output)

        mock_conn.cursor.return_value.__enter__.return_value.fetchall.return_value = [
            ("cost_model", 2000000, {})
        ]
        expected = (
            "INFO:masu.processor.tasks:ALTER TABLE acct10001.cost_model set (autovacuum_vacuum_scale_factor = 0.004);"
        )
        with self.assertLogs("masu.processor.tasks", level="INFO") as logger:
            autovacuum_tune_schema(self.schema)
            self.assertIn(expected, logger.output)

        mock_conn.cursor.return_value.__enter__.return_value.fetchall.return_value = [
            ("cost_model", 200000, {})
        ]
        expected = (
            "INFO:masu.processor.tasks:ALTER TABLE acct10001.cost_model set (autovacuum_vacuum_scale_factor = 0.011);"
        )
        with self.assertLogs("masu.processor.tasks", level="INFO") as logger:
            autovacuum_tune_schema(self.schema)
            self.assertIn(expected, logger.output)

        mock_conn.cursor.return_value.__enter__.return_value.fetchall.return_value = [
            ("cost_model", 200000, {
                "autovacuum_vacuum_scale_factor": Decimal("0.011")
            })
        ]
        expected = "INFO:masu.processor.tasks:Altered autovacuum_vacuum_scale_factor on 0 tables"
        with self.assertLogs("masu.processor.tasks", level="INFO") as logger:
            autovacuum_tune_schema(self.schema)
            self.assertIn(expected, logger.output)

        mock_conn.cursor.return_value.__enter__.return_value.fetchall.return_value = [
            ("cost_model", 20000, {
                "autovacuum_vacuum_scale_factor": Decimal("0.004")
            })
        ]
        expected = "INFO:masu.processor.tasks:ALTER TABLE acct10001.cost_model reset (autovacuum_vacuum_scale_factor);"
        with self.assertLogs("masu.processor.tasks", level="INFO") as logger:
            autovacuum_tune_schema(self.schema)
            self.assertIn(expected, logger.output)

        del os.environ["AUTOVACUUM_TUNING"]

    @patch("masu.processor.tasks.connection")
    def test_autovacuum_tune_schema_manual_setting(self, mock_conn):
        """Test that the autovacuum tuning runs."""
        logging.disable(logging.NOTSET)

        # Make sure that the AUTOVACUUM_TUNING environment variable is unset!
        if "AUTOVACUUM_TUNING" in os.environ:
            del os.environ["AUTOVACUUM_TUNING"]

        mock_conn.cursor.return_value.__enter__.return_value.fetchall.return_value = [
            ("cost_model", 200000, {
                "autovacuum_vacuum_scale_factor": Decimal("0.04")
            })
        ]
        expected = "INFO:masu.processor.tasks:Altered autovacuum_vacuum_scale_factor on 0 tables"
        with self.assertLogs("masu.processor.tasks", level="INFO") as logger:
            autovacuum_tune_schema(self.schema)
            self.assertIn(expected, logger.output)

        mock_conn.cursor.return_value.__enter__.return_value.fetchall.return_value = [
            ("cost_model", 200000, {
                "autovacuum_vacuum_scale_factor": Decimal("0.06")
            })
        ]
        expected = (
            "INFO:masu.processor.tasks:ALTER TABLE acct10001.cost_model set (autovacuum_vacuum_scale_factor = 0.05);"
        )
        with self.assertLogs("masu.processor.tasks", level="INFO") as logger:
            autovacuum_tune_schema(self.schema)
            self.assertIn(expected, logger.output)

    @patch("masu.processor.tasks.connection")
    def test_autovacuum_tune_schema_invalid_setting(self, mock_conn):
        """Test that the autovacuum tuning runs."""
        logging.disable(logging.NOTSET)

        # Make sure that the AUTOVACUUM_TUNING environment variable is unset!
        if "AUTOVACUUM_TUNING" in os.environ:
            del os.environ["AUTOVACUUM_TUNING"]

        # This invalid setting should be treated as though there was no setting
        mock_conn.cursor.return_value.__enter__.return_value.fetchall.return_value = [
            ("cost_model", 20000000, {
                "autovacuum_vacuum_scale_factor": ""
            })
        ]
        expected = (
            "INFO:masu.processor.tasks:ALTER TABLE acct10001.cost_model set (autovacuum_vacuum_scale_factor = 0.01);"
        )
        with self.assertLogs("masu.processor.tasks", level="INFO") as logger:
            autovacuum_tune_schema(self.schema)
            self.assertIn(expected, logger.output)

    def test_autovacuum_tune_schedule(self):
        vh = next(
            iter(koku_celery.app.conf.beat_schedule["vacuum-schemas"]
                 ["schedule"].hour))
        avh = next(
            iter(koku_celery.app.conf.beat_schedule["autovacuum-tune-schemas"]
                 ["schedule"].hour))
        self.assertTrue(avh == (23 if vh == 0 else (vh - 1)))
예제 #6
0
class TestUpdateSummaryTablesTask(MasuTestCase):
    """Test cases for Processor summary table Celery tasks."""
    @classmethod
    def setUpClass(cls):
        """Setup for the class."""
        super().setUpClass()
        cls.aws_tables = list(AWS_CUR_TABLE_MAP.values())
        cls.ocp_tables = list(OCP_REPORT_TABLE_MAP.values())
        cls.all_tables = list(AWS_CUR_TABLE_MAP.values()) + list(
            OCP_REPORT_TABLE_MAP.values())
        with ReportingCommonDBAccessor() as report_common_db:
            cls.column_map = report_common_db.column_map

        cls.creator = ReportObjectCreator(cls.schema, cls.column_map)

    def setUp(self):
        """Set up each test."""
        super().setUp()
        self.aws_accessor = AWSReportDBAccessor(schema=self.schema,
                                                column_map=self.column_map)
        self.ocp_accessor = OCPReportDBAccessor(schema=self.schema,
                                                column_map=self.column_map)

        # Populate some line item data so that the summary tables
        # have something to pull from
        self.start_date = DateAccessor().today_with_timezone('UTC').replace(
            day=1)
        last_month = self.start_date - relativedelta.relativedelta(months=1)

        for cost_entry_date in (self.start_date, last_month):
            bill = self.creator.create_cost_entry_bill(
                provider_id=self.aws_provider.id, bill_date=cost_entry_date)
            cost_entry = self.creator.create_cost_entry(bill, cost_entry_date)
            for family in [
                    'Storage',
                    'Compute Instance',
                    'Database Storage',
                    'Database Instance',
            ]:
                product = self.creator.create_cost_entry_product(family)
                pricing = self.creator.create_cost_entry_pricing()
                reservation = self.creator.create_cost_entry_reservation()
                self.creator.create_cost_entry_line_item(
                    bill, cost_entry, product, pricing, reservation)
        provider_ocp_uuid = self.ocp_test_provider_uuid

        with ProviderDBAccessor(
                provider_uuid=provider_ocp_uuid) as provider_accessor:
            provider_id = provider_accessor.get_provider().id

        cluster_id = self.ocp_provider_resource_name
        for period_date in (self.start_date, last_month):
            period = self.creator.create_ocp_report_period(
                period_date, provider_id=provider_id, cluster_id=cluster_id)
            report = self.creator.create_ocp_report(period, period_date)
            for _ in range(25):
                self.creator.create_ocp_usage_line_item(period, report)

    @patch('masu.processor.tasks.update_cost_summary_table')
    @patch('masu.processor.tasks.update_charge_info')
    def test_update_summary_tables_aws(self, mock_charge_info,
                                       mock_cost_summary):
        """Test that the summary table task runs."""
        provider = 'AWS'
        provider_aws_uuid = self.aws_test_provider_uuid

        daily_table_name = AWS_CUR_TABLE_MAP['line_item_daily']
        summary_table_name = AWS_CUR_TABLE_MAP['line_item_daily_summary']
        start_date = self.start_date.replace(
            day=1) + relativedelta.relativedelta(months=-1)

        with schema_context(self.schema):
            daily_query = self.aws_accessor._get_db_obj_query(daily_table_name)
            summary_query = self.aws_accessor._get_db_obj_query(
                summary_table_name)

            initial_daily_count = daily_query.count()
            initial_summary_count = summary_query.count()

        self.assertEqual(initial_daily_count, 0)
        self.assertEqual(initial_summary_count, 0)

        update_summary_tables(self.schema, provider, provider_aws_uuid,
                              start_date)

        with schema_context(self.schema):
            self.assertNotEqual(daily_query.count(), initial_daily_count)
            self.assertNotEqual(summary_query.count(), initial_summary_count)

        mock_charge_info.apply_async.assert_called()
        mock_cost_summary.si.assert_called()

    @patch('masu.processor.tasks.update_charge_info')
    def test_update_summary_tables_aws_end_date(self, mock_charge_info):
        """Test that the summary table task respects a date range."""
        provider = 'AWS'
        provider_aws_uuid = self.aws_test_provider_uuid
        ce_table_name = AWS_CUR_TABLE_MAP['cost_entry']
        daily_table_name = AWS_CUR_TABLE_MAP['line_item_daily']
        summary_table_name = AWS_CUR_TABLE_MAP['line_item_daily_summary']

        start_date = self.start_date.replace(
            day=1, hour=0, minute=0, second=0,
            microsecond=0) + relativedelta.relativedelta(months=-1)

        end_date = start_date + timedelta(days=10)
        end_date = end_date.replace(hour=23, minute=59, second=59)

        daily_table = getattr(self.aws_accessor.report_schema,
                              daily_table_name)
        summary_table = getattr(self.aws_accessor.report_schema,
                                summary_table_name)
        ce_table = getattr(self.aws_accessor.report_schema, ce_table_name)

        with schema_context(self.schema):
            ce_start_date = ce_table.objects\
                .filter(interval_start__gte=start_date)\
                .aggregate(Min('interval_start'))['interval_start__min']
            ce_end_date = ce_table.objects\
                .filter(interval_start__lte=end_date)\
                .aggregate(Max('interval_start'))['interval_start__max']

        # The summary tables will only include dates where there is data
        expected_start_date = max(start_date, ce_start_date)
        expected_start_date = expected_start_date.replace(hour=0,
                                                          minute=0,
                                                          second=0,
                                                          microsecond=0)
        expected_end_date = min(end_date, ce_end_date)
        expected_end_date = expected_end_date.replace(hour=0,
                                                      minute=0,
                                                      second=0,
                                                      microsecond=0)

        update_summary_tables(self.schema, provider, provider_aws_uuid,
                              start_date, end_date)

        with schema_context(self.schema):
            daily_entry = daily_table.objects.all().aggregate(
                Min('usage_start'), Max('usage_end'))
            result_start_date = daily_entry['usage_start__min']
            result_end_date = daily_entry['usage_end__max']

        self.assertEqual(result_start_date, expected_start_date)
        self.assertEqual(result_end_date, expected_end_date)

        with schema_context(self.schema):
            summary_entry = summary_table.objects.all().aggregate(
                Min('usage_start'), Max('usage_end'))
            result_start_date = summary_entry['usage_start__min']
            result_end_date = summary_entry['usage_end__max']

        self.assertEqual(result_start_date, expected_start_date)
        self.assertEqual(result_end_date, expected_end_date)

    @patch('masu.processor.tasks.update_cost_summary_table')
    @patch('masu.processor.tasks.update_charge_info')
    @patch(
        'masu.database.cost_model_db_accessor.CostModelDBAccessor._make_rate_by_metric_map'
    )
    @patch(
        'masu.database.cost_model_db_accessor.CostModelDBAccessor.get_markup')
    def test_update_summary_tables_ocp(self, mock_markup, mock_rate_map,
                                       mock_charge_info, mock_cost_summary):
        """Test that the summary table task runs."""
        markup = {}
        mem_rate = {'tiered_rates': [{'value': '1.5', 'unit': 'USD'}]}
        cpu_rate = {'tiered_rates': [{'value': '2.5', 'unit': 'USD'}]}
        rate_metric_map = {
            'cpu_core_usage_per_hour': cpu_rate,
            'memory_gb_usage_per_hour': mem_rate
        }

        mock_markup.return_value = markup
        mock_rate_map.return_value = rate_metric_map

        provider = 'OCP'
        provider_ocp_uuid = self.ocp_test_provider_uuid

        daily_table_name = OCP_REPORT_TABLE_MAP['line_item_daily']
        start_date = self.start_date.replace(
            day=1) + relativedelta.relativedelta(months=-1)

        with schema_context(self.schema):
            daily_query = self.ocp_accessor._get_db_obj_query(daily_table_name)

            initial_daily_count = daily_query.count()

        self.assertEqual(initial_daily_count, 0)
        update_summary_tables(self.schema, provider, provider_ocp_uuid,
                              start_date)

        with schema_context(self.schema):
            self.assertNotEqual(daily_query.count(), initial_daily_count)

        update_charge_info(schema_name=self.schema,
                           provider_uuid=provider_ocp_uuid)

        table_name = OCP_REPORT_TABLE_MAP['line_item_daily_summary']
        with ProviderDBAccessor(provider_ocp_uuid) as provider_accessor:
            provider_obj = provider_accessor.get_provider()

        usage_period_qry = self.ocp_accessor.get_usage_period_query_by_provider(
            provider_obj.id)
        with schema_context(self.schema):
            cluster_id = usage_period_qry.first().cluster_id

            items = self.ocp_accessor._get_db_obj_query(table_name).filter(
                cluster_id=cluster_id)
            for item in items:
                self.assertIsNotNone(item.pod_charge_memory_gigabyte_hours)
                self.assertIsNotNone(item.pod_charge_cpu_core_hours)

            storage_daily_name = OCP_REPORT_TABLE_MAP[
                'storage_line_item_daily']

            items = self.ocp_accessor._get_db_obj_query(
                storage_daily_name).filter(cluster_id=cluster_id)
            for item in items:
                self.assertIsNotNone(item.volume_request_storage_byte_seconds)
                self.assertIsNotNone(
                    item.persistentvolumeclaim_usage_byte_seconds)

            storage_summary_name = OCP_REPORT_TABLE_MAP[
                'line_item_daily_summary']
            items = self.ocp_accessor._get_db_obj_query(
                storage_summary_name).filter(cluster_id=cluster_id,
                                             data_source='Storage')
            for item in items:
                self.assertIsNotNone(
                    item.volume_request_storage_gigabyte_months)
                self.assertIsNotNone(
                    item.persistentvolumeclaim_usage_gigabyte_months)

        mock_charge_info.apply_async.assert_called()
        mock_cost_summary.si.assert_called()

    @patch('masu.processor.tasks.update_charge_info')
    @patch(
        'masu.database.cost_model_db_accessor.CostModelDBAccessor.get_memory_gb_usage_per_hour_rates'
    )
    @patch(
        'masu.database.cost_model_db_accessor.CostModelDBAccessor.get_cpu_core_usage_per_hour_rates'
    )
    def test_update_summary_tables_ocp_end_date(self, mock_cpu_rate,
                                                mock_mem_rate,
                                                mock_charge_info):
        """Test that the summary table task respects a date range."""
        mock_cpu_rate.return_value = 1.5
        mock_mem_rate.return_value = 2.5
        provider = 'OCP'
        provider_ocp_uuid = self.ocp_test_provider_uuid
        ce_table_name = OCP_REPORT_TABLE_MAP['report']
        daily_table_name = OCP_REPORT_TABLE_MAP['line_item_daily']

        start_date = self.start_date.replace(
            day=1, hour=0, minute=0, second=0,
            microsecond=0) + relativedelta.relativedelta(months=-1)

        end_date = start_date + timedelta(days=10)
        end_date = end_date.replace(hour=23, minute=59, second=59)

        daily_table = getattr(self.ocp_accessor.report_schema,
                              daily_table_name)
        ce_table = getattr(self.ocp_accessor.report_schema, ce_table_name)

        with schema_context(self.schema):
            ce_start_date = ce_table.objects\
                .filter(interval_start__gte=start_date)\
                .aggregate(Min('interval_start'))['interval_start__min']

            ce_end_date = ce_table.objects\
                .filter(interval_start__lte=end_date)\
                .aggregate(Max('interval_start'))['interval_start__max']

        # The summary tables will only include dates where there is data
        expected_start_date = max(start_date, ce_start_date)
        expected_start_date = expected_start_date.replace(hour=0,
                                                          minute=0,
                                                          second=0,
                                                          microsecond=0)
        expected_end_date = min(end_date, ce_end_date)
        expected_end_date = expected_end_date.replace(hour=0,
                                                      minute=0,
                                                      second=0,
                                                      microsecond=0)

        update_summary_tables(self.schema, provider, provider_ocp_uuid,
                              start_date, end_date)
        with schema_context(self.schema):
            daily_entry = daily_table.objects.all().aggregate(
                Min('usage_start'), Max('usage_end'))
            result_start_date = daily_entry['usage_start__min']
            result_end_date = daily_entry['usage_end__max']

        self.assertEqual(result_start_date, expected_start_date)
        self.assertEqual(result_end_date, expected_end_date)

    @patch('masu.processor.tasks.update_summary_tables')
    def test_get_report_data_for_all_providers(self, mock_update):
        """Test GET report_data endpoint with provider_uuid=*."""
        start_date = date.today()
        update_all_summary_tables(start_date)

        mock_update.delay.assert_called_with(ANY, ANY, ANY, str(start_date),
                                             ANY)

    @patch(
        'masu.database.ocp_report_db_accessor.OCPReportDBAccessor.populate_cost_summary_table'
    )
    def test_update_cost_summary_table(self, mock_update):
        """Tests that the updater updates the cost summary table."""
        provider = 'OCP'
        provider_aws_uuid = self.ocp_test_provider_uuid
        manifest_id = None
        start_date = self.start_date.replace(
            day=1, hour=0, minute=0, second=0,
            microsecond=0) + relativedelta.relativedelta(months=-1)

        update_cost_summary_table(self.schema,
                                  provider_aws_uuid,
                                  start_date=start_date)

        mock_update.assert_called()
예제 #7
0
class AWSReportProcessorTest(MasuTestCase):
    """Test Cases for the AWSReportProcessor object."""
    @classmethod
    def setUpClass(cls):
        """Set up the test class with required objects."""
        super().setUpClass()
        cls.test_report_test_path = "./koku/masu/test/data/test_cur.csv"
        cls.test_report_gzip_test_path = "./koku/masu/test/data/test_cur.csv.gz"

        cls.date_accessor = DateAccessor()
        cls.manifest_accessor = ReportManifestDBAccessor()

        _report_tables = copy.deepcopy(AWS_CUR_TABLE_MAP)
        _report_tables.pop("line_item_daily", None)
        _report_tables.pop("line_item_daily_summary", None)
        _report_tables.pop("tags_summary", None)
        _report_tables.pop("enabled_tag_keys", None)
        _report_tables.pop("ocp_on_aws_tags_summary", None)
        cls.report_tables = list(_report_tables.values())
        # Grab a single row of test data to work with
        with open(cls.test_report_test_path) as f:
            reader = csv.DictReader(f)
            cls.row = next(reader)

    def setUp(self):
        """Set up shared variables."""
        super().setUp()

        self.temp_dir = tempfile.mkdtemp()
        self.test_report = f"{self.temp_dir}/test_cur.csv"
        self.test_report_gzip = f"{self.temp_dir}/test_cur.csv.gz"

        shutil.copy2(self.test_report_test_path, self.test_report)
        shutil.copy2(self.test_report_gzip_test_path, self.test_report_gzip)

        self.processor = AWSReportProcessor(
            schema_name=self.schema,
            report_path=self.test_report,
            compression=UNCOMPRESSED,
            provider_uuid=self.aws_provider_uuid,
        )

        billing_start = self.date_accessor.today_with_timezone("UTC").replace(
            year=2018, month=6, day=1, hour=0, minute=0, second=0)
        self.assembly_id = "1234"
        self.manifest_dict = {
            "assembly_id": self.assembly_id,
            "billing_period_start_datetime": billing_start,
            "num_total_files": 2,
            "provider_uuid": self.aws_provider_uuid,
        }

        self.accessor = AWSReportDBAccessor(self.schema)
        self.report_schema = self.accessor.report_schema
        self.manifest = self.manifest_accessor.add(**self.manifest_dict)

    def tearDown(self):
        """Return the database to a pre-test state."""
        super().tearDown()

        shutil.rmtree(self.temp_dir)

        self.processor.processed_report.remove_processed_rows()
        self.processor.line_item_columns = None

    def test_initializer(self):
        """Test initializer."""
        self.assertIsNotNone(self.processor._schema)
        self.assertIsNotNone(self.processor._report_path)
        self.assertIsNotNone(self.processor._report_name)
        self.assertIsNotNone(self.processor._compression)
        self.assertEqual(self.processor._datetime_format,
                         Config.AWS_DATETIME_STR_FORMAT)
        self.assertEqual(self.processor._batch_size,
                         Config.REPORT_PROCESSING_BATCH_SIZE)

    def test_initializer_unsupported_compression(self):
        """Assert that an error is raised for an invalid compression."""
        with self.assertRaises(MasuProcessingError):
            AWSReportProcessor(
                schema_name=self.schema,
                report_path=self.test_report,
                compression="unsupported",
                provider_uuid=self.aws_provider_uuid,
            )

    def test_process_default(self):
        """Test the processing of an uncompressed file."""
        counts = {}
        processor = AWSReportProcessor(
            schema_name=self.schema,
            report_path=self.test_report,
            compression=UNCOMPRESSED,
            provider_uuid=self.aws_provider_uuid,
            manifest_id=self.manifest.id,
        )
        report_db = self.accessor
        report_schema = report_db.report_schema
        for table_name in self.report_tables:
            table = getattr(report_schema, table_name)
            with schema_context(self.schema):
                count = table.objects.count()
            counts[table_name] = count

        bill_date = self.manifest.billing_period_start_datetime.date()

        expected = (
            f"INFO:masu.processor.report_processor_base:Processing bill starting on {bill_date}.\n"
            f" Processing entire month.\n"
            f" schema_name: {self.schema},\n"
            f" provider_uuid: {self.aws_provider_uuid},\n"
            f" manifest_id: {self.manifest.id}")
        logging.disable(
            logging.NOTSET
        )  # We are currently disabling all logging below CRITICAL in masu/__init__.py
        with self.assertLogs("masu.processor.report_processor_base",
                             level="INFO") as logger:
            processor.process()
            self.assertIn(expected, logger.output)

        for table_name in self.report_tables:
            table = getattr(report_schema, table_name)
            with schema_context(self.schema):
                count = table.objects.count()

            if table_name in (
                    "reporting_awscostentryreservation",
                    "reporting_awscostentrypricing",
                    "reporting_ocpawscostlineitem_daily_summary_p",
                    "reporting_ocpawscostlineitem_project_daily_summary_p",
            ):
                self.assertTrue(count >= counts[table_name])
            else:
                self.assertTrue(count > counts[table_name])

        self.assertFalse(os.path.exists(self.test_report))

    def test_process_no_file_on_disk(self):
        """Test the processing of when the file is not found on disk."""
        counts = {}
        base_name = "test_no_cur.csv"
        no_report = f"{self.temp_dir}/{base_name}"
        processor = AWSReportProcessor(
            schema_name=self.schema,
            report_path=no_report,
            compression=UNCOMPRESSED,
            provider_uuid=self.aws_provider_uuid,
            manifest_id=self.manifest.id,
        )
        report_db = self.accessor
        report_schema = report_db.report_schema
        for table_name in self.report_tables:
            table = getattr(report_schema, table_name)
            with schema_context(self.schema):
                count = table.objects.count()
            counts[table_name] = count

        expected = ("INFO:masu.processor.aws.aws_report_processor:"
                    f"Skip processing for file: {base_name} and "
                    f"schema: {self.schema} as it was not found on disk.")
        logging.disable(
            logging.NOTSET
        )  # We are currently disabling all logging below CRITICAL in masu/__init__.py
        with self.assertLogs("masu.processor.aws.aws_report_processor",
                             level="INFO") as logger:
            processor.process()
            self.assertIn(expected, logger.output)

    def test_process_gzip(self):
        """Test the processing of a gzip compressed file."""
        counts = {}
        processor = AWSReportProcessor(
            schema_name=self.schema,
            report_path=self.test_report_gzip,
            compression=GZIP_COMPRESSED,
            provider_uuid=self.aws_provider_uuid,
        )
        report_db = self.accessor
        report_schema = report_db.report_schema
        for table_name in self.report_tables:
            table = getattr(report_schema, table_name)
            with schema_context(self.schema):
                count = table.objects.count()
            counts[table_name] = count

        processor.process()

        for table_name in self.report_tables:
            table = getattr(report_schema, table_name)
            with schema_context(self.schema):
                count = table.objects.count()

            if table_name in (
                    "reporting_awscostentryreservation",
                    "reporting_ocpawscostlineitem_daily_summary_p",
                    "reporting_ocpawscostlineitem_project_daily_summary_p",
                    "reporting_awscostentrypricing",
            ):
                self.assertTrue(count >= counts[table_name])
            else:
                self.assertTrue(count > counts[table_name])

    def test_process_duplicates(self):
        """Test that row duplicates are not inserted into the DB."""
        counts = {}
        processor = AWSReportProcessor(
            schema_name=self.schema,
            report_path=self.test_report,
            compression=UNCOMPRESSED,
            provider_uuid=self.aws_provider_uuid,
        )

        report_db = self.accessor
        report_schema = report_db.report_schema
        with schema_context(self.schema):
            table_name = AWS_CUR_TABLE_MAP["line_item"]
            table = getattr(report_schema, table_name)
            initial_line_item_count = table.objects.count()

        # Process for the first time
        processor.process()

        for table_name in self.report_tables:
            table = getattr(report_schema, table_name)
            with schema_context(self.schema):
                count = table.objects.count()
            counts[table_name] = count
            if table_name == AWS_CUR_TABLE_MAP["line_item"]:
                counts[
                    table_name] = counts[table_name] - initial_line_item_count

        # Wipe stale data
        with schema_context(self.schema):
            self.accessor._get_db_obj_query(
                AWS_CUR_TABLE_MAP["line_item"]).delete()

        shutil.copy2(self.test_report_test_path, self.test_report)

        processor = AWSReportProcessor(
            schema_name=self.schema,
            report_path=self.test_report,
            compression=UNCOMPRESSED,
            provider_uuid=self.aws_provider_uuid,
        )
        # Process for the second time
        processor.process()

        for table_name in self.report_tables:
            table = getattr(report_schema, table_name)
            with schema_context(self.schema):
                count = table.objects.count()
            self.assertTrue(count == counts[table_name])

    def test_process_finalized_rows(self):
        """Test that a finalized bill is processed properly."""
        data = []
        table_name = AWS_CUR_TABLE_MAP["line_item"]

        with open(self.test_report) as f:
            reader = csv.DictReader(f)
            for row in reader:
                data.append(row)

        for row in data:
            row["bill/InvoiceId"] = "12345"

        tmp_file = "/tmp/test_process_finalized_rows.csv"
        field_names = data[0].keys()

        with open(tmp_file, "w") as f:
            writer = csv.DictWriter(f, fieldnames=field_names)
            writer.writeheader()
            writer.writerows(data)

        processor = AWSReportProcessor(
            schema_name=self.schema,
            report_path=self.test_report,
            compression=UNCOMPRESSED,
            provider_uuid=self.aws_provider_uuid,
        )

        report_db = self.accessor
        report_schema = report_db.report_schema
        with schema_context(self.schema):
            table = getattr(report_schema, table_name)
            initial_line_item_count = table.objects.count()

        # Process for the first time
        processor.process()

        bill_table_name = AWS_CUR_TABLE_MAP["bill"]
        bill_table = getattr(report_schema, bill_table_name)
        with schema_context(self.schema):
            bill = bill_table.objects.first()
            self.assertIsNone(bill.finalized_datetime)

        table = getattr(report_schema, table_name)
        with schema_context(self.schema):
            orig_count = table.objects.count()

        # Wipe stale data
        with schema_context(self.schema):
            self.accessor._get_db_obj_query(table_name).delete()

        processor = AWSReportProcessor(
            schema_name=self.schema,
            report_path=tmp_file,
            compression=UNCOMPRESSED,
            provider_uuid=self.aws_provider_uuid,
        )
        # Process for the second time
        processor.process()

        with schema_context(self.schema):
            count = table.objects.count()
            self.assertTrue(count == orig_count - initial_line_item_count)
            count = table.objects.filter(invoice_id__isnull=False).count()
            self.assertTrue(count == orig_count - initial_line_item_count)

        with schema_context(self.schema):
            final_count = bill_table.objects.filter(
                finalized_datetime__isnull=False).count()
            self.assertEqual(final_count, 1)

    def test_process_finalized_rows_small_batch_size(self):
        """Test that a finalized bill is processed properly on batch size."""
        data = []
        table_name = AWS_CUR_TABLE_MAP["line_item"]

        with open(self.test_report) as f:
            reader = csv.DictReader(f)
            for row in reader:
                data.append(row)

        for row in data:
            row["bill/InvoiceId"] = "12345"

        tmp_file = "/tmp/test_process_finalized_rows.csv"
        field_names = data[0].keys()

        with open(tmp_file, "w") as f:
            writer = csv.DictWriter(f, fieldnames=field_names)
            writer.writeheader()
            writer.writerows(data)

        processor = AWSReportProcessor(
            schema_name=self.schema,
            report_path=self.test_report,
            compression=UNCOMPRESSED,
            provider_uuid=self.aws_provider_uuid,
        )

        report_db = self.accessor
        report_schema = report_db.report_schema
        table = getattr(report_schema, table_name)
        with schema_context(self.schema):
            initial_line_item_count = table.objects.count()

        # Process for the first time
        processor.process()

        bill_table_name = AWS_CUR_TABLE_MAP["bill"]
        bill_table = getattr(report_schema, bill_table_name)
        with schema_context(self.schema):
            bill = bill_table.objects.first()
            self.assertIsNone(bill.finalized_datetime)

        table = getattr(report_schema, table_name)
        with schema_context(self.schema):
            orig_count = table.objects.count()

        # Wipe stale data
        with schema_context(self.schema):
            self.accessor._get_db_obj_query(table_name).delete()

        processor = AWSReportProcessor(
            schema_name=self.schema,
            report_path=tmp_file,
            compression=UNCOMPRESSED,
            provider_uuid=self.aws_provider_uuid,
        )
        processor._batch_size = 2
        # Process for the second time
        processor.process()

        with schema_context(self.schema):
            count = table.objects.count()
            self.assertTrue(count == orig_count - initial_line_item_count)
            count = table.objects.filter(invoice_id__isnull=False).count()
            self.assertTrue(count == orig_count - initial_line_item_count)

        with schema_context(self.schema):
            final_count = bill_table.objects.filter(
                finalized_datetime__isnull=False).count()
            self.assertEqual(final_count, 1)

    def test_do_not_overwrite_finalized_bill_timestamp(self):
        """Test that a finalized bill timestamp does not get overwritten."""
        data = []
        with open(self.test_report) as f:
            reader = csv.DictReader(f)
            for row in reader:
                data.append(row)

        for row in data:
            row["bill/InvoiceId"] = "12345"

        tmp_file = "/tmp/test_process_finalized_rows.csv"
        field_names = data[0].keys()

        with open(tmp_file, "w") as f:
            writer = csv.DictWriter(f, fieldnames=field_names)
            writer.writeheader()
            writer.writerows(data)

        processor = AWSReportProcessor(
            schema_name=self.schema,
            report_path=self.test_report,
            compression=UNCOMPRESSED,
            provider_uuid=self.aws_provider_uuid,
        )

        # Process for the first time
        processor.process()
        report_db = self.accessor
        report_schema = report_db.report_schema

        bill_table_name = AWS_CUR_TABLE_MAP["bill"]
        bill_table = getattr(report_schema, bill_table_name)
        with schema_context(self.schema):
            bill = bill_table.objects.first()

        with open(tmp_file, "w") as f:
            writer = csv.DictWriter(f, fieldnames=field_names)
            writer.writeheader()
            writer.writerows(data)

        processor = AWSReportProcessor(
            schema_name=self.schema,
            report_path=tmp_file,
            compression=UNCOMPRESSED,
            provider_uuid=self.aws_provider_uuid,
        )
        # Process for the second time
        processor.process()

        finalized_datetime = bill.finalized_datetime

        with open(tmp_file, "w") as f:
            writer = csv.DictWriter(f, fieldnames=field_names)
            writer.writeheader()
            writer.writerows(data)

        processor = AWSReportProcessor(
            schema_name=self.schema,
            report_path=tmp_file,
            compression=UNCOMPRESSED,
            provider_uuid=self.aws_provider_uuid,
        )
        # Process for the third time to make sure the timestamp is the same
        processor.process()
        self.assertEqual(bill.finalized_datetime, finalized_datetime)

    def test_get_file_opener_default(self):
        """Test that the default file opener is returned."""
        opener, mode = self.processor._get_file_opener(UNCOMPRESSED)

        self.assertEqual(opener, open)
        self.assertEqual(mode, "r")

    def test_get_file_opener_gzip(self):
        """Test that the gzip file opener is returned."""
        opener, mode = self.processor._get_file_opener(GZIP_COMPRESSED)

        self.assertEqual(opener, gzip.open)
        self.assertEqual(mode, "rt")

    def test_update_mappings(self):
        """Test that mappings are updated."""
        test_entry = {"key": "value"}
        counts = {}
        ce_maps = {
            "cost_entry": self.processor.existing_cost_entry_map,
            "product": self.processor.existing_product_map,
            "pricing": self.processor.existing_pricing_map,
            "reservation": self.processor.existing_reservation_map,
        }

        for name, ce_map in ce_maps.items():
            counts[name] = len(ce_map.values())
            ce_map.update(test_entry)

        self.processor._update_mappings()

        for name, ce_map in ce_maps.items():
            self.assertTrue(len(ce_map.values()) > counts[name])
            for key in test_entry:
                self.assertIn(key, ce_map)

    def test_write_processed_rows_to_csv(self):
        """Test that the CSV bulk upload file contains proper data."""
        bill_id = self.processor._create_cost_entry_bill(
            self.row, self.accessor)
        cost_entry_id = self.processor._create_cost_entry(
            self.row, bill_id, self.accessor)
        product_id = self.processor._create_cost_entry_product(
            self.row, self.accessor)
        pricing_id = self.processor._create_cost_entry_pricing(
            self.row, self.accessor)
        reservation_id = self.processor._create_cost_entry_reservation(
            self.row, self.accessor)
        self.processor._create_cost_entry_line_item(self.row, cost_entry_id,
                                                    bill_id, product_id,
                                                    pricing_id, reservation_id,
                                                    self.accessor)

        file_obj = self.processor._write_processed_rows_to_csv()

        line_item_data = self.processor.processed_report.line_items.pop()
        # Convert data to CSV format
        expected_values = [
            str(value) if value else None for value in line_item_data.values()
        ]

        reader = csv.reader(file_obj)
        new_row = next(reader)
        actual = {}
        for i, key in enumerate(line_item_data.keys()):
            actual[key] = new_row[i] if new_row[i] else None

        self.assertEqual(actual.keys(), line_item_data.keys())
        self.assertEqual(list(actual.values()), expected_values)

    def test_get_data_for_table(self):
        """Test that a row is disected into appropriate data structures."""

        for table_name in self.report_tables:
            expected_columns = sorted(REPORT_COLUMN_MAP[table_name].values())
            data = self.processor._get_data_for_table(self.row, table_name)

            for key in data:
                self.assertIn(key, expected_columns)

    def test_process_tags(self):
        """Test that tags are properly packaged in a JSON string."""
        row = {
            "resourceTags/user:environment": "prod",
            "notATag": "value",
            "resourceTags/System": "value",
            "resourceTags/system:system_key": "system_value",
        }
        expected = {"environment": "prod"}
        actual = json.loads(self.processor._process_tags(row))

        self.assertNotIn(row["notATag"], actual)
        self.assertEqual(expected, actual)

    def test_get_cost_entry_time_interval(self):
        """Test that an interval string is properly split."""
        fmt = Config.AWS_DATETIME_STR_FORMAT
        end = datetime.datetime.utcnow()
        expected_start = (end - datetime.timedelta(days=1)).strftime(fmt)
        expected_end = end.strftime(fmt)
        interval = expected_start + "/" + expected_end

        actual_start, actual_end = self.processor._get_cost_entry_time_interval(
            interval)

        self.assertEqual(expected_start, actual_start)
        self.assertEqual(expected_end, actual_end)

    def test_create_cost_entry_bill(self):
        """Test that a cost entry bill id is returned."""
        table_name = AWS_CUR_TABLE_MAP["bill"]

        bill_id = self.processor._create_cost_entry_bill(
            self.row, self.accessor)

        self.assertIsNotNone(bill_id)

        query = self.accessor._get_db_obj_query(table_name)
        id_in_db = query.order_by("-id").first().id
        provider_uuid = query.order_by("-id").first().provider_id

        self.assertEqual(bill_id, id_in_db)
        self.assertIsNotNone(provider_uuid)

    def test_create_cost_entry_bill_existing(self):
        """Test that a cost entry bill id is returned from an existing bill."""
        table_name = AWS_CUR_TABLE_MAP["bill"]

        bill_id = self.processor._create_cost_entry_bill(
            self.row, self.accessor)

        query = self.accessor._get_db_obj_query(table_name)
        bill = query.first()

        self.processor.current_bill = bill

        new_bill_id = self.processor._create_cost_entry_bill(
            self.row, self.accessor)

        self.assertEqual(bill_id, new_bill_id)

        self.processor.current_bill = None

    def test_create_cost_entry(self):
        """Test that a cost entry id is returned."""
        table_name = AWS_CUR_TABLE_MAP["cost_entry"]

        bill_id = self.processor._create_cost_entry_bill(
            self.row, self.accessor)

        cost_entry_id = self.processor._create_cost_entry(
            self.row, bill_id, self.accessor)

        self.assertIsNotNone(cost_entry_id)

        query = self.accessor._get_db_obj_query(table_name)
        id_in_db = query.order_by("-id").first().id

        self.assertEqual(cost_entry_id, id_in_db)

    def test_create_cost_entry_existing(self):
        """Test that a cost entry id is returned from an existing entry."""
        bill_id = self.processor._create_cost_entry_bill(
            self.row, self.accessor)

        interval = self.row.get("identity/TimeInterval")
        start, _ = self.processor._get_cost_entry_time_interval(interval)
        key = (bill_id, start)
        expected_id = random.randint(1, 9)
        self.processor.existing_cost_entry_map[key] = expected_id

        cost_entry_id = self.processor._create_cost_entry(
            self.row, bill_id, self.accessor)
        self.assertEqual(cost_entry_id, expected_id)

    def test_create_cost_entry_line_item(self):
        """Test that line item data is returned properly."""
        bill_id = self.processor._create_cost_entry_bill(
            self.row, self.accessor)
        cost_entry_id = self.processor._create_cost_entry(
            self.row, bill_id, self.accessor)
        product_id = self.processor._create_cost_entry_product(
            self.row, self.accessor)
        pricing_id = self.processor._create_cost_entry_pricing(
            self.row, self.accessor)
        reservation_id = self.processor._create_cost_entry_reservation(
            self.row, self.accessor)

        self.processor._create_cost_entry_line_item(self.row, cost_entry_id,
                                                    bill_id, product_id,
                                                    pricing_id, reservation_id,
                                                    self.accessor)

        line_item = None
        if self.processor.processed_report.line_items:
            line_item = self.processor.processed_report.line_items[-1]

        self.assertIsNotNone(line_item)
        self.assertIn("tags", line_item)
        self.assertEqual(line_item.get("cost_entry_id"), cost_entry_id)
        self.assertEqual(line_item.get("cost_entry_bill_id"), bill_id)
        self.assertEqual(line_item.get("cost_entry_product_id"), product_id)
        self.assertEqual(line_item.get("cost_entry_pricing_id"), pricing_id)
        self.assertEqual(line_item.get("cost_entry_reservation_id"),
                         reservation_id)

        self.assertIsNotNone(self.processor.line_item_columns)

    def test_create_cost_entry_product(self):
        """Test that a cost entry product id is returned."""
        table_name = AWS_CUR_TABLE_MAP["product"]

        product_id = self.processor._create_cost_entry_product(
            self.row, self.accessor)

        self.assertIsNotNone(product_id)

        query = self.accessor._get_db_obj_query(table_name)
        id_in_db = query.order_by("-id").first().id

        self.assertEqual(product_id, id_in_db)

    def test_create_cost_entry_product_already_processed(self):
        """Test that an already processed product id is returned."""
        expected_id = random.randint(1, 9)
        sku = self.row.get("product/sku")
        product_name = self.row.get("product/ProductName")
        region = self.row.get("product/region")
        key = (sku, product_name, region)
        self.processor.processed_report.products.update({key: expected_id})

        product_id = self.processor._create_cost_entry_product(
            self.row, self.accessor)

        self.assertEqual(product_id, expected_id)

    def test_create_cost_entry_product_existing(self):
        """Test that a previously existing product id is returned."""
        expected_id = random.randint(1, 9)
        sku = self.row.get("product/sku")
        product_name = self.row.get("product/ProductName")
        region = self.row.get("product/region")
        key = (sku, product_name, region)
        self.processor.existing_product_map.update({key: expected_id})

        product_id = self.processor._create_cost_entry_product(
            self.row, self.accessor)

        self.assertEqual(product_id, expected_id)

    def test_create_cost_entry_pricing(self):
        """Test that a cost entry pricing id is returned."""
        table_name = AWS_CUR_TABLE_MAP["pricing"]

        pricing_id = self.processor._create_cost_entry_pricing(
            self.row, self.accessor)

        self.assertIsNotNone(pricing_id)

        with schema_context(self.schema):
            query = self.accessor._get_db_obj_query(table_name)
            id_in_db = query.order_by("-id").first().id
            self.assertEqual(pricing_id, id_in_db)

    def test_create_cost_entry_pricing_already_processed(self):
        """Test that an already processed pricing id is returned."""
        expected_id = random.randint(1, 9)

        key = "{term}-{unit}".format(term=self.row["pricing/term"],
                                     unit=self.row["pricing/unit"])
        self.processor.processed_report.pricing.update({key: expected_id})

        pricing_id = self.processor._create_cost_entry_pricing(
            self.row, self.accessor)

        self.assertEqual(pricing_id, expected_id)

    def test_create_cost_entry_pricing_existing(self):
        """Test that a previously existing pricing id is returned."""
        expected_id = random.randint(1, 9)

        key = "{term}-{unit}".format(term=self.row["pricing/term"],
                                     unit=self.row["pricing/unit"])
        self.processor.existing_pricing_map.update({key: expected_id})

        pricing_id = self.processor._create_cost_entry_pricing(
            self.row, self.accessor)

        self.assertEqual(pricing_id, expected_id)

    def test_create_cost_entry_reservation(self):
        """Test that a cost entry reservation id is returned."""
        # Ensure a reservation exists on the row
        arn = "TestARN"
        row = copy.deepcopy(self.row)
        row["reservation/ReservationARN"] = arn

        table_name = AWS_CUR_TABLE_MAP["reservation"]

        reservation_id = self.processor._create_cost_entry_reservation(
            row, self.accessor)

        self.assertIsNotNone(reservation_id)

        query = self.accessor._get_db_obj_query(table_name)
        id_in_db = query.order_by("-id").first().id

        self.assertEqual(reservation_id, id_in_db)

    def test_create_cost_entry_reservation_update(self):
        """Test that a cost entry reservation id is returned."""
        # Ensure a reservation exists on the row
        arn = "TestARN"
        row = copy.deepcopy(self.row)
        row["reservation/ReservationARN"] = arn
        row["reservation/NumberOfReservations"] = 1

        table_name = AWS_CUR_TABLE_MAP["reservation"]

        reservation_id = self.processor._create_cost_entry_reservation(
            row, self.accessor)

        self.assertIsNotNone(reservation_id)

        with schema_context(self.schema):
            query = self.accessor._get_db_obj_query(table_name)
            id_in_db = query.order_by("-id").first().id

        self.assertEqual(reservation_id, id_in_db)

        row["lineItem/LineItemType"] = "RIFee"
        res_count = row["reservation/NumberOfReservations"]
        row["reservation/NumberOfReservations"] = res_count + 1
        reservation_id = self.processor._create_cost_entry_reservation(
            row, self.accessor)

        self.assertEqual(reservation_id, id_in_db)

        db_row = query.filter(id=id_in_db).first()
        self.assertEqual(db_row.number_of_reservations,
                         row["reservation/NumberOfReservations"])

    def test_create_cost_entry_reservation_already_processed(self):
        """Test that an already processed reservation id is returned."""
        expected_id = random.randint(1, 9)
        arn = self.row.get("reservation/ReservationARN")
        self.processor.processed_report.reservations.update({arn: expected_id})

        reservation_id = self.processor._create_cost_entry_reservation(
            self.row, self.accessor)

        self.assertEqual(reservation_id, expected_id)

    def test_create_cost_entry_reservation_existing(self):
        """Test that a previously existing reservation id is returned."""
        expected_id = random.randint(1, 9)
        arn = self.row.get("reservation/ReservationARN")
        self.processor.existing_reservation_map.update({arn: expected_id})

        product_id = self.processor._create_cost_entry_reservation(
            self.row, self.accessor)

        self.assertEqual(product_id, expected_id)

    def test_check_for_finalized_bill_bill_is_finalized(self):
        """Verify that a file with invoice_id is marked as finalzed."""
        data = []

        with open(self.test_report) as f:
            reader = csv.DictReader(f)
            for row in reader:
                data.append(row)

        for row in data:
            row["bill/InvoiceId"] = "12345"

        tmp_file = "/tmp/test_process_finalized_rows.csv"
        field_names = data[0].keys()

        with open(tmp_file, "w") as f:
            writer = csv.DictWriter(f, fieldnames=field_names)
            writer.writeheader()
            writer.writerows(data)

        processor = AWSReportProcessor(
            schema_name=self.schema,
            report_path=tmp_file,
            compression=UNCOMPRESSED,
            provider_uuid=self.aws_provider_uuid,
        )

        result = processor._check_for_finalized_bill()

        self.assertTrue(result)

    def test_check_for_finalized_bill_bill_not_finalized(self):
        """Verify that a file without invoice_id is not marked as finalzed."""
        processor = AWSReportProcessor(
            schema_name=self.schema,
            report_path=self.test_report,
            compression=UNCOMPRESSED,
            provider_uuid=self.aws_provider_uuid,
        )

        result = processor._check_for_finalized_bill()

        self.assertFalse(result)

    def test_check_for_finalized_bill_empty_bill(self):
        """Verify that an empty file is not marked as finalzed."""
        tmp_file = "/tmp/test_process_finalized_rows.csv"

        with open(tmp_file, "w"):
            pass

        processor = AWSReportProcessor(
            schema_name=self.schema,
            report_path=tmp_file,
            compression=UNCOMPRESSED,
            provider_uuid=self.aws_provider_uuid,
        )
        result = processor._check_for_finalized_bill()
        self.assertFalse(result)

    @patch(
        "masu.processor.report_processor_base.ReportProcessorBase._should_process_full_month"
    )
    def test_delete_line_items_use_data_cutoff_date(self, mock_should_process):
        """Test that only three days of data are deleted."""
        mock_should_process.return_value = True

        today = self.date_accessor.today_with_timezone("UTC").replace(
            hour=0, minute=0, second=0, microsecond=0)
        first_of_month = today.replace(day=1)

        self.manifest.billing_period_start_datetime = first_of_month
        self.manifest.save()

        processor = AWSReportProcessor(
            schema_name=self.schema,
            report_path="",
            compression=UNCOMPRESSED,
            provider_uuid=self.aws_provider_uuid,
            manifest_id=self.manifest.id,
        )

        # Get latest data date.
        bill_ids = []
        with schema_context(self.schema):
            bills = self.accessor.get_cost_entry_bills()
            for key, value in bills.items():
                if key[1] == DateHelper().this_month_start:
                    bill_ids.append(value)

        for bill_id in bill_ids:
            with schema_context(self.schema):
                line_item_query = self.accessor.get_lineitem_query_for_billid(
                    bill_id)
                undeleted_max_date = line_item_query.aggregate(
                    max_date=Max("usage_start"))

            mock_should_process.return_value = False
            processor._delete_line_items(AWSReportDBAccessor,
                                         is_finalized=False)

            with schema_context(self.schema):
                # bills = self.accessor.get_cost_entry_bills()
                # for bill_id in bills.values():
                line_item_query = self.accessor.get_lineitem_query_for_billid(
                    bill_id)
                if today.day <= 3:
                    self.assertEqual(line_item_query.count(), 0)
                else:
                    max_date = line_item_query.aggregate(
                        max_date=Max("usage_start"))
                    self.assertLess(
                        max_date.get("max_date").date(),
                        processor.data_cutoff_date)
                    self.assertLess(
                        max_date.get("max_date").date(),
                        undeleted_max_date.get("max_date").date())
                    self.assertNotEqual(line_item_query.count(), 0)

    @patch("masu.processor.report_processor_base.DateAccessor")
    def test_data_cutoff_date_not_start_of_month(self, mock_date):
        """Test that the data_cuttof_date respects month boundaries."""
        today = self.date_accessor.today_with_timezone("UTC").replace(day=10)
        expected = today.date() - relativedelta(days=2)

        mock_date.return_value.today_with_timezone.return_value = today

        processor = AWSReportProcessor(
            schema_name=self.schema,
            report_path=self.test_report,
            compression=UNCOMPRESSED,
            provider_uuid=self.aws_provider_uuid,
            manifest_id=self.manifest.id,
        )
        self.assertEqual(expected, processor.data_cutoff_date)

    @patch("masu.processor.report_processor_base.DateAccessor")
    def test_data_cutoff_date_start_of_month(self, mock_date):
        """Test that the data_cuttof_date respects month boundaries."""
        today = self.date_accessor.today_with_timezone("UTC")
        first_of_month = today.replace(day=1)

        mock_date.return_value.today_with_timezone.return_value = first_of_month

        processor = AWSReportProcessor(
            schema_name=self.schema,
            report_path=self.test_report,
            compression=UNCOMPRESSED,
            provider_uuid=self.aws_provider_uuid,
            manifest_id=self.manifest.id,
        )
        self.assertEqual(first_of_month.date(), processor.data_cutoff_date)

    @patch("masu.processor.report_processor_base.ReportManifestDBAccessor")
    def test_should_process_full_month_first_manifest_for_bill(
            self, mock_manifest_accessor):
        """Test that we process data for a new bill/manifest completely."""
        mock_manifest = Mock()
        today = self.date_accessor.today_with_timezone("UTC")
        mock_manifest.billing_period_start_datetime = today
        mock_manifest.num_processed_files = 1
        mock_manifest.num_total_files = 2
        mock_manifest_accessor.return_value.__enter__.return_value.get_manifest_by_id.return_value = mock_manifest
        mock_manifest_accessor.return_value.__enter__.return_value.get_manifest_list_for_provider_and_bill_date.return_value = [  # noqa
            mock_manifest
        ]
        processor = AWSReportProcessor(
            schema_name=self.schema,
            report_path=self.test_report,
            compression=UNCOMPRESSED,
            provider_uuid=self.aws_provider_uuid,
            manifest_id=self.manifest.id,
        )

        self.assertTrue(processor._should_process_full_month())

    @patch("masu.processor.report_processor_base.ReportManifestDBAccessor")
    def test_should_process_full_month_not_first_manifest_for_bill(
            self, mock_manifest_accessor):
        """Test that we process a window of data for the bill/manifest."""
        mock_manifest = Mock()
        today = self.date_accessor.today_with_timezone("UTC")
        mock_manifest.billing_period_start_datetime = today
        mock_manifest.num_processed_files = 1
        mock_manifest.num_total_files = 1
        mock_manifest_accessor.return_value.__enter__.return_value.get_manifest_by_id.return_value = mock_manifest
        mock_manifest_accessor.return_value.__enter__.return_value.get_manifest_list_for_provider_and_bill_date.return_value = [  # noqa
            mock_manifest,
            mock_manifest,
        ]
        processor = AWSReportProcessor(
            schema_name=self.schema,
            report_path=self.test_report,
            compression=UNCOMPRESSED,
            provider_uuid=self.aws_provider_uuid,
            manifest_id=self.manifest.id,
        )

        self.assertFalse(processor._should_process_full_month())

    @patch("masu.processor.report_processor_base.ReportManifestDBAccessor")
    def test_should_process_full_month_manifest_for_not_current_month(
            self, mock_manifest_accessor):
        """Test that we process this manifest completely."""
        mock_manifest = Mock()
        last_month = self.date_accessor.today_with_timezone(
            "UTC") - relativedelta(months=1)
        mock_manifest.billing_period_start_datetime = last_month
        mock_manifest_accessor.return_value.__enter__.return_value.get_manifest_by_id.return_value = mock_manifest
        processor = AWSReportProcessor(
            schema_name=self.schema,
            report_path=self.test_report,
            compression=UNCOMPRESSED,
            provider_uuid=self.aws_provider_uuid,
            manifest_id=self.manifest.id,
        )

        self.assertTrue(processor._should_process_full_month())

    def test_should_process_full_month_no_manifest(self):
        """Test that we process this manifest completely."""
        processor = AWSReportProcessor(
            schema_name=self.schema,
            report_path=self.test_report,
            compression=UNCOMPRESSED,
            provider_uuid=self.aws_provider_uuid,
        )

        self.assertTrue(processor._should_process_full_month())

    def test_should_process_row_within_cuttoff_date(self):
        """Test that we correctly determine a row should be processed."""
        today = self.date_accessor.today_with_timezone("UTC")
        row = {"lineItem/UsageStartDate": today.isoformat()}

        processor = AWSReportProcessor(
            schema_name=self.schema,
            report_path=self.test_report,
            compression=UNCOMPRESSED,
            provider_uuid=self.aws_provider_uuid,
            manifest_id=self.manifest.id,
        )

        should_process = processor._should_process_row(
            row, "lineItem/UsageStartDate", False)

        self.assertTrue(should_process)

    def test_should_process_row_outside_cuttoff_date(self):
        """Test that we correctly determine a row should be processed."""
        today = self.date_accessor.today_with_timezone("UTC")
        usage_start = today - relativedelta(days=10)
        row = {"lineItem/UsageStartDate": usage_start.isoformat()}

        processor = AWSReportProcessor(
            schema_name=self.schema,
            report_path=self.test_report,
            compression=UNCOMPRESSED,
            provider_uuid=self.aws_provider_uuid,
            manifest_id=self.manifest.id,
        )

        should_process = processor._should_process_row(
            row, "lineItem/UsageStartDate", False)

        self.assertFalse(should_process)

    def test_should_process_is_full_month(self):
        """Test that we correctly determine a row should be processed."""
        today = self.date_accessor.today_with_timezone("UTC")
        usage_start = today - relativedelta(days=10)
        row = {"lineItem/UsageStartDate": usage_start.isoformat()}

        processor = AWSReportProcessor(
            schema_name=self.schema,
            report_path=self.test_report,
            compression=UNCOMPRESSED,
            provider_uuid=self.aws_provider_uuid,
            manifest_id=self.manifest.id,
        )

        should_process = processor._should_process_row(
            row, "lineItem/UsageStartDate", True)

        self.assertTrue(should_process)

    def test_should_process_is_finalized(self):
        """Test that we correctly determine a row should be processed."""
        today = self.date_accessor.today_with_timezone("UTC")
        usage_start = today - relativedelta(days=10)
        row = {"lineItem/UsageStartDate": usage_start.isoformat()}

        processor = AWSReportProcessor(
            schema_name=self.schema,
            report_path=self.test_report,
            compression=UNCOMPRESSED,
            provider_uuid=self.aws_provider_uuid,
            manifest_id=self.manifest.id,
        )

        should_process = processor._should_process_row(
            row, "lineItem/UsageStartDate", False, is_finalized=True)

        self.assertTrue(should_process)

    def test_get_date_column_filter(self):
        """Test that the Azure specific filter is returned."""
        processor = AWSReportProcessor(
            schema_name=self.schema,
            report_path=self.test_report,
            compression=UNCOMPRESSED,
            provider_uuid=self.aws_provider_uuid,
            manifest_id=self.manifest.id,
        )
        date_filter = processor.get_date_column_filter()

        self.assertIn("usage_start__gte", date_filter)

    def test_process_memory_value(self):
        """Test that product data has memory properly parsed."""

        data = {"memory": None}
        result = self.processor._process_memory_value(data)
        self.assertIsNone(result.get("memory"))
        self.assertIsNone(result.get("memory_unit"))

        data = {"memory": "NA"}
        result = self.processor._process_memory_value(data)
        self.assertIsNone(result.get("memory"))
        self.assertIsNone(result.get("memory_unit"))

        data = {"memory": "4GiB"}
        result = self.processor._process_memory_value(data)
        self.assertEqual(result.get("memory"), 4)
        self.assertEqual(result.get("memory_unit"), "GiB")

        data = {"memory": "4 GB"}
        result = self.processor._process_memory_value(data)
        self.assertEqual(result.get("memory"), 4)
        self.assertEqual(result.get("memory_unit"), "GB")

        data = {"memory": "4"}
        result = self.processor._process_memory_value(data)
        self.assertEqual(result.get("memory"), 4)
        self.assertIsNone(result.get("memory_unit"))