Beispiel #1
0
 def setUp(self):
     """Test set up."""
     super().setUp()
     provider_accessor = ProviderDBAccessor(self.aws_provider_uuid)
     provider = provider_accessor.get_provider()
     self.provider_uuid = provider.uuid
     provider_accessor.close_session()
Beispiel #2
0
class OCPUtilTests(MasuTestCase):
    """Test the OCP utility functions."""

    def setUp(self):
        """Shared variables used by ocp common tests."""
        super().setUp()
        self.common_accessor = ReportingCommonDBAccessor()
        self.column_map = self.common_accessor.column_map
        self.accessor = OCPReportDBAccessor(schema=self.schema, column_map=self.column_map)
        self.provider_accessor = ProviderDBAccessor(provider_uuid=self.ocp_test_provider_uuid)
        self.report_schema = self.accessor.report_schema
        self.creator = ReportObjectCreator(self.schema, self.column_map)
        self.all_tables = list(OCP_REPORT_TABLE_MAP.values())

        self.provider_uuid = self.provider_accessor.get_provider().uuid
        reporting_period = self.creator.create_ocp_report_period(provider_uuid=self.provider_uuid)
        report = self.creator.create_ocp_report(
            reporting_period, reporting_period.report_period_start
        )
        self.creator.create_ocp_usage_line_item(reporting_period, report)
        self.creator.create_ocp_storage_line_item(reporting_period, report)

    def test_get_cluster_id_from_provider(self):
        """Test that the cluster ID is returned from OCP provider."""
        cluster_id = utils.get_cluster_id_from_provider(self.ocp_test_provider_uuid)
        self.assertIsNotNone(cluster_id)

    def test_get_cluster_id_from_non_ocp_provider(self):
        """Test that None is returned when getting cluster ID on non-OCP provider."""
        cluster_id = utils.get_cluster_id_from_provider(self.aws_provider_uuid)
        self.assertIsNone(cluster_id)

    def test_get_provider_uuid_from_cluster_id(self):
        """Test that the provider uuid is returned for a cluster ID."""
        cluster_id = self.ocp_provider_resource_name
        provider_uuid = utils.get_provider_uuid_from_cluster_id(cluster_id)
        try:
            UUID(provider_uuid)
        except ValueError:
            self.fail('{} is not a valid uuid.'.format(str(provider_uuid)))

    def test_get_provider_uuid_from_invalid_cluster_id(self):
        """Test that the provider uuid is not returned for an invalid cluster ID."""
        cluster_id = 'bad_cluster_id'
        provider_uuid = utils.get_provider_uuid_from_cluster_id(cluster_id)
        self.assertIsNone(provider_uuid)

    def test_poll_ingest_override_for_provider(self):
        """Test that OCP polling override returns True if insights local path exists."""
        fake_dir = tempfile.mkdtemp()
        with patch.object(Config, 'INSIGHTS_LOCAL_REPORT_DIR', fake_dir):
            cluster_id = utils.get_cluster_id_from_provider(self.ocp_test_provider_uuid)
            expected_path = '{}/{}/'.format(Config.INSIGHTS_LOCAL_REPORT_DIR, cluster_id)
            os.makedirs(expected_path, exist_ok=True)
            self.assertTrue(utils.poll_ingest_override_for_provider(self.ocp_test_provider_uuid))
        shutil.rmtree(fake_dir)
class AzureReportChargeUpdaterTest(MasuTestCase):
    """Test Cases for the AzureReportChargeUpdater object."""
    @classmethod
    def setUpClass(cls):
        """Set up the test class with required objects."""
        super().setUpClass()
        with ReportingCommonDBAccessor() as report_common_db:
            cls.column_map = report_common_db.column_map

        cls.accessor = AzureReportDBAccessor('acct10001', cls.column_map)

        cls.report_schema = cls.accessor.report_schema

        cls.all_tables = list(AZURE_REPORT_TABLE_MAP.values())

        cls.creator = ReportObjectCreator(cls.schema, cls.column_map)

        cls.date_accessor = DateAccessor()
        cls.manifest_accessor = ReportManifestDBAccessor()

    def setUp(self):
        """Set up each test."""
        super().setUp()

        billing_start = self.date_accessor.today_with_timezone('UTC').replace(
            day=1)
        self.manifest_dict = {
            'assembly_id': '1234',
            'billing_period_start_datetime': billing_start,
            'num_total_files': 1,
            'provider_id': self.azure_provider.id,
        }

        self.provider_accessor = ProviderDBAccessor(
            provider_uuid=self.azure_test_provider_uuid)
        provider_id = self.provider_accessor.get_provider().id

        self.updater = AzureReportChargeUpdater(
            schema=self.schema,
            provider_uuid=self.azure_test_provider_uuid,
            provider_id=provider_id)

        today = DateAccessor().today_with_timezone('UTC')
        bill = self.creator.create_azure_cost_entry_bill(
            provider_id=provider_id, bill_date=today)
        product = self.creator.create_azure_cost_entry_product()
        meter = self.creator.create_azure_meter()
        service = self.creator.create_azure_service()
        self.creator.create_azure_cost_entry_line_item(bill, product, meter,
                                                       service)

        self.manifest = self.manifest_accessor.add(**self.manifest_dict)
        self.manifest_accessor.commit()

        with ProviderDBAccessor(
                self.azure_test_provider_uuid) as provider_accessor:
            self.provider = provider_accessor.get_provider()

    def test_azure_update_summary_charge_info(self):
        """Test to verify Azure derived cost summary is calculated."""
        start_date = self.date_accessor.today_with_timezone('UTC')
        bill_date = start_date.replace(day=1).date()

        self.updater.update_summary_charge_info()

        with AzureReportDBAccessor(self.schema, self.column_map) as accessor:
            bill = accessor.get_cost_entry_bills_by_date(bill_date)[0]
            self.assertIsNotNone(bill.derived_cost_datetime)
Beispiel #4
0
class OCPUtilTests(MasuTestCase):
    """Test the OCP utility functions."""
    def setUp(self):
        """Shared variables used by ocp common tests."""
        super().setUp()
        self.accessor = OCPReportDBAccessor(schema=self.schema)
        self.provider_accessor = ProviderDBAccessor(
            provider_uuid=self.ocp_test_provider_uuid)
        self.report_schema = self.accessor.report_schema
        self.creator = ReportObjectCreator(self.schema)
        self.all_tables = list(OCP_REPORT_TABLE_MAP.values())

        self.provider_uuid = self.provider_accessor.get_provider().uuid
        reporting_period = self.creator.create_ocp_report_period(
            provider_uuid=self.provider_uuid)
        report = self.creator.create_ocp_report(
            reporting_period, reporting_period.report_period_start)
        self.creator.create_ocp_usage_line_item(reporting_period, report)
        self.creator.create_ocp_storage_line_item(reporting_period, report)
        self.creator.create_ocp_node_label_line_item(reporting_period, report)

    def test_get_cluster_id_from_provider(self):
        """Test that the cluster ID is returned from OCP provider."""
        cluster_id = utils.get_cluster_id_from_provider(
            self.ocp_test_provider_uuid)
        self.assertIsNotNone(cluster_id)

    def test_get_cluster_id_from_non_ocp_provider(self):
        """Test that None is returned when getting cluster ID on non-OCP provider."""
        cluster_id = utils.get_cluster_id_from_provider(self.aws_provider_uuid)
        self.assertIsNone(cluster_id)

    def test_get_cluster_alias_from_cluster_id(self):
        """Test that the cluster alias is returned from cluster_id."""
        cluster_id = self.ocp_cluster_id
        cluster_alias = utils.get_cluster_alias_from_cluster_id(cluster_id)
        self.assertIsNotNone(cluster_alias)

    def test_get_provider_uuid_from_cluster_id(self):
        """Test that the provider uuid is returned for a cluster ID."""
        cluster_id = self.ocp_cluster_id
        provider_uuid = utils.get_provider_uuid_from_cluster_id(cluster_id)
        try:
            UUID(provider_uuid)
        except ValueError:
            self.fail("{} is not a valid uuid.".format(str(provider_uuid)))

    def test_get_provider_uuid_from_invalid_cluster_id(self):
        """Test that the provider uuid is not returned for an invalid cluster ID."""
        cluster_id = "bad_cluster_id"
        provider_uuid = utils.get_provider_uuid_from_cluster_id(cluster_id)
        self.assertIsNone(provider_uuid)

    def test_poll_ingest_override_for_provider(self):
        """Test that OCP polling override returns True if insights local path exists."""
        fake_dir = tempfile.mkdtemp()
        with patch.object(Config, "INSIGHTS_LOCAL_REPORT_DIR", fake_dir):
            cluster_id = utils.get_cluster_id_from_provider(
                self.ocp_test_provider_uuid)
            expected_path = f"{Config.INSIGHTS_LOCAL_REPORT_DIR}/{cluster_id}/"
            os.makedirs(expected_path, exist_ok=True)
            self.assertTrue(
                utils.poll_ingest_override_for_provider(
                    self.ocp_test_provider_uuid))
        shutil.rmtree(fake_dir)

    def test_process_openshift_datetime(self):
        """Test process_openshift_datetime method with good and bad values."""
        expected_dt_str = "2020-07-01 00:00:00"
        expected = pd.to_datetime(expected_dt_str)
        dt = utils.process_openshift_datetime("2020-07-01 00:00:00 +0000 UTC")
        self.assertEqual(expected, dt)
Beispiel #5
0
class OCPUtilTests(MasuTestCase):
    """Test the OCP utility functions."""
    def setUp(self):
        """Shared variables used by ocp common tests."""
        super().setUp()
        self.accessor = OCPReportDBAccessor(schema=self.schema)
        self.provider_accessor = ProviderDBAccessor(
            provider_uuid=self.ocp_test_provider_uuid)
        self.report_schema = self.accessor.report_schema
        self.creator = ReportObjectCreator(self.schema)
        self.all_tables = list(OCP_REPORT_TABLE_MAP.values())

        self.provider_uuid = self.provider_accessor.get_provider().uuid
        reporting_period = self.creator.create_ocp_report_period(
            provider_uuid=self.provider_uuid)
        report = self.creator.create_ocp_report(
            reporting_period, reporting_period.report_period_start)
        self.creator.create_ocp_usage_line_item(reporting_period, report)
        self.creator.create_ocp_storage_line_item(reporting_period, report)
        self.creator.create_ocp_node_label_line_item(reporting_period, report)

    def test_get_cluster_id_from_provider(self):
        """Test that the cluster ID is returned from OCP provider."""
        cluster_id = utils.get_cluster_id_from_provider(
            self.ocp_test_provider_uuid)
        self.assertIsNotNone(cluster_id)

    def test_get_cluster_id_with_no_authentication(self):
        """Test that a None is correctly returned if authentication is not present."""
        # Remove test provider authentication
        Provider.objects.filter(uuid=self.ocp_test_provider_uuid).update(
            authentication=None)
        ocp_provider = Provider.objects.get(uuid=self.ocp_test_provider_uuid)
        self.assertIsNone(ocp_provider.authentication)
        # Assert if authentication is empty we return none instead of an error
        cluster_id = utils.get_cluster_id_from_provider(
            self.ocp_test_provider_uuid)
        self.assertIsNone(cluster_id)

    def test_get_cluster_id_from_non_ocp_provider(self):
        """Test that None is returned when getting cluster ID on non-OCP provider."""
        cluster_id = utils.get_cluster_id_from_provider(self.aws_provider_uuid)
        self.assertIsNone(cluster_id)

    def test_get_cluster_alias_from_cluster_id(self):
        """Test that the cluster alias is returned from cluster_id."""
        cluster_id = self.ocp_cluster_id
        cluster_alias = utils.get_cluster_alias_from_cluster_id(cluster_id)
        self.assertIsNotNone(cluster_alias)

    def test_get_provider_uuid_from_cluster_id(self):
        """Test that the provider uuid is returned for a cluster ID."""
        cluster_id = self.ocp_cluster_id
        provider_uuid = utils.get_provider_uuid_from_cluster_id(cluster_id)
        try:
            UUID(provider_uuid)
        except ValueError:
            self.fail(f"{str(provider_uuid)} is not a valid uuid.")

    def test_get_provider_uuid_from_invalid_cluster_id(self):
        """Test that the provider uuid is not returned for an invalid cluster ID."""
        cluster_id = "bad_cluster_id"
        provider_uuid = utils.get_provider_uuid_from_cluster_id(cluster_id)
        self.assertIsNone(provider_uuid)

    def test_poll_ingest_override_for_provider(self):
        """Test that OCP polling override returns True if insights local path exists."""
        fake_dir = tempfile.mkdtemp()
        with patch.object(Config, "INSIGHTS_LOCAL_REPORT_DIR", fake_dir):
            cluster_id = utils.get_cluster_id_from_provider(
                self.ocp_test_provider_uuid)
            expected_path = f"{Config.INSIGHTS_LOCAL_REPORT_DIR}/{cluster_id}/"
            os.makedirs(expected_path, exist_ok=True)
            self.assertTrue(
                utils.poll_ingest_override_for_provider(
                    self.ocp_test_provider_uuid))
        shutil.rmtree(fake_dir)

    def test_process_openshift_datetime(self):
        """Test process_openshift_datetime method with good and bad values."""
        expected_dt_str = "2020-07-01 00:00:00"
        expected = pd.to_datetime(expected_dt_str)
        dt = utils.process_openshift_datetime("2020-07-01 00:00:00 +0000 UTC")
        self.assertEqual(expected, dt)

    def test_ocp_generate_daily_data(self):
        """Test that OCP data is aggregated to daily."""
        usage = random.randint(1, 10)
        capacity = random.randint(1, 10)
        namespace = "project_1"
        pod = "pod_1"
        node = "node_1"
        resource_id = "123"
        pvc = "pvc_1"
        label = '{"key": "value"}'

        interval_start = datetime.datetime(2021, 6, 7, 1, 0, 0)
        next_hour = datetime.datetime(2021, 6, 7, 2, 0, 0)
        next_day = datetime.datetime(2021, 6, 8, 1, 0, 0)

        base_data = {
            "report_period_start": datetime.datetime(2021, 6, 1, 0, 0, 0),
            "report_period_end": datetime.datetime(2021, 6, 1, 0, 0, 0),
            "interval_start": interval_start,
            "interval_end": interval_start + datetime.timedelta(hours=1),
        }
        base_next_hour = copy.deepcopy(base_data)
        base_next_hour["interval_start"] = next_hour
        base_next_hour["interval_end"] = next_hour + datetime.timedelta(
            hours=1)

        base_next_day = copy.deepcopy(base_data)
        base_next_day["interval_start"] = next_day
        base_next_day["interval_end"] = next_day + datetime.timedelta(hours=1)

        base_pod_data = {
            "pod": pod,
            "namespace": namespace,
            "node": node,
            "resource_id": resource_id,
            "pod_usage_cpu_core_seconds": usage,
            "pod_request_cpu_core_seconds": usage,
            "pod_limit_cpu_core_seconds": usage,
            "pod_usage_memory_byte_seconds": usage,
            "pod_request_memory_byte_seconds": usage,
            "pod_limit_memory_byte_seconds": usage,
            "node_capacity_cpu_cores": capacity,
            "node_capacity_cpu_core_seconds": capacity,
            "node_capacity_memory_bytes": capacity,
            "node_capacity_memory_byte_seconds": capacity,
            "pod_labels": label,
        }

        base_storage_data = {
            "namespace": namespace,
            "pod": pod,
            "persistentvolumeclaim": pvc,
            "persistentvolume": pvc,
            "storageclass": "gold",
            "persistentvolumeclaim_capacity_bytes": capacity,
            "persistentvolumeclaim_capacity_byte_seconds": capacity,
            "volume_request_storage_byte_seconds": usage,
            "persistentvolumeclaim_usage_byte_seconds": usage,
            "persistentvolume_labels": label,
            "persistentvolumeclaim_labels": label,
        }

        base_node_data = {"node": node, "node_labels": label}

        base_namespace_data = {
            "namespace": namespace,
            "namespace_labels": label
        }

        base_data_list = [
            ("pod_usage", base_pod_data),
            ("storage_usage", base_storage_data),
            ("node_labels", base_node_data),
            ("namespace_labels", base_namespace_data),
        ]

        for report_type, data in base_data_list:
            data_list = [
                copy.deepcopy(base_data),
                copy.deepcopy(base_next_hour),
                copy.deepcopy(base_next_day)
            ]
            for entry in data_list:
                entry.update(data)
            df = pd.DataFrame(data_list)
            daily_df = utils.ocp_generate_daily_data(df, report_type)

            first_day = daily_df[daily_df["interval_start"] == str(
                interval_start.date())]
            second_day = daily_df[daily_df["interval_start"] == str(
                next_day.date())]

            # Assert that there is only 1 record per day
            self.assertEqual(first_day.shape[0], 1)
            self.assertEqual(second_day.shape[0], 1)

            if report_type == "pod_usage":
                self.assertTrue(
                    (first_day["pod_usage_cpu_core_seconds"] == usage *
                     2).bool())
                self.assertTrue(
                    (first_day["pod_usage_memory_byte_seconds"] == usage *
                     2).bool())
                self.assertTrue(
                    (first_day["node_capacity_cpu_cores"] == capacity).bool())

                self.assertTrue(
                    (second_day["pod_usage_cpu_core_seconds"] == usage).bool())
                self.assertTrue(
                    (second_day["pod_usage_memory_byte_seconds"] == usage
                     ).bool())
                self.assertTrue(
                    (second_day["node_capacity_cpu_cores"] == capacity).bool())
            elif report_type == "storage_usage":
                self.assertTrue(
                    (first_day["persistentvolumeclaim_usage_byte_seconds"] ==
                     usage * 2).bool())
                self.assertTrue((
                    first_day["volume_request_storage_byte_seconds"] == usage *
                    2).bool())
                self.assertTrue(
                    (first_day["persistentvolumeclaim_capacity_byte_seconds"]
                     == capacity * 2).bool())
                self.assertTrue(
                    (first_day["persistentvolumeclaim_capacity_bytes"] ==
                     capacity).bool())

                self.assertTrue(
                    (second_day["persistentvolumeclaim_usage_byte_seconds"] ==
                     usage).bool())
                self.assertTrue(
                    (second_day["volume_request_storage_byte_seconds"] == usage
                     ).bool())
                self.assertTrue(
                    (second_day["persistentvolumeclaim_capacity_byte_seconds"]
                     == capacity).bool())
                self.assertTrue(
                    (second_day["persistentvolumeclaim_capacity_bytes"] ==
                     capacity).bool())
            elif report_type == "node_labels":
                self.assertTrue((first_day["node"] == node).bool())
                self.assertTrue((first_day["node_labels"] == label).bool())

                self.assertTrue((second_day["node"] == node).bool())
                self.assertTrue((second_day["node_labels"] == label).bool())
            elif report_type == "namespace_labels":
                self.assertTrue((first_day["namespace"] == namespace).bool())
                self.assertTrue(
                    (first_day["namespace_labels"] == label).bool())

                self.assertTrue((second_day["namespace"] == namespace).bool())
                self.assertTrue(
                    (second_day["namespace_labels"] == label).bool())

    def test_match_openshift_labels(self):
        """Test that a label match returns."""
        matched_tags = [{"key": "value"}, {"other_key": "other_value"}]

        tag_dicts = [
            {
                "tag": json.dumps({"key": "value"}),
                "expected": '"key": "value"'
            },
            {
                "tag": json.dumps({"key": "other_value"}),
                "expected": ""
            },
            {
                "tag": json.dumps({
                    "key": "value",
                    "other_key": "other_value"
                }),
                "expected": '"key": "value","other_key": "other_value"',
            },
        ]

        for tag_dict in tag_dicts:
            td = tag_dict.get("tag")
            expected = tag_dict.get("expected")
            result = utils.match_openshift_labels(td, matched_tags)
            self.assertEqual(result, expected)

    def test_match_openshift_labels_null_value(self):
        """Test that a label match doesn't return null tag values."""
        matched_tags = [{"key": "value"}, {"other_key": "other_value"}]

        tag_dicts = [
            {
                "tag": json.dumps({"key": "value"}),
                "expected": '"key": "value"'
            },
            {
                "tag": json.dumps({"key": "other_value"}),
                "expected": ""
            },
            {
                "tag": json.dumps({
                    "key": "value",
                    "other_key": "other_value"
                }),
                "expected": '"key": "value","other_key": "other_value"',
            },
            {
                "tag": json.dumps({
                    "key": "value",
                    "other_key": None
                }),
                "expected": '"key": "value"'
            },
        ]

        for tag_dict in tag_dicts:
            td = tag_dict.get("tag")
            expected = tag_dict.get("expected")
            result = utils.match_openshift_labels(td, matched_tags)
            self.assertEqual(result, expected)

    def test_get_report_details(self):
        """Test that we handle manifest files properly."""
        with tempfile.TemporaryDirectory() as manifest_path:
            manifest_file = f"{manifest_path}/manifest.json"
            with self.assertLogs("masu.util.ocp.common",
                                 level="INFO") as logger:
                expected = f"INFO:masu.util.ocp.common:No manifest available at {manifest_file}"
                utils.get_report_details(manifest_path)
                self.assertIn(expected, logger.output)

            with open(manifest_file, "w") as f:
                data = {"key": "value"}
                json.dump(data, f)
            utils.get_report_details(manifest_path)

            with patch("masu.util.ocp.common.open") as mock_open:
                mock_open.side_effect = OSError
                with self.assertLogs("masu.util.ocp.common",
                                     level="INFO") as logger:
                    expected = "ERROR:masu.util.ocp.common:Unable to extract manifest data"
                    utils.get_report_details(manifest_path)
                    self.assertIn(expected, logger.output[0])
Beispiel #6
0
class AWSReportChargeUpdaterTest(MasuTestCase):
    """Test Cases for the AWSReportChargeUpdater object."""
    @classmethod
    def setUpClass(cls):
        """Set up the test class with required objects."""
        super().setUpClass()
        with ReportingCommonDBAccessor() as report_common_db:
            cls.column_map = report_common_db.column_map

        cls.accessor = AWSReportDBAccessor('acct10001', cls.column_map)

        cls.report_schema = cls.accessor.report_schema
        cls.session = cls.accessor._session

        cls.all_tables = list(AWS_CUR_TABLE_MAP.values())

        cls.creator = ReportObjectCreator(cls.accessor, cls.column_map,
                                          cls.report_schema.column_types)

        cls.date_accessor = DateAccessor()
        billing_start = cls.date_accessor.today_with_timezone('UTC').replace(
            day=1)
        cls.manifest_dict = {
            'assembly_id': '1234',
            'billing_period_start_datetime': billing_start,
            'num_total_files': 2,
            'provider_id': 1,
        }
        cls.manifest_accessor = ReportManifestDBAccessor()

    @classmethod
    def tearDownClass(cls):
        """Tear down the test class."""
        cls.manifest_accessor.close_session()
        cls.accessor.close_connections()
        cls.accessor.close_session()
        super().tearDownClass()

    def setUp(self):
        """Set up each test."""
        super().setUp()
        if self.accessor._conn.closed:
            self.accessor._conn = self.accessor._db.connect()
        if self.accessor._pg2_conn.closed:
            self.accessor._pg2_conn = self.accessor._get_psycopg2_connection()
        if self.accessor._cursor.closed:
            self.accessor._cursor = self.accessor._get_psycopg2_cursor()

        self.provider_accessor = ProviderDBAccessor(
            provider_uuid=self.aws_test_provider_uuid)
        provider_id = self.provider_accessor.get_provider().id
        self.provider_accessor = ProviderDBAccessor(
            provider_uuid=self.aws_test_provider_uuid)
        self.updater = AWSReportChargeUpdater(
            schema=self.test_schema,
            provider_uuid=self.aws_test_provider_uuid,
            provider_id=provider_id,
        )
        today = DateAccessor().today_with_timezone('UTC')
        bill = self.creator.create_cost_entry_bill(today)
        cost_entry = self.creator.create_cost_entry(bill, today)
        product = self.creator.create_cost_entry_product()
        pricing = self.creator.create_cost_entry_pricing()
        reservation = self.creator.create_cost_entry_reservation()
        self.creator.create_cost_entry_line_item(bill, cost_entry, product,
                                                 pricing, reservation)

        self.manifest = self.manifest_accessor.add(**self.manifest_dict)
        self.manifest_accessor.commit()

        with ProviderDBAccessor(
                self.aws_test_provider_uuid) as provider_accessor:
            self.provider = provider_accessor.get_provider()

    def tearDown(self):
        """Return the database to a pre-test state."""
        super().tearDown()
        # self.session.rollback()

        for table_name in self.all_tables:
            tables = self.accessor._get_db_obj_query(table_name).all()
            for table in tables:
                self.accessor._session.delete(table)
        self.accessor.commit()

        manifests = self.manifest_accessor._get_db_obj_query().all()
        for manifest in manifests:
            self.manifest_accessor.delete(manifest)
        self.manifest_accessor.commit()

    def test_update_summary_charge_info(self):
        """Test to verify AWS derived cost summary is calculated."""
        start_date = self.date_accessor.today_with_timezone('UTC')
        bill_date = start_date.replace(day=1).date()

        self.updater.update_summary_charge_info()
        with AWSReportDBAccessor('acct10001', self.column_map) as accessor:
            bill = accessor.get_cost_entry_bills_by_date(bill_date)[0]
            self.assertIsNotNone(bill.derived_cost_datetime)
class AWSReportChargeUpdaterTest(MasuTestCase):
    """Test Cases for the AWSReportChargeUpdater object."""
    @classmethod
    def setUpClass(cls):
        """Set up the test class with required objects."""
        super().setUpClass()
        with ReportingCommonDBAccessor() as report_common_db:
            cls.column_map = report_common_db.column_map

        cls.accessor = AWSReportDBAccessor('acct10001', cls.column_map)

        cls.report_schema = cls.accessor.report_schema

        cls.all_tables = list(AWS_CUR_TABLE_MAP.values())

        cls.creator = ReportObjectCreator(cls.schema, cls.column_map)

        cls.date_accessor = DateAccessor()
        cls.manifest_accessor = ReportManifestDBAccessor()

    def setUp(self):
        """Set up each test."""
        super().setUp()

        billing_start = self.date_accessor.today_with_timezone('UTC').replace(
            day=1)
        self.manifest_dict = {
            'assembly_id': '1234',
            'billing_period_start_datetime': billing_start,
            'num_total_files': 2,
            'provider_id': self.aws_provider.id,
        }

        self.provider_accessor = ProviderDBAccessor(
            provider_uuid=self.aws_test_provider_uuid)
        provider_id = self.provider_accessor.get_provider().id
        self.provider_accessor = ProviderDBAccessor(
            provider_uuid=self.aws_test_provider_uuid)
        self.updater = AWSReportChargeUpdater(
            schema=self.schema,
            provider_uuid=self.aws_test_provider_uuid,
            provider_id=provider_id,
        )
        today = DateAccessor().today_with_timezone('UTC')
        bill = self.creator.create_cost_entry_bill(provider_id=provider_id,
                                                   bill_date=today)
        cost_entry = self.creator.create_cost_entry(bill, today)
        product = self.creator.create_cost_entry_product()
        pricing = self.creator.create_cost_entry_pricing()
        reservation = self.creator.create_cost_entry_reservation()
        self.creator.create_cost_entry_line_item(bill, cost_entry, product,
                                                 pricing, reservation)

        self.manifest = self.manifest_accessor.add(**self.manifest_dict)
        self.manifest_accessor.commit()

        with ProviderDBAccessor(
                self.aws_test_provider_uuid) as provider_accessor:
            self.provider = provider_accessor.get_provider()

    @patch(
        'masu.database.cost_model_db_accessor.CostModelDBAccessor.get_markup')
    def test_update_summary_charge_info(self, mock_markup):
        """Test to verify AWS derived cost summary is calculated."""
        markup = {'value': 10, 'unit': 'percent'}
        mock_markup.return_value = markup
        start_date = self.date_accessor.today_with_timezone('UTC')
        bill_date = start_date.replace(day=1).date()

        self.updater.update_summary_charge_info()
        with AWSReportDBAccessor('acct10001', self.column_map) as accessor:
            bill = accessor.get_cost_entry_bills_by_date(bill_date)[0]
            self.assertIsNotNone(bill.derived_cost_datetime)