Example #1
0
    def test_task_is_running_false(self):
        """Test that a task is not running."""
        task_list = [1, 2, 3]
        _cache = WorkerCache()
        _cache.set_host_specific_task_list(task_list)

        self.assertFalse(_cache.task_is_running(4))
Example #2
0
    def test_task_is_running_false(self):
        """Test that a task is not running."""
        task_list = [1, 2, 3]
        _cache = WorkerCache()
        for task in task_list:
            _cache.add_task_to_cache(task)

        self.assertFalse(_cache.task_is_running(4))
Example #3
0
    def test_task_is_running_true(self, mock_inspect):
        """Test that a task is running."""
        mock_worker_list = {"celery@kokuworker": ""}
        mock_inspect.reserved.return_value = mock_worker_list

        task_list = [1, 2, 3]

        _cache = WorkerCache()
        for task in task_list:
            _cache.add_task_to_cache(task)

        self.assertTrue(_cache.task_is_running(1))
Example #4
0
class ReportDownloaderBase:
    """
    Download cost reports from a provider.

    Base object class for downloading cost reports from a cloud provider.
    """

    def __init__(self, task, download_path=None, **kwargs):
        """
        Create a downloader.

        Args:
            task          (Object) bound celery object
            download_path (String) filesystem path to store downloaded files

        Kwargs:
            customer_name     (String) customer name
            access_credential (Dict) provider access credentials
            report_source     (String) cost report source
            provider_type     (String) cloud provider type
            provider_uuid     (String) cloud provider uuid
            report_name       (String) cost report name

        """
        self._task = task

        if download_path:
            self.download_path = download_path
        else:
            self.download_path = mkdtemp(prefix="masu")
        self.worker_cache = WorkerCache()
        self._cache_key = kwargs.get("cache_key")
        self._provider_uuid = kwargs.get("provider_uuid")
        self.request_id = kwargs.get("request_id")
        self.account = kwargs.get("account")
        self.context = {"request_id": self.request_id, "provider_uuid": self._provider_uuid, "account": self.account}

    def _get_existing_manifest_db_id(self, assembly_id):
        """Return a manifest DB object if it exists."""
        manifest_id = None
        with ReportManifestDBAccessor() as manifest_accessor:
            manifest = manifest_accessor.get_manifest(assembly_id, self._provider_uuid)
            if manifest:
                manifest_id = manifest.id
        return manifest_id

    def check_if_manifest_should_be_downloaded(self, assembly_id):
        """Check if we should download this manifest.

        We first check if we have a database record of this manifest.
        That would indicate that we have already downloaded and at least
        begun processing. We then check the last completed time for
        a file in this manifest. This second check is to cover the case
        when we did not complete processing and need to re-downlaod and
        process the manifest.

        Returns True if the manifest should be downloaded and processed.
        """
        if self._cache_key and self.worker_cache.task_is_running(self._cache_key):
            msg = f"{self._cache_key} is currently running."
            LOG.info(log_json(self.request_id, msg, self.context))
            return False
        with ReportManifestDBAccessor() as manifest_accessor:
            manifest = manifest_accessor.get_manifest(assembly_id, self._provider_uuid)

            if manifest:
                manifest_id = manifest.id
                # check if `last_completed_datetime` is null for any report in the manifest.
                # if nulls exist, report processing is not complete and reports should be downloaded.
                need_to_download = manifest_accessor.is_last_completed_datetime_null(manifest_id)
                if need_to_download:
                    self.worker_cache.add_task_to_cache(self._cache_key)
                return need_to_download

        # The manifest does not exist, this is the first time we are
        # downloading and processing it.
        self.worker_cache.add_task_to_cache(self._cache_key)
        return True

    def _process_manifest_db_record(self, assembly_id, billing_start, num_of_files):
        """Insert or update the manifest DB record."""
        LOG.info("Inserting/updating manifest in database for assembly_id: %s", assembly_id)

        with ReportManifestDBAccessor() as manifest_accessor:
            manifest_entry = manifest_accessor.get_manifest(assembly_id, self._provider_uuid)

            if not manifest_entry:
                msg = f"No manifest entry found in database. Adding for bill period start: {billing_start}"
                LOG.info(log_json(self.request_id, msg, self.context))
                manifest_dict = {
                    "assembly_id": assembly_id,
                    "billing_period_start_datetime": billing_start,
                    "num_total_files": num_of_files,
                    "provider_uuid": self._provider_uuid,
                    "task": self._task.request.id,
                }
                manifest_entry = manifest_accessor.add(**manifest_dict)

            manifest_accessor.mark_manifest_as_updated(manifest_entry)
            manifest_id = manifest_entry.id

        return manifest_id
Example #5
0
class Orchestrator:
    """
    Orchestrator for report processing.

    Top level object which is responsible for:
    * Maintaining a current list of accounts
    * Ensuring that reports are downloaded and processed for all accounts.

    """
    def __init__(self, billing_source=None, provider_uuid=None):
        """
        Orchestrator for processing.

        Args:
            billing_source (String): Individual account to retrieve.

        """
        self._accounts, self._polling_accounts = self.get_accounts(
            billing_source, provider_uuid)
        self.worker_cache = WorkerCache()

    @staticmethod
    def get_accounts(billing_source=None, provider_uuid=None):
        """
        Prepare a list of accounts for the orchestrator to get CUR from.

        If billing_source is not provided all accounts will be returned, otherwise
        only the account for the provided billing_source will be returned.

        Still a work in progress, but works for now.

        Args:
            billing_source (String): Individual account to retrieve.

        Returns:
            [CostUsageReportAccount] (all), [CostUsageReportAccount] (polling only)

        """
        all_accounts = []
        polling_accounts = []
        try:
            all_accounts = AccountsAccessor().get_accounts(provider_uuid)
        except AccountsAccessorError as error:
            LOG.error("Unable to get accounts. Error: %s", str(error))

        if billing_source:
            for account in all_accounts:
                if billing_source == account.get("billing_source"):
                    all_accounts = [account]

        for account in all_accounts:
            if AccountsAccessor().is_polling_account(account):
                polling_accounts.append(account)

        return all_accounts, polling_accounts

    @staticmethod
    def get_reports(provider_uuid):
        """
        Get months for provider to process.

        Args:
            (String) provider uuid to determine if initial setup is complete.

        Returns:
            (List) List of datetime objects.

        """
        with ProviderDBAccessor(
                provider_uuid=provider_uuid) as provider_accessor:
            reports_processed = provider_accessor.get_setup_complete()

        if Config.INGEST_OVERRIDE or not reports_processed:
            number_of_months = Config.INITIAL_INGEST_NUM_MONTHS
        else:
            number_of_months = 2

        return DateAccessor().get_billing_months(number_of_months)

    def start_manifest_processing(self, customer_name, credentials,
                                  data_source, provider_type, schema_name,
                                  provider_uuid, report_month):
        """
        Start processing an account's manifest for the specified report_month.

        Args:
            (String) customer_name - customer name
            (String) credentials - credentials object
            (String) data_source - report storage location
            (String) schema_name - db tenant
            (String) provider_uuid - provider unique identifier
            (Date)   report_month - month to get latest manifest

        Returns:
            ({}) Dictionary containing the following keys:
                manifest_id - (String): Manifest ID for ReportManifestDBAccessor
                assembly_id - (String): UUID identifying report file
                compression - (String): Report compression format
                files       - ([{"key": full_file_path "local_file": "local file name"}]): List of report files.
        """
        downloader = ReportDownloader(
            customer_name=customer_name,
            credentials=credentials,
            data_source=data_source,
            provider_type=provider_type,
            provider_uuid=provider_uuid,
            report_name=None,
        )
        manifest = downloader.download_manifest(report_month)

        if manifest:
            LOG.info("Saving all manifest file names.")
            record_all_manifest_files(manifest["manifest_id"], [
                report.get("local_file")
                for report in manifest.get("files", [])
            ])

        LOG.info(f"Found Manifests: {str(manifest)}")
        report_files = manifest.get("files", [])
        report_tasks = []
        for report_file_dict in report_files:
            local_file = report_file_dict.get("local_file")
            report_file = report_file_dict.get("key")

            # Check if report file is complete or in progress.
            if record_report_status(manifest["manifest_id"], local_file,
                                    "no_request"):
                LOG.info(f"{local_file} was already processed")
                continue

            cache_key = f"{provider_uuid}:{report_file}"
            if self.worker_cache.task_is_running(cache_key):
                LOG.info(f"{local_file} process is in progress")
                continue

            report_context = manifest.copy()
            report_context["current_file"] = report_file
            report_context["local_file"] = local_file
            report_context["key"] = report_file

            report_tasks.append(
                get_report_files.s(
                    customer_name,
                    credentials,
                    data_source,
                    provider_type,
                    schema_name,
                    provider_uuid,
                    report_month,
                    report_context,
                ))
            LOG.info("Download queued - schema_name: %s.", schema_name)

        if report_tasks:
            async_id = chord(report_tasks, summarize_reports.s())()
            LOG.info(f"Manifest Processing Async ID: {async_id}")
        return manifest

    def prepare(self):
        """
        Prepare a processing request for each account.

        Scans the database for providers that have reports that need to be processed.
        Any report it finds is queued to the appropriate celery task to download
        and process those reports.

        Args:
            None

        Returns:
            (celery.result.AsyncResult) Async result for download request.

        """
        async_result = None
        for account in self._polling_accounts:
            provider_uuid = account.get("provider_uuid")
            report_months = self.get_reports(provider_uuid)
            for month in report_months:
                LOG.info(
                    "Getting %s report files for account (provider uuid): %s",
                    month.strftime("%B %Y"), provider_uuid)
                account["report_month"] = month
                try:
                    self.start_manifest_processing(**account)
                except ReportDownloaderError as err:
                    LOG.warning(
                        f"Unable to download manifest for provider: {provider_uuid}. Error: {str(err)}."
                    )
                    continue
                except Exception as err:
                    # Broad exception catching is important here because any errors thrown can
                    # block all subsequent account processing.
                    LOG.error(
                        f"Unexpected manifest processing error for provider: {provider_uuid}. Error: {str(err)}."
                    )
                    continue

                # update labels
                labeler = AccountLabel(
                    auth=account.get("credentials"),
                    schema=account.get("schema_name"),
                    provider_type=account.get("provider_type"),
                )
                account_number, label = labeler.get_label_details()
                if account_number:
                    LOG.info("Account: %s Label: %s updated.", account_number,
                             label)

        return async_result

    def remove_expired_report_data(self,
                                   simulate=False,
                                   line_items_only=False):
        """
        Remove expired report data for each account.

        Args:
            simulate (Boolean) Simulate report data removal

        Returns:
            (celery.result.AsyncResult) Async result for deletion request.

        """
        async_results = []
        for account in self._accounts:
            LOG.info("Calling remove_expired_data with account: %s", account)
            async_result = remove_expired_data.delay(
                schema_name=account.get("schema_name"),
                provider=account.get("provider_type"),
                simulate=simulate,
                line_items_only=line_items_only,
            )
            LOG.info(
                "Expired data removal queued - schema_name: %s, Task ID: %s",
                account.get("schema_name"),
                str(async_result),
            )
            async_results.append({
                "customer": account.get("customer_name"),
                "async_id": str(async_result)
            })
        return async_results
Example #6
0
class ReportDownloaderBase:
    """
    Download cost reports from a provider.

    Base object class for downloading cost reports from a cloud provider.
    """

    # pylint: disable=unused-argument
    def __init__(self, task, download_path=None, **kwargs):
        """
        Create a downloader.

        Args:
            task          (Object) bound celery object
            download_path (String) filesystem path to store downloaded files

        Kwargs:
            customer_name     (String) customer name
            access_credential (Dict) provider access credentials
            report_source     (String) cost report source
            provider_type     (String) cloud provider type
            provider_uuid     (String) cloud provider uuid
            report_name       (String) cost report name

        """
        self._task = task

        if download_path:
            self.download_path = download_path
        else:
            self.download_path = mkdtemp(prefix="masu")
        self.worker_cache = WorkerCache()
        self._cache_key = kwargs.get("cache_key")
        self._provider_uuid = None
        self._provider_uuid = kwargs.get("provider_uuid")

    def _get_existing_manifest_db_id(self, assembly_id):
        """Return a manifest DB object if it exists."""
        manifest_id = None
        with ReportManifestDBAccessor() as manifest_accessor:
            manifest = manifest_accessor.get_manifest(assembly_id,
                                                      self._provider_uuid)
            if manifest:
                manifest_id = manifest.id
        return manifest_id

    def check_if_manifest_should_be_downloaded(self, assembly_id):
        """Check if we should download this manifest.

        We first check if we have a database record of this manifest.
        That would indicate that we have already downloaded and at least
        begun processing. We then check the last completed time for
        a file in this manifest. This second check is to cover the case
        when we did not complete processing and need to re-downlaod and
        process the manifest.

        Returns True if the manifest should be downloaded and processed.
        """
        if self._cache_key and self.worker_cache.task_is_running(
                self._cache_key):
            msg = f"{self._cache_key} is currently running."
            LOG.info(msg)
            return False
        today = DateAccessor().today_with_timezone("UTC")
        last_completed_cutoff = today - datetime.timedelta(hours=1)
        with ReportManifestDBAccessor() as manifest_accessor:
            manifest = manifest_accessor.get_manifest(assembly_id,
                                                      self._provider_uuid)

            if manifest:

                manifest_id = manifest.id
                num_processed_files = manifest.num_processed_files
                num_total_files = manifest.num_total_files
                if num_processed_files < num_total_files:
                    completed_datetime = manifest_accessor.get_last_report_completed_datetime(
                        manifest_id)
                    if (completed_datetime and completed_datetime <
                            last_completed_cutoff) or not completed_datetime:
                        # It has been more than an hour since we processed a file
                        # and we didn't finish processing. Or, if there is a
                        # start time but no completion time recorded.
                        # We should download and reprocess.
                        manifest_accessor.reset_manifest(manifest_id)
                        self.worker_cache.add_task_to_cache(self._cache_key)
                        return True
                # The manifest exists and we have processed all the files.
                # We should not redownload.
                return False
        # The manifest does not exist, this is the first time we are
        # downloading and processing it.
        self.worker_cache.add_task_to_cache(self._cache_key)
        return True

    def _process_manifest_db_record(self, assembly_id, billing_start,
                                    num_of_files):
        """Insert or update the manifest DB record."""
        LOG.info("Inserting manifest database record for assembly_id: %s",
                 assembly_id)

        with ReportManifestDBAccessor() as manifest_accessor:
            manifest_entry = manifest_accessor.get_manifest(
                assembly_id, self._provider_uuid)

            if not manifest_entry:
                LOG.info(
                    "No manifest entry found.  Adding for bill period start: %s",
                    billing_start)
                manifest_dict = {
                    "assembly_id": assembly_id,
                    "billing_period_start_datetime": billing_start,
                    "num_total_files": num_of_files,
                    "provider_uuid": self._provider_uuid,
                    "task": self._task.request.id,
                }
                manifest_entry = manifest_accessor.add(**manifest_dict)

            manifest_accessor.mark_manifest_as_updated(manifest_entry)
            manifest_id = manifest_entry.id

        return manifest_id
Example #7
0
class Orchestrator:
    """
    Orchestrator for report processing.

    Top level object which is responsible for:
    * Maintaining a current list of accounts
    * Ensuring that reports are downloaded and processed for all accounts.

    """
    def __init__(self,
                 billing_source=None,
                 provider_uuid=None,
                 bill_date=None,
                 queue_name=None):
        """
        Orchestrator for processing.

        Args:
            billing_source (String): Individual account to retrieve.

        """
        self.worker_cache = WorkerCache()
        self.billing_source = billing_source
        self.bill_date = bill_date
        self.provider_uuid = provider_uuid
        self.queue_name = queue_name
        self._accounts, self._polling_accounts = self.get_accounts(
            self.billing_source, self.provider_uuid)

    @staticmethod
    def get_accounts(billing_source=None, provider_uuid=None):
        """
        Prepare a list of accounts for the orchestrator to get CUR from.

        If billing_source is not provided all accounts will be returned, otherwise
        only the account for the provided billing_source will be returned.

        Still a work in progress, but works for now.

        Args:
            billing_source (String): Individual account to retrieve.

        Returns:
            [CostUsageReportAccount] (all), [CostUsageReportAccount] (polling only)

        """
        all_accounts = []
        polling_accounts = []
        try:
            all_accounts = AccountsAccessor().get_accounts(provider_uuid)
        except AccountsAccessorError as error:
            LOG.error("Unable to get accounts. Error: %s", str(error))

        if billing_source:
            for account in all_accounts:
                if billing_source == account.get("billing_source"):
                    all_accounts = [account]

        for account in all_accounts:
            if AccountsAccessor().is_polling_account(account):
                polling_accounts.append(account)

        return all_accounts, polling_accounts

    def get_reports(self, provider_uuid):
        """
        Get months for provider to process.

        Args:
            (String) provider uuid to determine if initial setup is complete.

        Returns:
            (List) List of datetime objects.

        """
        with ProviderDBAccessor(
                provider_uuid=provider_uuid) as provider_accessor:
            reports_processed = provider_accessor.get_setup_complete()

        if self.bill_date:
            return [DateAccessor().get_billing_month_start(self.bill_date)]

        if Config.INGEST_OVERRIDE or not reports_processed:
            number_of_months = Config.INITIAL_INGEST_NUM_MONTHS
        else:
            number_of_months = 2

        return sorted(DateAccessor().get_billing_months(number_of_months),
                      reverse=True)

    def start_manifest_processing(self, customer_name, credentials,
                                  data_source, provider_type, schema_name,
                                  provider_uuid, report_month):
        """
        Start processing an account's manifest for the specified report_month.

        Args:
            (String) customer_name - customer name
            (String) credentials - credentials object
            (String) data_source - report storage location
            (String) schema_name - db tenant
            (String) provider_uuid - provider unique identifier
            (Date)   report_month - month to get latest manifest

        Returns:
            ({}) Dictionary containing the following keys:
                manifest_id - (String): Manifest ID for ReportManifestDBAccessor
                assembly_id - (String): UUID identifying report file
                compression - (String): Report compression format
                files       - ([{"key": full_file_path "local_file": "local file name"}]): List of report files.
            (Boolean) - Whether we are processing this manifest
        """
        # Switching initial ingest to use priority queue for QE tests based on QE_SCHEMA flag
        if self.queue_name is not None and self.provider_uuid is not None:
            SUMMARY_QUEUE = self.queue_name
            REPORT_QUEUE = self.queue_name
        else:
            SUMMARY_QUEUE = SUMMARIZE_REPORTS_QUEUE
            REPORT_QUEUE = GET_REPORT_FILES_QUEUE
        reports_tasks_queued = False
        downloader = ReportDownloader(
            customer_name=customer_name,
            credentials=credentials,
            data_source=data_source,
            provider_type=provider_type,
            provider_uuid=provider_uuid,
            report_name=None,
        )
        manifest = downloader.download_manifest(report_month)
        tracing_id = manifest.get("assembly_id",
                                  manifest.get("request_id", "no-request-id"))
        files = manifest.get("files", [])
        filenames = []
        for file in files:
            filenames.append(file.get("local_file"))
        LOG.info(
            log_json(
                tracing_id,
                f"Report with manifest {tracing_id} contains the files: {filenames}"
            ))

        if manifest:
            LOG.debug("Saving all manifest file names.")
            record_all_manifest_files(manifest["manifest_id"], [
                report.get("local_file")
                for report in manifest.get("files", [])
            ], tracing_id)

        LOG.info(log_json(tracing_id, f"Found Manifests: {str(manifest)}"))
        report_files = manifest.get("files", [])
        report_tasks = []
        last_report_index = len(report_files) - 1
        for i, report_file_dict in enumerate(report_files):
            local_file = report_file_dict.get("local_file")
            report_file = report_file_dict.get("key")

            # Check if report file is complete or in progress.
            if record_report_status(manifest["manifest_id"], local_file,
                                    "no_request"):
                LOG.info(
                    log_json(tracing_id,
                             f"{local_file} was already processed"))
                continue

            cache_key = f"{provider_uuid}:{report_file}"
            if self.worker_cache.task_is_running(cache_key):
                LOG.info(
                    log_json(tracing_id,
                             f"{local_file} process is in progress"))
                continue

            report_context = manifest.copy()
            report_context["current_file"] = report_file
            report_context["local_file"] = local_file
            report_context["key"] = report_file
            report_context["request_id"] = tracing_id

            if provider_type in [Provider.PROVIDER_OCP, Provider.PROVIDER_GCP
                                 ] or i == last_report_index:
                # This create_table flag is used by the ParquetReportProcessor
                # to create a Hive/Trino table.
                # To reduce the number of times we check Trino/Hive tables, we just do this
                # on the final file of the set.
                report_context["create_table"] = True
            # add the tracing id to the report context
            # This defaults to the celery queue
            report_tasks.append(
                get_report_files.s(
                    customer_name,
                    credentials,
                    data_source,
                    provider_type,
                    schema_name,
                    provider_uuid,
                    report_month,
                    report_context,
                ).set(queue=REPORT_QUEUE))
            LOG.info(
                log_json(tracing_id,
                         f"Download queued - schema_name: {schema_name}."))

        if report_tasks:
            reports_tasks_queued = True
            async_id = chord(report_tasks,
                             summarize_reports.s().set(queue=SUMMARY_QUEUE))()
            LOG.debug(
                log_json(tracing_id,
                         f"Manifest Processing Async ID: {async_id}"))
        return manifest, reports_tasks_queued

    def prepare(self):
        """
        Prepare a processing request for each account.

        Scans the database for providers that have reports that need to be processed.
        Any report it finds is queued to the appropriate celery task to download
        and process those reports.

        Args:
            None

        Returns:
            (celery.result.AsyncResult) Async result for download request.

        """
        for account in self._polling_accounts:
            accounts_labeled = False
            provider_uuid = account.get("provider_uuid")
            report_months = self.get_reports(provider_uuid)
            for month in report_months:
                LOG.info(
                    "Getting %s report files for account (provider uuid): %s",
                    month.strftime("%B %Y"), provider_uuid)
                account["report_month"] = month
                try:
                    _, reports_tasks_queued = self.start_manifest_processing(
                        **account)
                except ReportDownloaderError as err:
                    LOG.warning(
                        f"Unable to download manifest for provider: {provider_uuid}. Error: {str(err)}."
                    )
                    continue
                except Exception as err:
                    # Broad exception catching is important here because any errors thrown can
                    # block all subsequent account processing.
                    LOG.error(
                        f"Unexpected manifest processing error for provider: {provider_uuid}. Error: {str(err)}."
                    )
                    continue

                # update labels
                if reports_tasks_queued and not accounts_labeled:
                    LOG.info("Running AccountLabel to get account aliases.")
                    labeler = AccountLabel(
                        auth=account.get("credentials"),
                        schema=account.get("schema_name"),
                        provider_type=account.get("provider_type"),
                    )
                    account_number, label = labeler.get_label_details()
                    accounts_labeled = True
                    if account_number:
                        LOG.info("Account: %s Label: %s updated.",
                                 account_number, label)

        return

    def remove_expired_report_data(self, simulate=False):
        """
        Remove expired report data for each account.

        Args:
            simulate (Boolean) Simulate report data removal

        Returns:
            (celery.result.AsyncResult) Async result for deletion request.

        """
        async_results = []
        for account in self._accounts:
            LOG.info("Calling remove_expired_data with account: %s", account)
            async_result = remove_expired_data.delay(
                schema_name=account.get("schema_name"),
                provider=account.get("provider_type"),
                simulate=simulate)
            LOG.info(
                "Expired data removal queued - schema_name: %s, Task ID: %s",
                account.get("schema_name"),
                str(async_result),
            )
            async_results.append({
                "customer": account.get("customer_name"),
                "async_id": str(async_result)
            })
        return async_results