コード例 #1
0
ファイル: test_hash.py プロジェクト: zanetworker/koku
    def test_shake_length_required(self):
        """Test that shake algorithms require a length."""
        hash_function = "shake_128"
        with self.assertRaises(HasherError):
            Hasher(hash_function=hash_function)

        Hasher(hash_function=hash_function, length=16)
コード例 #2
0
class HasherUtilTests(MasuTestCase):
    """Test the hashing utility class."""
    def setUp(self):
        super().setUp()
        self.encoding = 'utf-8'
        self.hash_function = random.choice(list(hashlib.algorithms_guaranteed))
        # self.hash_function = 'sha3_224'
        if 'shake' in self.hash_function:
            self.hasher = Hasher(
                hash_function=self.hash_function,
                length=random.randint(8, 64),
                encoding=self.encoding,
            )
        else:
            self.hasher = Hasher(hash_function=self.hash_function,
                                 encoding=self.encoding)
        self.string_to_hash = ''.join(
            [random.choice(string.ascii_letters) for _ in range(16)])

    def test_initializer(self):
        """Test that the proper variables are initialized."""
        # encoding = 'ascii'
        # hasher = Hasher(hash_function=self.hash_function, encoding=encoding)
        hash_function = getattr(hashlib, self.hash_function)
        self.assertEqual(self.hasher.hash_function, hash_function)
        self.assertEqual(self.hasher.encoding, self.encoding)

    def test_hash_string_to_hex(self):
        """Test that the string hash function returns a hex string."""
        result = self.hasher.hash_string_to_hex(self.string_to_hash)
        self.assertIsInstance(result, str)
        for char in result:
            self.assertIn(char, string.hexdigits)

    def test_reliable_hash(self):
        """Test that the hasher creates a true hash."""
        result_one = self.hasher.hash_string_to_hex(self.string_to_hash)
        result_two = self.hasher.hash_string_to_hex(self.string_to_hash)

        self.assertEqual(result_one, result_two)

    def test_shake_length_required(self):
        """Test that shake algorithms require a length."""
        hash_function = 'shake_128'
        with self.assertRaises(HasherError):
            hasher = Hasher(hash_function=hash_function)

        hasher = Hasher(hash_function=hash_function, length=16)

    def test_unsupported_algorithm(self):
        bad_algorithm = 'zuul'

        with self.assertRaises(HasherError):
            hasher = Hasher(hash_function=bad_algorithm)
コード例 #3
0
ファイル: test_hash.py プロジェクト: zanetworker/koku
 def setUp(self):
     """Shared variables for hasher tests."""
     super().setUp()
     self.encoding = "utf-8"
     self.hash_function = random.choice(list(hashlib.algorithms_guaranteed))
     if "shake" in self.hash_function:
         self.hasher = Hasher(
             hash_function=self.hash_function, length=random.randint(8, 64), encoding=self.encoding
         )
     else:
         self.hasher = Hasher(hash_function=self.hash_function, encoding=self.encoding)
     self.string_to_hash = "".join([random.choice(string.ascii_letters) for _ in range(16)])
コード例 #4
0
 def setUp(self):
     super().setUp()
     self.encoding = 'utf-8'
     self.hash_function = random.choice(list(hashlib.algorithms_guaranteed))
     # self.hash_function = 'sha3_224'
     if 'shake' in self.hash_function:
         self.hasher = Hasher(hash_function=self.hash_function,
                              length=random.randint(8, 64),
                              encoding=self.encoding)
     else:
         self.hasher = Hasher(hash_function=self.hash_function,
                              encoding=self.encoding)
     self.string_to_hash = ''.join(
         [random.choice(string.ascii_letters) for _ in range(16)])
コード例 #5
0
ファイル: report_processor.py プロジェクト: LaVLaS/masu
    def __init__(self, schema_name, report_path, compression):
        """Initialize the report processor.

        Args:
            schema_name (str): The name of the customer schema to process into
            report_path (str): Where the report file lives in the file system
            compression (CONST): How the report file is compressed.
                Accepted values: UNCOMPRESSED, GZIP_COMPRESSED

        """
        if compression.upper() not in ALLOWED_COMPRESSIONS:
            err_msg = f'Compression {compression} is not supported.'
            raise MasuProcessingError(err_msg)

        self._schema_name = schema_name
        self._report_path = report_path
        self._compression = compression.upper()
        self._report_name = path.basename(report_path)
        self._datetime_format = Config.AWS_DATETIME_STR_FORMAT
        self._batch_size = Config.REPORT_PROCESSING_BATCH_SIZE

        self.processed_report = ProcessedReport()

        # Gather database accessors
        self.report_common_db = ReportingCommonDBAccessor()
        self.column_map = self.report_common_db.column_map
        self.report_common_db.close_session()

        self.report_db = ReportDBAccessor(schema=self._schema_name,
                                          column_map=self.column_map)
        self.report_schema = self.report_db.report_schema

        self.temp_table = self.report_db.create_temp_table(
            AWS_CUR_TABLE_MAP['line_item'])
        self.line_item_columns = None

        self.hasher = Hasher(hash_function='sha256')
        self.hash_columns = self._get_line_item_hash_columns()

        self.current_bill = self.report_db.get_current_cost_entry_bill()
        self.existing_cost_entry_map = self.report_db.get_cost_entries()
        self.existing_product_map = self.report_db.get_products()
        self.existing_pricing_map = self.report_db.get_pricing()
        self.existing_reservation_map = self.report_db.get_reservations()

        LOG.info('Initialized report processor for file: %s and schema: %s',
                 self._report_name, self._schema_name)
コード例 #6
0
ファイル: test_hash.py プロジェクト: project-koku/koku
    def test_guaranteed_algorithms(self, mock_hashlib):
        """Test that an exception is raised for a guaranteed algorithm."""
        bad_algorithm = "test_hash_function"
        mock_hashlib.algorithms_guaranteed = [bad_algorithm]
        mock_hashlib.test_hash_function = None

        with self.assertRaises(HasherError):
            Hasher(hash_function=bad_algorithm)
コード例 #7
0
ファイル: test_hash.py プロジェクト: zanetworker/koku
    def test_unsupported_algorithm(self):
        """Test that an exception is raised for unsupported algorithms."""
        bad_algorithm = "zuul"

        with self.assertRaises(HasherError):
            Hasher(hash_function=bad_algorithm)
コード例 #8
0
ファイル: test_hash.py プロジェクト: project-koku/koku
class HasherUtilTests(MasuTestCase):
    """Test the hashing utility class."""
    def setUp(self):
        """Shared variables for hasher tests."""
        super().setUp()
        self.encoding = "utf-8"
        self.hash_function = random.choice(list(hashlib.algorithms_guaranteed))
        if "shake" in self.hash_function:
            self.hasher = Hasher(hash_function=self.hash_function,
                                 length=random.randint(8, 64),
                                 encoding=self.encoding)
        else:
            self.hasher = Hasher(hash_function=self.hash_function,
                                 encoding=self.encoding)
        self.string_to_hash = "".join(
            [random.choice(string.ascii_letters) for _ in range(16)])

    def test_initializer(self):
        """Test that the proper variables are initialized."""
        hash_function = getattr(hashlib, self.hash_function)
        self.assertEqual(self.hasher.hash_function, hash_function)
        self.assertEqual(self.hasher.encoding, self.encoding)

    def test_hash_string_to_hex(self):
        """Test that the string hash function returns a hex string."""
        result = self.hasher.hash_string_to_hex(self.string_to_hash)
        self.assertIsInstance(result, str)
        for char in result:
            self.assertIn(char, string.hexdigits)

    def test_reliable_hash(self):
        """Test that the hasher creates a true hash."""
        result_one = self.hasher.hash_string_to_hex(self.string_to_hash)
        result_two = self.hasher.hash_string_to_hex(self.string_to_hash)

        self.assertEqual(result_one, result_two)

    def test_shake_length_required(self):
        """Test that shake algorithms require a length."""
        hash_function = "shake_128"
        with self.assertRaises(HasherError):
            Hasher(hash_function=hash_function)

        Hasher(hash_function=hash_function, length=16)

    def test_unsupported_algorithm(self):
        """Test that an exception is raised for unsupported algorithms."""
        bad_algorithm = "zuul"

        with self.assertRaises(HasherError):
            Hasher(hash_function=bad_algorithm)

    @patch("masu.util.hash.hashlib", spec=hashlib)
    def test_guaranteed_algorithms(self, mock_hashlib):
        """Test that an exception is raised for a guaranteed algorithm."""
        bad_algorithm = "test_hash_function"
        mock_hashlib.algorithms_guaranteed = [bad_algorithm]
        mock_hashlib.test_hash_function = None

        with self.assertRaises(HasherError):
            Hasher(hash_function=bad_algorithm)
コード例 #9
0
ファイル: report_processor.py プロジェクト: LaVLaS/masu
class ReportProcessor:
    """Cost Usage Report processor."""
    def __init__(self, schema_name, report_path, compression):
        """Initialize the report processor.

        Args:
            schema_name (str): The name of the customer schema to process into
            report_path (str): Where the report file lives in the file system
            compression (CONST): How the report file is compressed.
                Accepted values: UNCOMPRESSED, GZIP_COMPRESSED

        """
        if compression.upper() not in ALLOWED_COMPRESSIONS:
            err_msg = f'Compression {compression} is not supported.'
            raise MasuProcessingError(err_msg)

        self._schema_name = schema_name
        self._report_path = report_path
        self._compression = compression.upper()
        self._report_name = path.basename(report_path)
        self._datetime_format = Config.AWS_DATETIME_STR_FORMAT
        self._batch_size = Config.REPORT_PROCESSING_BATCH_SIZE

        self.processed_report = ProcessedReport()

        # Gather database accessors
        self.report_common_db = ReportingCommonDBAccessor()
        self.column_map = self.report_common_db.column_map
        self.report_common_db.close_session()

        self.report_db = ReportDBAccessor(schema=self._schema_name,
                                          column_map=self.column_map)
        self.report_schema = self.report_db.report_schema

        self.temp_table = self.report_db.create_temp_table(
            AWS_CUR_TABLE_MAP['line_item'])
        self.line_item_columns = None

        self.hasher = Hasher(hash_function='sha256')
        self.hash_columns = self._get_line_item_hash_columns()

        self.current_bill = self.report_db.get_current_cost_entry_bill()
        self.existing_cost_entry_map = self.report_db.get_cost_entries()
        self.existing_product_map = self.report_db.get_products()
        self.existing_pricing_map = self.report_db.get_pricing()
        self.existing_reservation_map = self.report_db.get_reservations()

        LOG.info('Initialized report processor for file: %s and schema: %s',
                 self._report_name, self._schema_name)

    @property
    def line_item_conflict_columns(self):
        """Create a property to check conflict on line items."""
        return ['hash', 'cost_entry_id']

    @property
    def line_item_condition_column(self):
        """Create a property with condition to check for line item inserts."""
        return 'invoice_id'

    def process(self):
        """Process CUR file.

        Returns:
            (None)

        """
        row_count = 0
        bill_id = None
        opener, mode = self._get_file_opener(self._compression)
        # pylint: disable=invalid-name
        with opener(self._report_path, mode) as f:
            LOG.info('File %s opened for processing', str(f))
            reader = csv.DictReader(f)
            for row in reader:
                if bill_id is None:
                    bill_id = self._create_cost_entry_bill(row)

                cost_entry_id = self._create_cost_entry(row, bill_id)
                product_id = self._create_cost_entry_product(row)
                pricing_id = self._create_cost_entry_pricing(row)
                reservation_id = self._create_cost_entry_reservation(row)

                self._create_cost_entry_line_item(row, cost_entry_id, bill_id,
                                                  product_id, pricing_id,
                                                  reservation_id)

                if len(self.processed_report.line_items) >= self._batch_size:
                    self._save_to_db()

                    self.report_db.merge_temp_table(
                        AWS_CUR_TABLE_MAP['line_item'], self.temp_table,
                        self.line_item_columns,
                        self.line_item_condition_column,
                        self.line_item_conflict_columns)

                    LOG.info('Saving report rows %d to %d for %s', row_count,
                             row_count + len(self.processed_report.line_items),
                             self._report_name)
                    row_count += len(self.processed_report.line_items)

                    self._update_mappings()

            if self.processed_report.line_items:
                self._save_to_db()

                self.report_db.merge_temp_table(
                    AWS_CUR_TABLE_MAP['line_item'], self.temp_table,
                    self.line_item_columns, self.line_item_condition_column,
                    self.line_item_conflict_columns)

                LOG.info('Saving report rows %d to %d for %s', row_count,
                         row_count + len(self.processed_report.line_items),
                         self._report_name)

                row_count += len(self.processed_report.line_items)

            self.report_db.close_session()
            self.report_db.close_connections()

        LOG.info('Completed report processing for file: %s and schema: %s',
                 self._report_name, self._schema_name)
        return

    # pylint: disable=inconsistent-return-statements, no-self-use
    def _get_file_opener(self, compression):
        """Get the file opener for the file's compression.

        Args:
            compression (str): The compression format for the file.

        Returns:
            (file opener, str): The proper file stream handler for the
                compression and the read mode for the file

        """
        if compression == UNCOMPRESSED:
            return open, 'r'
        elif compression == GZIP_COMPRESSED:
            return gzip.open, 'rt'

    def _save_to_db(self):
        """Save current batch of records to the database."""
        columns = tuple(self.processed_report.line_items[0].keys())
        csv_file = self._write_processed_rows_to_csv()

        # This will commit all pricing, products, and reservations
        # on the session
        self.report_db.commit()

        self.report_db.bulk_insert_rows(csv_file, self.temp_table, columns)

    def _update_mappings(self):
        """Update cache of database objects for reference."""
        self.existing_cost_entry_map.update(self.processed_report.cost_entries)
        self.existing_product_map.update(self.processed_report.products)
        self.existing_pricing_map.update(self.processed_report.pricing)
        self.existing_reservation_map.update(
            self.processed_report.reservations)

        self.processed_report.remove_processed_rows()

    def _write_processed_rows_to_csv(self):
        """Output CSV content to file stream object."""
        values = [
            tuple(item.values()) for item in self.processed_report.line_items
        ]

        file_obj = io.StringIO()
        writer = csv.writer(file_obj,
                            delimiter='\t',
                            quoting=csv.QUOTE_NONE,
                            quotechar='')
        writer.writerows(values)
        file_obj.seek(0)

        return file_obj

    def _get_data_for_table(self, row, table_name):
        """Extract the data from a row for a specific table.

        Args:
            row (dict): A dictionary representation of a CSV file row
            table_name (str): The DB table fields are required for

        Returns:
            (dict): The data from the row keyed on the DB table's column names

        """
        # Memory can come as a single number or a number with a unit
        # e.g. "1" vs. "1 Gb" so it gets special cased.
        if 'product/memory' in row and row['product/memory'] is not None:
            memory_list = row['product/memory'].split(' ')
            if len(memory_list) > 1:
                memory, unit = row['product/memory'].split(' ')
            else:
                memory = memory_list[0]
                unit = None
            row['product/memory'] = memory
            row['product/memory_unit'] = unit

        column_map = self.column_map[table_name]

        return {
            column_map[key]: value
            for key, value in row.items() if key in column_map
        }

    # pylint: disable=no-self-use
    def _process_tags(self, row, tag_suffix='resourceTags'):
        """Return a JSON string of AWS resource tags.

        Args:
            row (dict): A dictionary representation of a CSV file row
            tag_suffix (str): A specifier used to identify a value as a tag

        Returns:
            (str): A JSON string of AWS resource tags

        """
        return json.dumps({
            key: value
            for key, value in row.items() if tag_suffix in key and row[key]
        })

    # pylint: disable=no-self-use
    def _get_cost_entry_time_interval(self, interval):
        """Split the cost entry time interval into start and end.

        Args:
            interval (str): The time interval from the cost usage report.

        Returns:
            (str, str): Separated start and end strings

        """
        start, end = interval.split('/')
        return start, end

    def _create_cost_entry_bill(self, row):
        """Create a cost entry bill object.

        Args:
            row (dict): A dictionary representation of a CSV file row

        Returns:
            (str): A cost entry bill object id

        """
        table_name = AWS_CUR_TABLE_MAP['bill']
        start_date = row.get('bill/BillingPeriodStartDate')

        current_start = None
        if self.current_bill is not None:
            current_start = self.current_bill.billing_period_start.strftime(
                self._datetime_format)

        if current_start is not None and start_date == current_start:
            self.processed_report.bill_id = self.current_bill.id
            return self.current_bill.id

        data = self._get_data_for_table(row, table_name)

        bill_id = self.report_db.insert_on_conflict_do_nothing(
            table_name, data)
        self.processed_report.bill_id = bill_id

        return bill_id

    def _create_cost_entry(self, row, bill_id):
        """Create a cost entry object.

        Args:
            row (dict): A dictionary representation of a CSV file row
            bill_id (str): The current cost entry bill id

        Returns:
            (str): The DB id of the cost entry object

        """
        table_name = AWS_CUR_TABLE_MAP['cost_entry']
        interval = row.get('identity/TimeInterval')
        start, end = self._get_cost_entry_time_interval(interval)

        if start in self.processed_report.cost_entries:
            return self.processed_report.cost_entries[start]
        elif start in self.existing_cost_entry_map:
            return self.existing_cost_entry_map[start]

        data = {
            'bill_id': bill_id,
            'interval_start': start,
            'interval_end': end
        }

        cost_entry_id = self.report_db.insert_on_conflict_do_nothing(
            table_name, data)
        self.processed_report.cost_entries[start] = cost_entry_id

        return cost_entry_id

    # pylint: disable=too-many-arguments
    def _create_cost_entry_line_item(self, row, cost_entry_id, bill_id,
                                     product_id, pricing_id, reservation_id):
        """Create a cost entry line item object.

        Args:
            row (dict): A dictionary representation of a CSV file row
            cost_entry_id (str): A processed cost entry object id
            bill_id (str): A processed cost entry bill object id
            product_id (str): A processed product object id
            pricing_id (str): A processed pricing object id
            reservation_id (str): A processed reservation object id

        Returns:
            (None)

        """
        table_name = AWS_CUR_TABLE_MAP['line_item']
        data = self._get_data_for_table(row, table_name)
        data = self.report_db.clean_data(data, table_name)

        data['tags'] = self._process_tags(row)
        data['cost_entry_id'] = cost_entry_id
        data['cost_entry_bill_id'] = bill_id
        data['cost_entry_product_id'] = product_id
        data['cost_entry_pricing_id'] = pricing_id
        data['cost_entry_reservation_id'] = reservation_id

        data_str = self._create_line_item_hash_string(data)
        data['hash'] = self.hasher.hash_string_to_hex(data_str)

        self.processed_report.line_items.append(data)

        if self.line_item_columns is None:
            self.line_item_columns = list(data.keys())

    def _create_cost_entry_pricing(self, row):
        """Create a cost entry pricing object.

        Args:
            row (dict): A dictionary representation of a CSV file row

        Returns:
            (str): The DB id of the pricing object

        """
        table_name = AWS_CUR_TABLE_MAP['pricing']

        term = row.get('pricing/term') if row.get('pricing/term') else 'None'
        unit = row.get('pricing/unit') if row.get('pricing/unit') else 'None'

        key = '{term}-{unit}'.format(term=term, unit=unit)
        if key in self.processed_report.pricing:
            return self.processed_report.pricing[key]
        elif key in self.existing_pricing_map:
            return self.existing_pricing_map[key]

        data = self._get_data_for_table(row, table_name)
        value_set = set(data.values())
        if value_set == {''}:
            return

        pricing_id = self.report_db.insert_on_conflict_do_nothing(
            table_name, data)
        self.processed_report.pricing[key] = pricing_id

        return pricing_id

    def _create_cost_entry_product(self, row):
        """Create a cost entry product object.

        Args:
            row (dict): A dictionary representation of a CSV file row

        Returns:
            (str): The DB id of the product object

        """
        table_name = AWS_CUR_TABLE_MAP['product']
        sku = row.get('product/sku')
        product_name = row.get('product/ProductName')
        region = row.get('product/region')
        key = (sku, product_name, region)

        if key in self.processed_report.products:
            return self.processed_report.products[key]
        elif key in self.existing_product_map:
            return self.existing_product_map[key]

        data = self._get_data_for_table(row, table_name)
        value_set = set(data.values())
        if value_set == {''}:
            return

        product_id = self.report_db.insert_on_conflict_do_nothing(
            table_name,
            data,
            conflict_columns=['sku', 'product_name', 'region'])
        self.processed_report.products[key] = product_id

        return product_id

    def _create_cost_entry_reservation(self, row):
        """Create a cost entry reservation object.

        Args:
            row (dict): A dictionary representation of a CSV file row

        Returns:
            (str): The DB id of the reservation object

        """
        table_name = AWS_CUR_TABLE_MAP['reservation']
        arn = row.get('reservation/ReservationARN')
        line_item_type = row.get('lineItem/LineItemType', '').lower()
        reservation_id = None

        if arn in self.processed_report.reservations:
            reservation_id = self.processed_report.reservations.get(arn)
        elif arn in self.existing_reservation_map:
            reservation_id = self.existing_reservation_map[arn]

        if reservation_id is None or line_item_type == 'rifee':
            data = self._get_data_for_table(row, table_name)
            value_set = set(data.values())
            if value_set == {''}:
                return
        else:
            return reservation_id

        # Special rows with additional reservation information
        if line_item_type == 'rifee':
            reservation_id = self.report_db.insert_on_conflict_do_update(
                table_name,
                data,
                conflict_columns=['reservation_arn'],
                set_columns=list(data.keys()))
        else:
            reservation_id = self.report_db.insert_on_conflict_do_nothing(
                table_name, data, conflict_columns=['reservation_arn'])
        self.processed_report.reservations[arn] = reservation_id

        return reservation_id

    def _get_line_item_hash_columns(self):
        """Get the column list used for creating a line item hash."""
        all_columns = self.column_map[AWS_CUR_TABLE_MAP['line_item']].values()
        # Invoice id is populated when a bill is finalized so we don't want to
        # use it to determine row uniqueness
        return [column for column in all_columns if column != 'invoice_id']

    def _create_line_item_hash_string(self, data):
        """Build the string to be hashed using line item data.

        Args:
            data (dict): The processed line item data dictionary

        Returns:
            (str): A string representation of the data

        """
        data = stringify_json_data(copy.deepcopy(data))
        data = [data.get(column, 'None') for column in self.hash_columns]
        return ':'.join(data)
コード例 #10
0
    def test_unsupported_algorithm(self):
        bad_algorithm = 'zuul'

        with self.assertRaises(HasherError):
            hasher = Hasher(hash_function=bad_algorithm)