Exemplo n.º 1
0
    def test_create_then_import(
        self, create_dataset_mock, import_data_mock, get_dataset_mock, sync
    ):

        aiplatform.init(project=_TEST_PROJECT)

        my_dataset = datasets._Dataset.create(
            display_name=_TEST_DISPLAY_NAME,
            metadata_schema_uri=_TEST_METADATA_SCHEMA_URI_NONTABULAR,
            encryption_spec_key_name=_TEST_ENCRYPTION_KEY_NAME,
            sync=sync,
        )

        my_dataset.import_data(
            gcs_source=_TEST_SOURCE_URI_GCS,
            import_schema_uri=_TEST_IMPORT_SCHEMA_URI,
            data_item_labels=_TEST_DATA_LABEL_ITEMS,
            sync=sync,
        )

        if not sync:
            my_dataset.wait()

        expected_dataset = gca_dataset.Dataset(
            display_name=_TEST_DISPLAY_NAME,
            metadata_schema_uri=_TEST_METADATA_SCHEMA_URI_NONTABULAR,
            metadata=_TEST_NONTABULAR_DATASET_METADATA,
            encryption_spec=_TEST_ENCRYPTION_SPEC,
        )

        expected_import_config = gca_dataset.ImportDataConfig(
            gcs_source=gca_io.GcsSource(uris=[_TEST_SOURCE_URI_GCS]),
            import_schema_uri=_TEST_IMPORT_SCHEMA_URI,
            data_item_labels=_TEST_DATA_LABEL_ITEMS,
        )

        create_dataset_mock.assert_called_once_with(
            parent=_TEST_PARENT,
            dataset=expected_dataset,
            metadata=_TEST_REQUEST_METADATA,
        )

        get_dataset_mock.assert_called_once_with(name=_TEST_NAME)

        import_data_mock.assert_called_once_with(
            name=_TEST_NAME, import_configs=[expected_import_config]
        )

        expected_dataset.name = _TEST_NAME
        assert my_dataset._gca_resource == expected_dataset
Exemplo n.º 2
0
    def test_init_aiplatform_with_encryption_key_name_and_batch_predict_gcs_source_and_dest(
        self, create_batch_prediction_job_mock, sync
    ):
        aiplatform.init(
            project=_TEST_PROJECT,
            location=_TEST_LOCATION,
            encryption_spec_key_name=_TEST_ENCRYPTION_KEY_NAME,
        )
        test_model = models.Model(_TEST_ID)

        # Make SDK batch_predict method call
        batch_prediction_job = test_model.batch_predict(
            job_display_name=_TEST_BATCH_PREDICTION_DISPLAY_NAME,
            gcs_source=_TEST_BATCH_PREDICTION_GCS_SOURCE,
            gcs_destination_prefix=_TEST_BATCH_PREDICTION_GCS_DEST_PREFIX,
            sync=sync,
        )

        if not sync:
            batch_prediction_job.wait()

        # Construct expected request
        expected_gapic_batch_prediction_job = gca_batch_prediction_job.BatchPredictionJob(
            display_name=_TEST_BATCH_PREDICTION_DISPLAY_NAME,
            model=model_service_client.ModelServiceClient.model_path(
                _TEST_PROJECT, _TEST_LOCATION, _TEST_ID
            ),
            input_config=gca_batch_prediction_job.BatchPredictionJob.InputConfig(
                instances_format="jsonl",
                gcs_source=gca_io.GcsSource(uris=[_TEST_BATCH_PREDICTION_GCS_SOURCE]),
            ),
            output_config=gca_batch_prediction_job.BatchPredictionJob.OutputConfig(
                gcs_destination=gca_io.GcsDestination(
                    output_uri_prefix=_TEST_BATCH_PREDICTION_GCS_DEST_PREFIX
                ),
                predictions_format="jsonl",
            ),
            encryption_spec=_TEST_ENCRYPTION_SPEC,
        )

        create_batch_prediction_job_mock.assert_called_once_with(
            parent=_TEST_PARENT,
            batch_prediction_job=expected_gapic_batch_prediction_job,
        )
Exemplo n.º 3
0
    def test_batch_predict_gcs_source_bq_dest(self,
                                              create_batch_prediction_job_mock,
                                              sync):
        aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION)

        batch_prediction_job = jobs.BatchPredictionJob.create(
            model_name=_TEST_MODEL_NAME,
            job_display_name=_TEST_BATCH_PREDICTION_JOB_DISPLAY_NAME,
            gcs_source=_TEST_BATCH_PREDICTION_GCS_SOURCE,
            bigquery_destination_prefix=_TEST_BATCH_PREDICTION_BQ_PREFIX,
            sync=sync,
        )

        if not sync:
            batch_prediction_job.wait()

        assert (batch_prediction_job.output_info ==
                gca_batch_prediction_job.BatchPredictionJob.OutputInfo())

        # Construct expected request
        expected_gapic_batch_prediction_job = gca_batch_prediction_job.BatchPredictionJob(
            display_name=_TEST_BATCH_PREDICTION_JOB_DISPLAY_NAME,
            model=_TEST_MODEL_NAME,
            input_config=gca_batch_prediction_job.BatchPredictionJob.
            InputConfig(
                instances_format="jsonl",
                gcs_source=gca_io.GcsSource(
                    uris=[_TEST_BATCH_PREDICTION_GCS_SOURCE]),
            ),
            output_config=gca_batch_prediction_job.BatchPredictionJob.
            OutputConfig(
                bigquery_destination=gca_io.BigQueryDestination(
                    output_uri=
                    _TEST_BATCH_PREDICTION_BQ_DEST_PREFIX_WITH_PROTOCOL),
                predictions_format="bigquery",
            ),
        )

        create_batch_prediction_job_mock.assert_called_once_with(
            parent=_TEST_PARENT,
            batch_prediction_job=expected_gapic_batch_prediction_job,
        )
Exemplo n.º 4
0
    def test_batch_predict_gcs_source_bq_dest(self,
                                              create_batch_prediction_job_mock,
                                              sync):

        test_model = models.Model(_TEST_ID)

        # Make SDK batch_predict method call
        batch_prediction_job = test_model.batch_predict(
            job_display_name=_TEST_BATCH_PREDICTION_DISPLAY_NAME,
            gcs_source=_TEST_BATCH_PREDICTION_GCS_SOURCE,
            bigquery_destination_prefix=_TEST_BATCH_PREDICTION_BQ_PREFIX,
            sync=sync,
        )

        if not sync:
            batch_prediction_job.wait()

        # Construct expected request
        expected_gapic_batch_prediction_job = gca_batch_prediction_job.BatchPredictionJob(
            display_name=_TEST_BATCH_PREDICTION_DISPLAY_NAME,
            model=model_service_client.ModelServiceClient.model_path(
                _TEST_PROJECT, _TEST_LOCATION, _TEST_ID),
            input_config=gca_batch_prediction_job.BatchPredictionJob.
            InputConfig(
                instances_format="jsonl",
                gcs_source=gca_io.GcsSource(
                    uris=[_TEST_BATCH_PREDICTION_GCS_SOURCE]),
            ),
            output_config=gca_batch_prediction_job.BatchPredictionJob.
            OutputConfig(
                bigquery_destination=gca_io.BigQueryDestination(
                    output_uri=
                    _TEST_BATCH_PREDICTION_BQ_DEST_PREFIX_WITH_PROTOCOL),
                predictions_format="bigquery",
            ),
        )

        create_batch_prediction_job_mock.assert_called_once_with(
            parent=_TEST_PARENT,
            batch_prediction_job=expected_gapic_batch_prediction_job,
        )
Exemplo n.º 5
0
    def test_import_data(self, import_data_mock, sync):
        aiplatform.init(project=_TEST_PROJECT)

        my_dataset = datasets.ImageDataset(dataset_name=_TEST_NAME)

        my_dataset.import_data(
            gcs_source=[_TEST_SOURCE_URI_GCS],
            import_schema_uri=_TEST_IMPORT_SCHEMA_URI_IMAGE,
            sync=sync,
        )

        if not sync:
            my_dataset.wait()

        expected_import_config = gca_dataset.ImportDataConfig(
            gcs_source=gca_io.GcsSource(uris=[_TEST_SOURCE_URI_GCS]),
            import_schema_uri=_TEST_IMPORT_SCHEMA_URI_IMAGE,
        )

        import_data_mock.assert_called_once_with(
            name=_TEST_NAME, import_configs=[expected_import_config])
Exemplo n.º 6
0
    def test_dataset_create_to_model_predict(
        self,
        create_dataset_mock,  # noqa: F811
        import_data_mock,  # noqa: F811
        predict_client_predict_mock,  # noqa: F811
        mock_python_package_to_gcs,  # noqa: F811
        mock_pipeline_service_create,  # noqa: F811
        mock_model_service_get,  # noqa: F811
        mock_pipeline_service_get,  # noqa: F811
        sync,
    ):

        aiplatform.init(
            project=test_datasets._TEST_PROJECT,
            staging_bucket=test_training_jobs._TEST_BUCKET_NAME,
            credentials=test_training_jobs._TEST_CREDENTIALS,
        )

        my_dataset = aiplatform.ImageDataset.create(
            display_name=test_datasets._TEST_DISPLAY_NAME,
            encryption_spec_key_name=_TEST_ENCRYPTION_KEY_NAME,
            sync=sync,
        )

        my_dataset.import_data(
            gcs_source=test_datasets._TEST_SOURCE_URI_GCS,
            import_schema_uri=test_datasets._TEST_IMPORT_SCHEMA_URI,
            data_item_labels=test_datasets._TEST_DATA_LABEL_ITEMS,
            sync=sync,
        )

        job = aiplatform.CustomTrainingJob(
            display_name=test_training_jobs._TEST_DISPLAY_NAME,
            script_path=test_training_jobs._TEST_LOCAL_SCRIPT_FILE_NAME,
            container_uri=test_training_jobs._TEST_TRAINING_CONTAINER_IMAGE,
            model_serving_container_image_uri=test_training_jobs.
            _TEST_SERVING_CONTAINER_IMAGE,
            model_serving_container_predict_route=test_training_jobs.
            _TEST_SERVING_CONTAINER_PREDICTION_ROUTE,
            model_serving_container_health_route=test_training_jobs.
            _TEST_SERVING_CONTAINER_HEALTH_ROUTE,
        )

        model_from_job = job.run(
            dataset=my_dataset,
            base_output_dir=test_training_jobs._TEST_BASE_OUTPUT_DIR,
            args=test_training_jobs._TEST_RUN_ARGS,
            replica_count=1,
            machine_type=test_training_jobs._TEST_MACHINE_TYPE,
            accelerator_type=test_training_jobs._TEST_ACCELERATOR_TYPE,
            accelerator_count=test_training_jobs._TEST_ACCELERATOR_COUNT,
            model_display_name=test_training_jobs._TEST_MODEL_DISPLAY_NAME,
            training_fraction_split=test_training_jobs.
            _TEST_TRAINING_FRACTION_SPLIT,
            validation_fraction_split=test_training_jobs.
            _TEST_VALIDATION_FRACTION_SPLIT,
            test_fraction_split=test_training_jobs._TEST_TEST_FRACTION_SPLIT,
            sync=sync,
        )

        created_endpoint = models.Endpoint.create(
            display_name=test_endpoints._TEST_DISPLAY_NAME,
            encryption_spec_key_name=_TEST_ENCRYPTION_KEY_NAME,
            sync=sync,
        )

        my_endpoint = model_from_job.deploy(
            encryption_spec_key_name=_TEST_ENCRYPTION_KEY_NAME, sync=sync)

        endpoint_deploy_return = created_endpoint.deploy(model_from_job,
                                                         sync=sync)

        assert endpoint_deploy_return is None

        if not sync:
            my_endpoint.wait()
            created_endpoint.wait()

        test_prediction = created_endpoint.predict(instances=[[1.0, 2.0, 3.0],
                                                              [1.0, 3.0, 4.0]],
                                                   parameters={"param": 3.0})

        true_prediction = models.Prediction(
            predictions=test_endpoints._TEST_PREDICTION,
            deployed_model_id=test_endpoints._TEST_ID,
        )

        assert true_prediction == test_prediction
        predict_client_predict_mock.assert_called_once_with(
            endpoint=test_endpoints._TEST_ENDPOINT_NAME,
            instances=[[1.0, 2.0, 3.0], [1.0, 3.0, 4.0]],
            parameters={"param": 3.0},
        )

        expected_dataset = gca_dataset.Dataset(
            display_name=test_datasets._TEST_DISPLAY_NAME,
            metadata_schema_uri=test_datasets.
            _TEST_METADATA_SCHEMA_URI_NONTABULAR,
            metadata=test_datasets._TEST_NONTABULAR_DATASET_METADATA,
            encryption_spec=_TEST_ENCRYPTION_SPEC,
        )

        expected_import_config = gca_dataset.ImportDataConfig(
            gcs_source=gca_io.GcsSource(
                uris=[test_datasets._TEST_SOURCE_URI_GCS]),
            import_schema_uri=test_datasets._TEST_IMPORT_SCHEMA_URI,
            data_item_labels=test_datasets._TEST_DATA_LABEL_ITEMS,
        )

        create_dataset_mock.assert_called_once_with(
            parent=test_datasets._TEST_PARENT,
            dataset=expected_dataset,
            metadata=test_datasets._TEST_REQUEST_METADATA,
        )

        import_data_mock.assert_called_once_with(
            name=test_datasets._TEST_NAME,
            import_configs=[expected_import_config])

        expected_dataset.name = test_datasets._TEST_NAME
        assert my_dataset._gca_resource == expected_dataset

        mock_python_package_to_gcs.assert_called_once_with(
            gcs_staging_dir=test_training_jobs._TEST_BUCKET_NAME,
            project=test_training_jobs._TEST_PROJECT,
            credentials=initializer.global_config.credentials,
        )

        true_args = test_training_jobs._TEST_RUN_ARGS

        true_worker_pool_spec = {
            "replica_count": test_training_jobs._TEST_REPLICA_COUNT,
            "machine_spec": {
                "machine_type": test_training_jobs._TEST_MACHINE_TYPE,
                "accelerator_type": test_training_jobs._TEST_ACCELERATOR_TYPE,
                "accelerator_count":
                test_training_jobs._TEST_ACCELERATOR_COUNT,
            },
            "python_package_spec": {
                "executor_image_uri":
                test_training_jobs._TEST_TRAINING_CONTAINER_IMAGE,
                "python_module":
                source_utils._TrainingScriptPythonPackager.module_name,
                "package_uris":
                [test_training_jobs._TEST_OUTPUT_PYTHON_PACKAGE_PATH],
                "args":
                true_args,
            },
        }

        true_fraction_split = gca_training_pipeline.FractionSplit(
            training_fraction=test_training_jobs._TEST_TRAINING_FRACTION_SPLIT,
            validation_fraction=test_training_jobs.
            _TEST_VALIDATION_FRACTION_SPLIT,
            test_fraction=test_training_jobs._TEST_TEST_FRACTION_SPLIT,
        )

        true_container_spec = gca_model.ModelContainerSpec(
            image_uri=test_training_jobs._TEST_SERVING_CONTAINER_IMAGE,
            predict_route=test_training_jobs.
            _TEST_SERVING_CONTAINER_PREDICTION_ROUTE,
            health_route=test_training_jobs.
            _TEST_SERVING_CONTAINER_HEALTH_ROUTE,
        )

        true_managed_model = gca_model.Model(
            display_name=test_training_jobs._TEST_MODEL_DISPLAY_NAME,
            container_spec=true_container_spec,
        )

        true_input_data_config = gca_training_pipeline.InputDataConfig(
            fraction_split=true_fraction_split,
            dataset_id=my_dataset.name,
            gcs_destination=gca_io.GcsDestination(
                output_uri_prefix=test_training_jobs._TEST_BASE_OUTPUT_DIR),
        )

        true_training_pipeline = gca_training_pipeline.TrainingPipeline(
            display_name=test_training_jobs._TEST_DISPLAY_NAME,
            training_task_definition=schema.training_job.definition.
            custom_task,
            training_task_inputs=json_format.ParseDict(
                {
                    "worker_pool_specs": [true_worker_pool_spec],
                    "base_output_directory": {
                        "output_uri_prefix":
                        test_training_jobs._TEST_BASE_OUTPUT_DIR
                    },
                },
                struct_pb2.Value(),
            ),
            model_to_upload=true_managed_model,
            input_data_config=true_input_data_config,
        )

        mock_pipeline_service_create.assert_called_once_with(
            parent=initializer.global_config.common_location_path(),
            training_pipeline=true_training_pipeline,
        )

        assert job._gca_resource is mock_pipeline_service_get.return_value

        mock_model_service_get.assert_called_once_with(
            name=test_training_jobs._TEST_MODEL_NAME)

        assert model_from_job._gca_resource is mock_model_service_get.return_value

        assert job.get_model(
        )._gca_resource is mock_model_service_get.return_value

        assert not job.has_failed

        assert job.state == gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED
Exemplo n.º 7
0
_TEST_BATCH_PREDICTION_GCS_SOURCE_LIST = [
    "gs://example-bucket/folder/instance1.jsonl",
    "gs://example-bucket/folder/instance2.jsonl",
]
_TEST_BATCH_PREDICTION_GCS_DEST_PREFIX = "gs://example-bucket/folder/output"
_TEST_BATCH_PREDICTION_BQ_PREFIX = "ucaip-sample-tests"
_TEST_BATCH_PREDICTION_BQ_DEST_PREFIX_WITH_PROTOCOL = (
    f"bq://{_TEST_BATCH_PREDICTION_BQ_PREFIX}")

_TEST_JOB_STATE_SUCCESS = gca_job_state.JobState(4)
_TEST_JOB_STATE_RUNNING = gca_job_state.JobState(3)
_TEST_JOB_STATE_PENDING = gca_job_state.JobState(2)

_TEST_GCS_INPUT_CONFIG = gca_batch_prediction_job.BatchPredictionJob.InputConfig(
    instances_format="jsonl",
    gcs_source=gca_io.GcsSource(uris=[_TEST_GCS_JSONL_SOURCE_URI]),
)
_TEST_GCS_OUTPUT_CONFIG = gca_batch_prediction_job.BatchPredictionJob.OutputConfig(
    predictions_format="jsonl",
    gcs_destination=gca_io.GcsDestination(
        output_uri_prefix=_TEST_GCS_BUCKET_PATH),
)

_TEST_BQ_INPUT_CONFIG = gca_batch_prediction_job.BatchPredictionJob.InputConfig(
    instances_format="bigquery",
    bigquery_source=gca_io.BigQuerySource(input_uri=_TEST_BQ_PATH),
)
_TEST_BQ_OUTPUT_CONFIG = gca_batch_prediction_job.BatchPredictionJob.OutputConfig(
    predictions_format="bigquery",
    bigquery_destination=gca_io.BigQueryDestination(output_uri=_TEST_BQ_PATH),
)