def export_data(self, output_dir: str) -> Sequence[str]:
        """Exports data to output dir to GCS.

        Args:
            output_dir (str):
                Required. The Google Cloud Storage location where the output is to
                be written to. In the given directory a new directory will be
                created with name:
                ``export-data-<dataset-display-name>-<timestamp-of-export-call>``
                where timestamp is in YYYYMMDDHHMMSS format. All export
                output will be written into that directory. Inside that
                directory, annotations with the same schema will be grouped
                into sub directories which are named with the corresponding
                annotations' schema title. Inside these sub directories, a
                schema.yaml will be created to describe the output format.

                If the uri doesn't end with '/', a '/' will be automatically
                appended. The directory is created if it doesn't exist.

        Returns:
            exported_files (Sequence[str]):
                All of the files that are exported in this export operation.
        """
        self.wait()

        # TODO(b/171311614): Add support for BiqQuery export path
        export_data_config = gca_dataset.ExportDataConfig(
            gcs_destination=gca_io.GcsDestination(
                output_uri_prefix=output_dir))

        _LOGGER.log_action_start_against_resource("Exporting", "data", self)

        export_lro = self.api_client.export_data(
            name=self.resource_name, export_config=export_data_config)

        _LOGGER.log_action_started_against_resource_with_lro(
            "Export", "data", self.__class__, export_lro)

        export_data_response = export_lro.result()

        _LOGGER.log_action_completed_against_resource("data", "export", self)

        return export_data_response.exported_files
    def test_batch_predict_gcs_source_and_dest_with_timeout(
        self, create_batch_prediction_job_mock, sync
    ):
        aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION)

        # Make SDK batch_predict method call
        batch_prediction_job = jobs.BatchPredictionJob.create(
            model_name=_TEST_MODEL_NAME,
            job_display_name=_TEST_BATCH_PREDICTION_JOB_DISPLAY_NAME,
            gcs_source=_TEST_BATCH_PREDICTION_GCS_SOURCE,
            gcs_destination_prefix=_TEST_BATCH_PREDICTION_GCS_DEST_PREFIX,
            sync=sync,
            create_request_timeout=180.0,
        )

        batch_prediction_job.wait_for_resource_creation()

        batch_prediction_job.wait()

        # Construct expected request
        expected_gapic_batch_prediction_job = gca_batch_prediction_job_compat.BatchPredictionJob(
            display_name=_TEST_BATCH_PREDICTION_JOB_DISPLAY_NAME,
            model=_TEST_MODEL_NAME,
            input_config=gca_batch_prediction_job_compat.BatchPredictionJob.InputConfig(
                instances_format="jsonl",
                gcs_source=gca_io_compat.GcsSource(
                    uris=[_TEST_BATCH_PREDICTION_GCS_SOURCE]
                ),
            ),
            output_config=gca_batch_prediction_job_compat.BatchPredictionJob.OutputConfig(
                gcs_destination=gca_io_compat.GcsDestination(
                    output_uri_prefix=_TEST_BATCH_PREDICTION_GCS_DEST_PREFIX
                ),
                predictions_format="jsonl",
            ),
        )

        create_batch_prediction_job_mock.assert_called_once_with(
            parent=_TEST_PARENT,
            batch_prediction_job=expected_gapic_batch_prediction_job,
            timeout=180.0,
        )
    def test_batch_predict_with_all_args(
        self, create_batch_prediction_job_with_explanations_mock, sync
    ):
        aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION)
        creds = auth_credentials.AnonymousCredentials()

        batch_prediction_job = jobs.BatchPredictionJob.create(
            model_name=_TEST_MODEL_NAME,
            job_display_name=_TEST_BATCH_PREDICTION_JOB_DISPLAY_NAME,
            gcs_source=_TEST_BATCH_PREDICTION_GCS_SOURCE,
            gcs_destination_prefix=_TEST_BATCH_PREDICTION_GCS_DEST_PREFIX,
            predictions_format="csv",
            model_parameters={},
            machine_type=_TEST_MACHINE_TYPE,
            accelerator_type=_TEST_ACCELERATOR_TYPE,
            accelerator_count=_TEST_ACCELERATOR_COUNT,
            starting_replica_count=_TEST_STARTING_REPLICA_COUNT,
            max_replica_count=_TEST_MAX_REPLICA_COUNT,
            generate_explanation=True,
            explanation_metadata=_TEST_EXPLANATION_METADATA,
            explanation_parameters=_TEST_EXPLANATION_PARAMETERS,
            labels=_TEST_LABEL,
            credentials=creds,
            sync=sync,
            create_request_timeout=None,
        )

        batch_prediction_job.wait_for_resource_creation()

        batch_prediction_job.wait()

        # Construct expected request
        expected_gapic_batch_prediction_job = gca_batch_prediction_job_compat.BatchPredictionJob(
            display_name=_TEST_BATCH_PREDICTION_JOB_DISPLAY_NAME,
            model=_TEST_MODEL_NAME,
            input_config=gca_batch_prediction_job_compat.BatchPredictionJob.InputConfig(
                instances_format="jsonl",
                gcs_source=gca_io_compat.GcsSource(
                    uris=[_TEST_BATCH_PREDICTION_GCS_SOURCE]
                ),
            ),
            output_config=gca_batch_prediction_job_compat.BatchPredictionJob.OutputConfig(
                gcs_destination=gca_io_compat.GcsDestination(
                    output_uri_prefix=_TEST_BATCH_PREDICTION_GCS_DEST_PREFIX
                ),
                predictions_format="csv",
            ),
            dedicated_resources=gca_machine_resources_compat.BatchDedicatedResources(
                machine_spec=gca_machine_resources_compat.MachineSpec(
                    machine_type=_TEST_MACHINE_TYPE,
                    accelerator_type=_TEST_ACCELERATOR_TYPE,
                    accelerator_count=_TEST_ACCELERATOR_COUNT,
                ),
                starting_replica_count=_TEST_STARTING_REPLICA_COUNT,
                max_replica_count=_TEST_MAX_REPLICA_COUNT,
            ),
            generate_explanation=True,
            explanation_spec=gca_explanation_compat.ExplanationSpec(
                metadata=_TEST_EXPLANATION_METADATA,
                parameters=_TEST_EXPLANATION_PARAMETERS,
            ),
            labels=_TEST_LABEL,
        )

        create_batch_prediction_job_with_explanations_mock.assert_called_once_with(
            parent=f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}",
            batch_prediction_job=expected_gapic_batch_prediction_job,
            timeout=None,
        )
    f"bq://{_TEST_BATCH_PREDICTION_BQ_PREFIX}"
)

_TEST_JOB_STATE_SUCCESS = gca_job_state_compat.JobState(4)
_TEST_JOB_STATE_RUNNING = gca_job_state_compat.JobState(3)
_TEST_JOB_STATE_PENDING = gca_job_state_compat.JobState(2)

_TEST_GCS_INPUT_CONFIG = gca_batch_prediction_job_compat.BatchPredictionJob.InputConfig(
    instances_format="jsonl",
    gcs_source=gca_io_compat.GcsSource(uris=[_TEST_GCS_JSONL_SOURCE_URI]),
)
_TEST_GCS_OUTPUT_CONFIG = (
    gca_batch_prediction_job_compat.BatchPredictionJob.OutputConfig(
        predictions_format="jsonl",
        gcs_destination=gca_io_compat.GcsDestination(
            output_uri_prefix=_TEST_GCS_BUCKET_PATH
        ),
    )
)

_TEST_BQ_INPUT_CONFIG = gca_batch_prediction_job_compat.BatchPredictionJob.InputConfig(
    instances_format="bigquery",
    bigquery_source=gca_io_compat.BigQuerySource(input_uri=_TEST_BQ_PATH),
)
_TEST_BQ_OUTPUT_CONFIG = (
    gca_batch_prediction_job_compat.BatchPredictionJob.OutputConfig(
        predictions_format="bigquery",
        bigquery_destination=gca_io_compat.BigQueryDestination(
            output_uri=_TEST_BQ_PATH
        ),
    )
Beispiel #5
0
    kms_key_name=_TEST_DEFAULT_ENCRYPTION_KEY_NAME)

_TEST_SERVICE_ACCOUNT = "*****@*****.**"

_TEST_NETWORK = f"projects/{_TEST_PROJECT}/global/networks/{_TEST_ID}"

_TEST_TIMEOUT = 8000
_TEST_RESTART_JOB_ON_WORKER_RESTART = True

_TEST_LABELS = {"my_key": "my_value"}

_TEST_BASE_CUSTOM_JOB_PROTO = gca_custom_job_compat.CustomJob(
    display_name=_TEST_DISPLAY_NAME,
    job_spec=gca_custom_job_compat.CustomJobSpec(
        worker_pool_specs=_TEST_WORKER_POOL_SPEC,
        base_output_directory=gca_io_compat.GcsDestination(
            output_uri_prefix=_TEST_BASE_OUTPUT_DIR),
        scheduling=gca_custom_job_compat.Scheduling(
            timeout=duration_pb2.Duration(seconds=_TEST_TIMEOUT),
            restart_job_on_worker_restart=_TEST_RESTART_JOB_ON_WORKER_RESTART,
        ),
        service_account=_TEST_SERVICE_ACCOUNT,
        network=_TEST_NETWORK,
    ),
    labels=_TEST_LABELS,
    encryption_spec=_TEST_DEFAULT_ENCRYPTION_SPEC,
)


def _get_custom_job_proto(state=None, name=None, error=None):
    custom_job_proto = copy.deepcopy(_TEST_BASE_CUSTOM_JOB_PROTO)
    custom_job_proto.name = name
_TEST_DEFAULT_ENCRYPTION_KEY_NAME = "key_default"
_TEST_DEFAULT_ENCRYPTION_SPEC = gca_encryption_spec_compat.EncryptionSpec(
    kms_key_name=_TEST_DEFAULT_ENCRYPTION_KEY_NAME)

_TEST_SERVICE_ACCOUNT = "*****@*****.**"

_TEST_NETWORK = f"projects/{_TEST_PROJECT}/global/networks/{_TEST_ID}"

_TEST_TIMEOUT = 8000
_TEST_RESTART_JOB_ON_WORKER_RESTART = True

_TEST_BASE_CUSTOM_JOB_PROTO = gca_custom_job_compat.CustomJob(
    display_name=_TEST_DISPLAY_NAME,
    job_spec=gca_custom_job_compat.CustomJobSpec(
        worker_pool_specs=_TEST_WORKER_POOL_SPEC,
        base_output_directory=gca_io_compat.GcsDestination(
            output_uri_prefix=_TEST_STAGING_BUCKET),
        scheduling=gca_custom_job_compat.Scheduling(
            timeout=duration_pb2.Duration(seconds=_TEST_TIMEOUT),
            restart_job_on_worker_restart=_TEST_RESTART_JOB_ON_WORKER_RESTART,
        ),
        service_account=_TEST_SERVICE_ACCOUNT,
        network=_TEST_NETWORK,
    ),
    encryption_spec=_TEST_DEFAULT_ENCRYPTION_SPEC,
)


def _get_custom_job_proto(state=None, name=None, error=None, version="v1"):
    custom_job_proto = copy.deepcopy(_TEST_BASE_CUSTOM_JOB_PROTO)
    custom_job_proto.name = name
    custom_job_proto.state = state