def export_data(self, output_dir: str) -> Sequence[str]: """Exports data to output dir to GCS. Args: output_dir (str): Required. The Google Cloud Storage location where the output is to be written to. In the given directory a new directory will be created with name: ``export-data-<dataset-display-name>-<timestamp-of-export-call>`` where timestamp is in YYYYMMDDHHMMSS format. All export output will be written into that directory. Inside that directory, annotations with the same schema will be grouped into sub directories which are named with the corresponding annotations' schema title. Inside these sub directories, a schema.yaml will be created to describe the output format. If the uri doesn't end with '/', a '/' will be automatically appended. The directory is created if it doesn't exist. Returns: exported_files (Sequence[str]): All of the files that are exported in this export operation. """ self.wait() # TODO(b/171311614): Add support for BiqQuery export path export_data_config = gca_dataset.ExportDataConfig( gcs_destination=gca_io.GcsDestination( output_uri_prefix=output_dir)) _LOGGER.log_action_start_against_resource("Exporting", "data", self) export_lro = self.api_client.export_data( name=self.resource_name, export_config=export_data_config) _LOGGER.log_action_started_against_resource_with_lro( "Export", "data", self.__class__, export_lro) export_data_response = export_lro.result() _LOGGER.log_action_completed_against_resource("data", "export", self) return export_data_response.exported_files
def test_batch_predict_gcs_source_and_dest_with_timeout( self, create_batch_prediction_job_mock, sync ): aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) # Make SDK batch_predict method call batch_prediction_job = jobs.BatchPredictionJob.create( model_name=_TEST_MODEL_NAME, job_display_name=_TEST_BATCH_PREDICTION_JOB_DISPLAY_NAME, gcs_source=_TEST_BATCH_PREDICTION_GCS_SOURCE, gcs_destination_prefix=_TEST_BATCH_PREDICTION_GCS_DEST_PREFIX, sync=sync, create_request_timeout=180.0, ) batch_prediction_job.wait_for_resource_creation() batch_prediction_job.wait() # Construct expected request expected_gapic_batch_prediction_job = gca_batch_prediction_job_compat.BatchPredictionJob( display_name=_TEST_BATCH_PREDICTION_JOB_DISPLAY_NAME, model=_TEST_MODEL_NAME, input_config=gca_batch_prediction_job_compat.BatchPredictionJob.InputConfig( instances_format="jsonl", gcs_source=gca_io_compat.GcsSource( uris=[_TEST_BATCH_PREDICTION_GCS_SOURCE] ), ), output_config=gca_batch_prediction_job_compat.BatchPredictionJob.OutputConfig( gcs_destination=gca_io_compat.GcsDestination( output_uri_prefix=_TEST_BATCH_PREDICTION_GCS_DEST_PREFIX ), predictions_format="jsonl", ), ) create_batch_prediction_job_mock.assert_called_once_with( parent=_TEST_PARENT, batch_prediction_job=expected_gapic_batch_prediction_job, timeout=180.0, )
def test_batch_predict_with_all_args( self, create_batch_prediction_job_with_explanations_mock, sync ): aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) creds = auth_credentials.AnonymousCredentials() batch_prediction_job = jobs.BatchPredictionJob.create( model_name=_TEST_MODEL_NAME, job_display_name=_TEST_BATCH_PREDICTION_JOB_DISPLAY_NAME, gcs_source=_TEST_BATCH_PREDICTION_GCS_SOURCE, gcs_destination_prefix=_TEST_BATCH_PREDICTION_GCS_DEST_PREFIX, predictions_format="csv", model_parameters={}, machine_type=_TEST_MACHINE_TYPE, accelerator_type=_TEST_ACCELERATOR_TYPE, accelerator_count=_TEST_ACCELERATOR_COUNT, starting_replica_count=_TEST_STARTING_REPLICA_COUNT, max_replica_count=_TEST_MAX_REPLICA_COUNT, generate_explanation=True, explanation_metadata=_TEST_EXPLANATION_METADATA, explanation_parameters=_TEST_EXPLANATION_PARAMETERS, labels=_TEST_LABEL, credentials=creds, sync=sync, create_request_timeout=None, ) batch_prediction_job.wait_for_resource_creation() batch_prediction_job.wait() # Construct expected request expected_gapic_batch_prediction_job = gca_batch_prediction_job_compat.BatchPredictionJob( display_name=_TEST_BATCH_PREDICTION_JOB_DISPLAY_NAME, model=_TEST_MODEL_NAME, input_config=gca_batch_prediction_job_compat.BatchPredictionJob.InputConfig( instances_format="jsonl", gcs_source=gca_io_compat.GcsSource( uris=[_TEST_BATCH_PREDICTION_GCS_SOURCE] ), ), output_config=gca_batch_prediction_job_compat.BatchPredictionJob.OutputConfig( gcs_destination=gca_io_compat.GcsDestination( output_uri_prefix=_TEST_BATCH_PREDICTION_GCS_DEST_PREFIX ), predictions_format="csv", ), dedicated_resources=gca_machine_resources_compat.BatchDedicatedResources( machine_spec=gca_machine_resources_compat.MachineSpec( machine_type=_TEST_MACHINE_TYPE, accelerator_type=_TEST_ACCELERATOR_TYPE, accelerator_count=_TEST_ACCELERATOR_COUNT, ), starting_replica_count=_TEST_STARTING_REPLICA_COUNT, max_replica_count=_TEST_MAX_REPLICA_COUNT, ), generate_explanation=True, explanation_spec=gca_explanation_compat.ExplanationSpec( metadata=_TEST_EXPLANATION_METADATA, parameters=_TEST_EXPLANATION_PARAMETERS, ), labels=_TEST_LABEL, ) create_batch_prediction_job_with_explanations_mock.assert_called_once_with( parent=f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}", batch_prediction_job=expected_gapic_batch_prediction_job, timeout=None, )
f"bq://{_TEST_BATCH_PREDICTION_BQ_PREFIX}" ) _TEST_JOB_STATE_SUCCESS = gca_job_state_compat.JobState(4) _TEST_JOB_STATE_RUNNING = gca_job_state_compat.JobState(3) _TEST_JOB_STATE_PENDING = gca_job_state_compat.JobState(2) _TEST_GCS_INPUT_CONFIG = gca_batch_prediction_job_compat.BatchPredictionJob.InputConfig( instances_format="jsonl", gcs_source=gca_io_compat.GcsSource(uris=[_TEST_GCS_JSONL_SOURCE_URI]), ) _TEST_GCS_OUTPUT_CONFIG = ( gca_batch_prediction_job_compat.BatchPredictionJob.OutputConfig( predictions_format="jsonl", gcs_destination=gca_io_compat.GcsDestination( output_uri_prefix=_TEST_GCS_BUCKET_PATH ), ) ) _TEST_BQ_INPUT_CONFIG = gca_batch_prediction_job_compat.BatchPredictionJob.InputConfig( instances_format="bigquery", bigquery_source=gca_io_compat.BigQuerySource(input_uri=_TEST_BQ_PATH), ) _TEST_BQ_OUTPUT_CONFIG = ( gca_batch_prediction_job_compat.BatchPredictionJob.OutputConfig( predictions_format="bigquery", bigquery_destination=gca_io_compat.BigQueryDestination( output_uri=_TEST_BQ_PATH ), )
kms_key_name=_TEST_DEFAULT_ENCRYPTION_KEY_NAME) _TEST_SERVICE_ACCOUNT = "*****@*****.**" _TEST_NETWORK = f"projects/{_TEST_PROJECT}/global/networks/{_TEST_ID}" _TEST_TIMEOUT = 8000 _TEST_RESTART_JOB_ON_WORKER_RESTART = True _TEST_LABELS = {"my_key": "my_value"} _TEST_BASE_CUSTOM_JOB_PROTO = gca_custom_job_compat.CustomJob( display_name=_TEST_DISPLAY_NAME, job_spec=gca_custom_job_compat.CustomJobSpec( worker_pool_specs=_TEST_WORKER_POOL_SPEC, base_output_directory=gca_io_compat.GcsDestination( output_uri_prefix=_TEST_BASE_OUTPUT_DIR), scheduling=gca_custom_job_compat.Scheduling( timeout=duration_pb2.Duration(seconds=_TEST_TIMEOUT), restart_job_on_worker_restart=_TEST_RESTART_JOB_ON_WORKER_RESTART, ), service_account=_TEST_SERVICE_ACCOUNT, network=_TEST_NETWORK, ), labels=_TEST_LABELS, encryption_spec=_TEST_DEFAULT_ENCRYPTION_SPEC, ) def _get_custom_job_proto(state=None, name=None, error=None): custom_job_proto = copy.deepcopy(_TEST_BASE_CUSTOM_JOB_PROTO) custom_job_proto.name = name
_TEST_DEFAULT_ENCRYPTION_KEY_NAME = "key_default" _TEST_DEFAULT_ENCRYPTION_SPEC = gca_encryption_spec_compat.EncryptionSpec( kms_key_name=_TEST_DEFAULT_ENCRYPTION_KEY_NAME) _TEST_SERVICE_ACCOUNT = "*****@*****.**" _TEST_NETWORK = f"projects/{_TEST_PROJECT}/global/networks/{_TEST_ID}" _TEST_TIMEOUT = 8000 _TEST_RESTART_JOB_ON_WORKER_RESTART = True _TEST_BASE_CUSTOM_JOB_PROTO = gca_custom_job_compat.CustomJob( display_name=_TEST_DISPLAY_NAME, job_spec=gca_custom_job_compat.CustomJobSpec( worker_pool_specs=_TEST_WORKER_POOL_SPEC, base_output_directory=gca_io_compat.GcsDestination( output_uri_prefix=_TEST_STAGING_BUCKET), scheduling=gca_custom_job_compat.Scheduling( timeout=duration_pb2.Duration(seconds=_TEST_TIMEOUT), restart_job_on_worker_restart=_TEST_RESTART_JOB_ON_WORKER_RESTART, ), service_account=_TEST_SERVICE_ACCOUNT, network=_TEST_NETWORK, ), encryption_spec=_TEST_DEFAULT_ENCRYPTION_SPEC, ) def _get_custom_job_proto(state=None, name=None, error=None, version="v1"): custom_job_proto = copy.deepcopy(_TEST_BASE_CUSTOM_JOB_PROTO) custom_job_proto.name = name custom_job_proto.state = state