def test_export_data(self, shared_state): """Get an existing dataset, export data to a newly created folder in Google Cloud Storage, then verify data was successfully exported.""" assert shared_state["staging_bucket"] assert shared_state["storage_client"] aiplatform.init( project=_TEST_PROJECT, location=_TEST_LOCATION, staging_bucket=f"gs://{shared_state['staging_bucket']}", ) text_dataset = aiplatform.TextDataset(dataset_name=_TEST_TEXT_DATASET_ID) exported_files = text_dataset.export_data( output_dir=f"gs://{shared_state['staging_bucket']}" ) assert len(exported_files) # Ensure at least one GCS path was returned exported_file = exported_files[0] bucket, prefix = utils.extract_bucket_and_prefix_from_gcs_path(exported_file) storage_client = shared_state["storage_client"] bucket = storage_client.get_bucket(bucket) blob = bucket.get_blob(prefix) assert blob # Verify the returned GCS export path exists
def _retrieve_gcs_source_columns( project: str, gcs_csv_file_path: str, credentials: Optional[auth_credentials.Credentials] = None, ) -> List[str]: """Retrieve the columns from a comma-delimited CSV file stored on Google Cloud Storage Example Usage: column_names = _retrieve_gcs_source_columns( "project_id", "gs://example-bucket/path/to/csv_file" ) # column_names = ["column_1", "column_2"] Args: project (str): Required. Project to initiate the Google Cloud Storage client with. gcs_csv_file_path (str): Required. A full path to a CSV files stored on Google Cloud Storage. Must include "gs://" prefix. credentials (auth_credentials.Credentials): Credentials to use to with GCS Client. Returns: List[str] A list of columns names in the CSV file. Raises: RuntimeError: When the retrieved CSV file is invalid. """ gcs_bucket, gcs_blob = utils.extract_bucket_and_prefix_from_gcs_path( gcs_csv_file_path) client = storage.Client(project=project, credentials=credentials) bucket = client.bucket(gcs_bucket) blob = bucket.blob(gcs_blob) # Incrementally download the CSV file until the header is retrieved first_new_line_index = -1 start_index = 0 increment = 1000 line = "" try: logger = logging.getLogger("google.resumable_media._helpers") logging_warning_filter = utils.LoggingFilter(logging.INFO) logger.addFilter(logging_warning_filter) while first_new_line_index == -1: line += blob.download_as_bytes(start=start_index, end=start_index + increment).decode("utf-8") first_new_line_index = line.find("\n") start_index += increment header_line = line[:first_new_line_index] # Split to make it an iterable header_line = header_line.split("\n")[:1] csv_reader = csv.reader(header_line, delimiter=",") except (ValueError, RuntimeError) as err: raise RuntimeError( "There was a problem extracting the headers from the CSV file at '{}': {}" .format(gcs_csv_file_path, err)) finally: logger.removeFilter(logging_warning_filter) return next(csv_reader)
def iter_outputs( self, bq_max_results: Optional[int] = 100 ) -> Union[Iterable[storage.Blob], Iterable[bigquery.table.RowIterator]]: """Returns an Iterable object to traverse the output files, either a list of GCS Blobs or a BigQuery RowIterator depending on the output config set when the BatchPredictionJob was created. Args: bq_max_results: Optional[int] = 100 Limit on rows to retrieve from prediction table in BigQuery dataset. Only used when retrieving predictions from a bigquery_destination_prefix. Default is 100. Returns: Union[Iterable[storage.Blob], Iterable[bigquery.table.RowIterator]]: Either a list of GCS Blob objects within the prediction output directory or an iterable BigQuery RowIterator with predictions. Raises: RuntimeError: If BatchPredictionJob is in a JobState other than SUCCEEDED, since outputs cannot be retrieved until the Job has finished. NotImplementedError: If BatchPredictionJob succeeded and output_info does not have a GCS or BQ output provided. """ if self.state != gca_job_state.JobState.JOB_STATE_SUCCEEDED: raise RuntimeError( f"Cannot read outputs until BatchPredictionJob has succeeded, " f"current state: {self._gca_resource.state}") output_info = self._gca_resource.output_info # GCS Destination, return Blobs if output_info.gcs_output_directory: # Build a Storage Client using the same credentials as JobServiceClient storage_client = storage.Client( credentials=self.api_client._transport._credentials) gcs_bucket, gcs_prefix = utils.extract_bucket_and_prefix_from_gcs_path( output_info.gcs_output_directory) blobs = storage_client.list_blobs(gcs_bucket, prefix=gcs_prefix) return blobs # BigQuery Destination, return RowIterator elif output_info.bigquery_output_dataset: # Build a BigQuery Client using the same credentials as JobServiceClient bq_client = bigquery.Client( credentials=self.api_client._transport._credentials) # Format from service is `bq://projectId.bqDatasetId` bq_dataset = output_info.bigquery_output_dataset if bq_dataset.startswith("bq://"): bq_dataset = bq_dataset[5:] # # Split project ID and BQ dataset ID _, bq_dataset_id = bq_dataset.split(".", 1) row_iterator = bq_client.list_rows( table=f"{bq_dataset_id}.predictions", max_results=bq_max_results) return row_iterator # Unknown Destination type else: raise NotImplementedError( f"Unsupported batch prediction output location, here are details" f"on your prediction output:\n{output_info}")
def test_extract_bucket_and_prefix_from_gcs_path(gcs_path: str, expected: tuple): # Given a GCS path, ensure correct bucket and prefix are extracted assert expected == utils.extract_bucket_and_prefix_from_gcs_path(gcs_path)