def test_explain(self, get_endpoint_mock, predict_client_explain_mock):
        aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION)

        test_endpoint = models.Endpoint(_TEST_ID)
        test_prediction = test_endpoint.explain(
            instances=_TEST_INSTANCES,
            parameters={"param": 3.0},
            deployed_model_id=_TEST_MODEL_ID,
        )
        expected_explanations = _TEST_EXPLANATIONS
        expected_explanations[0].attributions.extend(_TEST_ATTRIBUTIONS)

        expected_prediction = models.Prediction(
            predictions=_TEST_PREDICTION,
            deployed_model_id=_TEST_ID,
            explanations=expected_explanations,
        )

        assert expected_prediction == test_prediction
        predict_client_explain_mock.assert_called_once_with(
            endpoint=_TEST_ENDPOINT_NAME,
            instances=_TEST_INSTANCES,
            parameters={"param": 3.0},
            deployed_model_id=_TEST_MODEL_ID,
        )
    def test_predict(self, get_endpoint_mock, predict_client_predict_mock):

        test_endpoint = models.Endpoint(_TEST_ID)
        test_prediction = test_endpoint.predict(instances=_TEST_INSTANCES,
                                                parameters={"param": 3.0})

        true_prediction = models.Prediction(predictions=_TEST_PREDICTION,
                                            deployed_model_id=_TEST_ID)

        assert true_prediction == test_prediction
        predict_client_predict_mock.assert_called_once_with(
            endpoint=_TEST_ENDPOINT_NAME,
            instances=_TEST_INSTANCES,
            parameters={"param": 3.0},
        )
示例#3
0
    def test_dataset_create_to_model_predict(
        self,
        create_dataset_mock,  # noqa: F811
        import_data_mock,  # noqa: F811
        predict_client_predict_mock,  # noqa: F811
        mock_python_package_to_gcs,  # noqa: F811
        mock_pipeline_service_create,  # noqa: F811
        mock_model_service_get,  # noqa: F811
        mock_pipeline_service_get,  # noqa: F811
        sync,
    ):

        aiplatform.init(
            project=test_datasets._TEST_PROJECT,
            staging_bucket=test_training_jobs._TEST_BUCKET_NAME,
            credentials=test_training_jobs._TEST_CREDENTIALS,
        )

        my_dataset = aiplatform.ImageDataset.create(
            display_name=test_datasets._TEST_DISPLAY_NAME,
            encryption_spec_key_name=_TEST_ENCRYPTION_KEY_NAME,
            sync=sync,
        )

        my_dataset.import_data(
            gcs_source=test_datasets._TEST_SOURCE_URI_GCS,
            import_schema_uri=test_datasets._TEST_IMPORT_SCHEMA_URI,
            data_item_labels=test_datasets._TEST_DATA_LABEL_ITEMS,
            sync=sync,
        )

        job = aiplatform.CustomTrainingJob(
            display_name=test_training_jobs._TEST_DISPLAY_NAME,
            script_path=test_training_jobs._TEST_LOCAL_SCRIPT_FILE_NAME,
            container_uri=test_training_jobs._TEST_TRAINING_CONTAINER_IMAGE,
            model_serving_container_image_uri=test_training_jobs.
            _TEST_SERVING_CONTAINER_IMAGE,
            model_serving_container_predict_route=test_training_jobs.
            _TEST_SERVING_CONTAINER_PREDICTION_ROUTE,
            model_serving_container_health_route=test_training_jobs.
            _TEST_SERVING_CONTAINER_HEALTH_ROUTE,
        )

        model_from_job = job.run(
            dataset=my_dataset,
            base_output_dir=test_training_jobs._TEST_BASE_OUTPUT_DIR,
            args=test_training_jobs._TEST_RUN_ARGS,
            replica_count=1,
            machine_type=test_training_jobs._TEST_MACHINE_TYPE,
            accelerator_type=test_training_jobs._TEST_ACCELERATOR_TYPE,
            accelerator_count=test_training_jobs._TEST_ACCELERATOR_COUNT,
            model_display_name=test_training_jobs._TEST_MODEL_DISPLAY_NAME,
            training_fraction_split=test_training_jobs.
            _TEST_TRAINING_FRACTION_SPLIT,
            validation_fraction_split=test_training_jobs.
            _TEST_VALIDATION_FRACTION_SPLIT,
            test_fraction_split=test_training_jobs._TEST_TEST_FRACTION_SPLIT,
            sync=sync,
        )

        created_endpoint = models.Endpoint.create(
            display_name=test_endpoints._TEST_DISPLAY_NAME,
            encryption_spec_key_name=_TEST_ENCRYPTION_KEY_NAME,
            sync=sync,
        )

        my_endpoint = model_from_job.deploy(
            encryption_spec_key_name=_TEST_ENCRYPTION_KEY_NAME, sync=sync)

        endpoint_deploy_return = created_endpoint.deploy(model_from_job,
                                                         sync=sync)

        assert endpoint_deploy_return is None

        if not sync:
            my_endpoint.wait()
            created_endpoint.wait()

        test_prediction = created_endpoint.predict(instances=[[1.0, 2.0, 3.0],
                                                              [1.0, 3.0, 4.0]],
                                                   parameters={"param": 3.0})

        true_prediction = models.Prediction(
            predictions=test_endpoints._TEST_PREDICTION,
            deployed_model_id=test_endpoints._TEST_ID,
        )

        assert true_prediction == test_prediction
        predict_client_predict_mock.assert_called_once_with(
            endpoint=test_endpoints._TEST_ENDPOINT_NAME,
            instances=[[1.0, 2.0, 3.0], [1.0, 3.0, 4.0]],
            parameters={"param": 3.0},
        )

        expected_dataset = gca_dataset.Dataset(
            display_name=test_datasets._TEST_DISPLAY_NAME,
            metadata_schema_uri=test_datasets.
            _TEST_METADATA_SCHEMA_URI_NONTABULAR,
            metadata=test_datasets._TEST_NONTABULAR_DATASET_METADATA,
            encryption_spec=_TEST_ENCRYPTION_SPEC,
        )

        expected_import_config = gca_dataset.ImportDataConfig(
            gcs_source=gca_io.GcsSource(
                uris=[test_datasets._TEST_SOURCE_URI_GCS]),
            import_schema_uri=test_datasets._TEST_IMPORT_SCHEMA_URI,
            data_item_labels=test_datasets._TEST_DATA_LABEL_ITEMS,
        )

        create_dataset_mock.assert_called_once_with(
            parent=test_datasets._TEST_PARENT,
            dataset=expected_dataset,
            metadata=test_datasets._TEST_REQUEST_METADATA,
        )

        import_data_mock.assert_called_once_with(
            name=test_datasets._TEST_NAME,
            import_configs=[expected_import_config])

        expected_dataset.name = test_datasets._TEST_NAME
        assert my_dataset._gca_resource == expected_dataset

        mock_python_package_to_gcs.assert_called_once_with(
            gcs_staging_dir=test_training_jobs._TEST_BUCKET_NAME,
            project=test_training_jobs._TEST_PROJECT,
            credentials=initializer.global_config.credentials,
        )

        true_args = test_training_jobs._TEST_RUN_ARGS

        true_worker_pool_spec = {
            "replica_count": test_training_jobs._TEST_REPLICA_COUNT,
            "machine_spec": {
                "machine_type": test_training_jobs._TEST_MACHINE_TYPE,
                "accelerator_type": test_training_jobs._TEST_ACCELERATOR_TYPE,
                "accelerator_count":
                test_training_jobs._TEST_ACCELERATOR_COUNT,
            },
            "python_package_spec": {
                "executor_image_uri":
                test_training_jobs._TEST_TRAINING_CONTAINER_IMAGE,
                "python_module":
                source_utils._TrainingScriptPythonPackager.module_name,
                "package_uris":
                [test_training_jobs._TEST_OUTPUT_PYTHON_PACKAGE_PATH],
                "args":
                true_args,
            },
        }

        true_fraction_split = gca_training_pipeline.FractionSplit(
            training_fraction=test_training_jobs._TEST_TRAINING_FRACTION_SPLIT,
            validation_fraction=test_training_jobs.
            _TEST_VALIDATION_FRACTION_SPLIT,
            test_fraction=test_training_jobs._TEST_TEST_FRACTION_SPLIT,
        )

        true_container_spec = gca_model.ModelContainerSpec(
            image_uri=test_training_jobs._TEST_SERVING_CONTAINER_IMAGE,
            predict_route=test_training_jobs.
            _TEST_SERVING_CONTAINER_PREDICTION_ROUTE,
            health_route=test_training_jobs.
            _TEST_SERVING_CONTAINER_HEALTH_ROUTE,
        )

        true_managed_model = gca_model.Model(
            display_name=test_training_jobs._TEST_MODEL_DISPLAY_NAME,
            container_spec=true_container_spec,
        )

        true_input_data_config = gca_training_pipeline.InputDataConfig(
            fraction_split=true_fraction_split,
            dataset_id=my_dataset.name,
            gcs_destination=gca_io.GcsDestination(
                output_uri_prefix=test_training_jobs._TEST_BASE_OUTPUT_DIR),
        )

        true_training_pipeline = gca_training_pipeline.TrainingPipeline(
            display_name=test_training_jobs._TEST_DISPLAY_NAME,
            training_task_definition=schema.training_job.definition.
            custom_task,
            training_task_inputs=json_format.ParseDict(
                {
                    "worker_pool_specs": [true_worker_pool_spec],
                    "base_output_directory": {
                        "output_uri_prefix":
                        test_training_jobs._TEST_BASE_OUTPUT_DIR
                    },
                },
                struct_pb2.Value(),
            ),
            model_to_upload=true_managed_model,
            input_data_config=true_input_data_config,
        )

        mock_pipeline_service_create.assert_called_once_with(
            parent=initializer.global_config.common_location_path(),
            training_pipeline=true_training_pipeline,
        )

        assert job._gca_resource is mock_pipeline_service_get.return_value

        mock_model_service_get.assert_called_once_with(
            name=test_training_jobs._TEST_MODEL_NAME)

        assert model_from_job._gca_resource is mock_model_service_get.return_value

        assert job.get_model(
        )._gca_resource is mock_model_service_get.return_value

        assert not job.has_failed

        assert job.state == gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED