Ejemplo n.º 1
0
    def test_export_data(self, shared_state):
        """Get an existing dataset, export data to a newly created folder in
        Google Cloud Storage, then verify data was successfully exported."""

        assert shared_state["staging_bucket"]
        assert shared_state["storage_client"]

        aiplatform.init(
            project=_TEST_PROJECT,
            location=_TEST_LOCATION,
            staging_bucket=f"gs://{shared_state['staging_bucket']}",
        )

        text_dataset = aiplatform.TextDataset(dataset_name=_TEST_TEXT_DATASET_ID)

        exported_files = text_dataset.export_data(
            output_dir=f"gs://{shared_state['staging_bucket']}"
        )

        assert len(exported_files)  # Ensure at least one GCS path was returned

        exported_file = exported_files[0]
        bucket, prefix = utils.extract_bucket_and_prefix_from_gcs_path(exported_file)

        storage_client = shared_state["storage_client"]

        bucket = storage_client.get_bucket(bucket)
        blob = bucket.get_blob(prefix)

        assert blob  # Verify the returned GCS export path exists
    def _retrieve_gcs_source_columns(
        project: str,
        gcs_csv_file_path: str,
        credentials: Optional[auth_credentials.Credentials] = None,
    ) -> List[str]:
        """Retrieve the columns from a comma-delimited CSV file stored on Google Cloud Storage

        Example Usage:

            column_names = _retrieve_gcs_source_columns(
                "project_id",
                "gs://example-bucket/path/to/csv_file"
            )

            # column_names = ["column_1", "column_2"]

        Args:
            project (str):
                Required. Project to initiate the Google Cloud Storage client with.
            gcs_csv_file_path (str):
                Required. A full path to a CSV files stored on Google Cloud Storage.
                Must include "gs://" prefix.
            credentials (auth_credentials.Credentials):
                Credentials to use to with GCS Client.
        Returns:
            List[str]
                A list of columns names in the CSV file.

        Raises:
            RuntimeError: When the retrieved CSV file is invalid.
        """

        gcs_bucket, gcs_blob = utils.extract_bucket_and_prefix_from_gcs_path(
            gcs_csv_file_path)
        client = storage.Client(project=project, credentials=credentials)
        bucket = client.bucket(gcs_bucket)
        blob = bucket.blob(gcs_blob)

        # Incrementally download the CSV file until the header is retrieved
        first_new_line_index = -1
        start_index = 0
        increment = 1000
        line = ""

        try:
            logger = logging.getLogger("google.resumable_media._helpers")
            logging_warning_filter = utils.LoggingFilter(logging.INFO)
            logger.addFilter(logging_warning_filter)

            while first_new_line_index == -1:
                line += blob.download_as_bytes(start=start_index,
                                               end=start_index +
                                               increment).decode("utf-8")

                first_new_line_index = line.find("\n")
                start_index += increment

            header_line = line[:first_new_line_index]

            # Split to make it an iterable
            header_line = header_line.split("\n")[:1]

            csv_reader = csv.reader(header_line, delimiter=",")
        except (ValueError, RuntimeError) as err:
            raise RuntimeError(
                "There was a problem extracting the headers from the CSV file at '{}': {}"
                .format(gcs_csv_file_path, err))
        finally:
            logger.removeFilter(logging_warning_filter)

        return next(csv_reader)
Ejemplo n.º 3
0
    def iter_outputs(
        self,
        bq_max_results: Optional[int] = 100
    ) -> Union[Iterable[storage.Blob], Iterable[bigquery.table.RowIterator]]:
        """Returns an Iterable object to traverse the output files, either a list
        of GCS Blobs or a BigQuery RowIterator depending on the output config set
        when the BatchPredictionJob was created.

        Args:
            bq_max_results: Optional[int] = 100
                Limit on rows to retrieve from prediction table in BigQuery dataset.
                Only used when retrieving predictions from a bigquery_destination_prefix.
                Default is 100.

        Returns:
            Union[Iterable[storage.Blob], Iterable[bigquery.table.RowIterator]]:
                Either a list of GCS Blob objects within the prediction output
                directory or an iterable BigQuery RowIterator with predictions.

        Raises:
            RuntimeError:
                If BatchPredictionJob is in a JobState other than SUCCEEDED,
                since outputs cannot be retrieved until the Job has finished.
            NotImplementedError:
                If BatchPredictionJob succeeded and output_info does not have a
                GCS or BQ output provided.
        """

        if self.state != gca_job_state.JobState.JOB_STATE_SUCCEEDED:
            raise RuntimeError(
                f"Cannot read outputs until BatchPredictionJob has succeeded, "
                f"current state: {self._gca_resource.state}")

        output_info = self._gca_resource.output_info

        # GCS Destination, return Blobs
        if output_info.gcs_output_directory:

            # Build a Storage Client using the same credentials as JobServiceClient
            storage_client = storage.Client(
                credentials=self.api_client._transport._credentials)

            gcs_bucket, gcs_prefix = utils.extract_bucket_and_prefix_from_gcs_path(
                output_info.gcs_output_directory)

            blobs = storage_client.list_blobs(gcs_bucket, prefix=gcs_prefix)

            return blobs

        # BigQuery Destination, return RowIterator
        elif output_info.bigquery_output_dataset:

            # Build a BigQuery Client using the same credentials as JobServiceClient
            bq_client = bigquery.Client(
                credentials=self.api_client._transport._credentials)

            # Format from service is `bq://projectId.bqDatasetId`
            bq_dataset = output_info.bigquery_output_dataset

            if bq_dataset.startswith("bq://"):
                bq_dataset = bq_dataset[5:]

            # # Split project ID and BQ dataset ID
            _, bq_dataset_id = bq_dataset.split(".", 1)

            row_iterator = bq_client.list_rows(
                table=f"{bq_dataset_id}.predictions",
                max_results=bq_max_results)

            return row_iterator

        # Unknown Destination type
        else:
            raise NotImplementedError(
                f"Unsupported batch prediction output location, here are details"
                f"on your prediction output:\n{output_info}")
Ejemplo n.º 4
0
def test_extract_bucket_and_prefix_from_gcs_path(gcs_path: str, expected: tuple):
    # Given a GCS path, ensure correct bucket and prefix are extracted
    assert expected == utils.extract_bucket_and_prefix_from_gcs_path(gcs_path)