def test_explain(self, get_endpoint_mock, predict_client_explain_mock): aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) test_endpoint = models.Endpoint(_TEST_ID) test_prediction = test_endpoint.explain( instances=_TEST_INSTANCES, parameters={"param": 3.0}, deployed_model_id=_TEST_MODEL_ID, ) expected_explanations = _TEST_EXPLANATIONS expected_explanations[0].attributions.extend(_TEST_ATTRIBUTIONS) expected_prediction = models.Prediction( predictions=_TEST_PREDICTION, deployed_model_id=_TEST_ID, explanations=expected_explanations, ) assert expected_prediction == test_prediction predict_client_explain_mock.assert_called_once_with( endpoint=_TEST_ENDPOINT_NAME, instances=_TEST_INSTANCES, parameters={"param": 3.0}, deployed_model_id=_TEST_MODEL_ID, )
def test_predict(self, get_endpoint_mock, predict_client_predict_mock): test_endpoint = models.Endpoint(_TEST_ID) test_prediction = test_endpoint.predict(instances=_TEST_INSTANCES, parameters={"param": 3.0}) true_prediction = models.Prediction(predictions=_TEST_PREDICTION, deployed_model_id=_TEST_ID) assert true_prediction == test_prediction predict_client_predict_mock.assert_called_once_with( endpoint=_TEST_ENDPOINT_NAME, instances=_TEST_INSTANCES, parameters={"param": 3.0}, )
def test_dataset_create_to_model_predict( self, create_dataset_mock, # noqa: F811 import_data_mock, # noqa: F811 predict_client_predict_mock, # noqa: F811 mock_python_package_to_gcs, # noqa: F811 mock_pipeline_service_create, # noqa: F811 mock_model_service_get, # noqa: F811 mock_pipeline_service_get, # noqa: F811 sync, ): aiplatform.init( project=test_datasets._TEST_PROJECT, staging_bucket=test_training_jobs._TEST_BUCKET_NAME, credentials=test_training_jobs._TEST_CREDENTIALS, ) my_dataset = aiplatform.ImageDataset.create( display_name=test_datasets._TEST_DISPLAY_NAME, encryption_spec_key_name=_TEST_ENCRYPTION_KEY_NAME, sync=sync, ) my_dataset.import_data( gcs_source=test_datasets._TEST_SOURCE_URI_GCS, import_schema_uri=test_datasets._TEST_IMPORT_SCHEMA_URI, data_item_labels=test_datasets._TEST_DATA_LABEL_ITEMS, sync=sync, ) job = aiplatform.CustomTrainingJob( display_name=test_training_jobs._TEST_DISPLAY_NAME, script_path=test_training_jobs._TEST_LOCAL_SCRIPT_FILE_NAME, container_uri=test_training_jobs._TEST_TRAINING_CONTAINER_IMAGE, model_serving_container_image_uri=test_training_jobs. _TEST_SERVING_CONTAINER_IMAGE, model_serving_container_predict_route=test_training_jobs. _TEST_SERVING_CONTAINER_PREDICTION_ROUTE, model_serving_container_health_route=test_training_jobs. _TEST_SERVING_CONTAINER_HEALTH_ROUTE, ) model_from_job = job.run( dataset=my_dataset, base_output_dir=test_training_jobs._TEST_BASE_OUTPUT_DIR, args=test_training_jobs._TEST_RUN_ARGS, replica_count=1, machine_type=test_training_jobs._TEST_MACHINE_TYPE, accelerator_type=test_training_jobs._TEST_ACCELERATOR_TYPE, accelerator_count=test_training_jobs._TEST_ACCELERATOR_COUNT, model_display_name=test_training_jobs._TEST_MODEL_DISPLAY_NAME, training_fraction_split=test_training_jobs. _TEST_TRAINING_FRACTION_SPLIT, validation_fraction_split=test_training_jobs. _TEST_VALIDATION_FRACTION_SPLIT, test_fraction_split=test_training_jobs._TEST_TEST_FRACTION_SPLIT, sync=sync, ) created_endpoint = models.Endpoint.create( display_name=test_endpoints._TEST_DISPLAY_NAME, encryption_spec_key_name=_TEST_ENCRYPTION_KEY_NAME, sync=sync, ) my_endpoint = model_from_job.deploy( encryption_spec_key_name=_TEST_ENCRYPTION_KEY_NAME, sync=sync) endpoint_deploy_return = created_endpoint.deploy(model_from_job, sync=sync) assert endpoint_deploy_return is None if not sync: my_endpoint.wait() created_endpoint.wait() test_prediction = created_endpoint.predict(instances=[[1.0, 2.0, 3.0], [1.0, 3.0, 4.0]], parameters={"param": 3.0}) true_prediction = models.Prediction( predictions=test_endpoints._TEST_PREDICTION, deployed_model_id=test_endpoints._TEST_ID, ) assert true_prediction == test_prediction predict_client_predict_mock.assert_called_once_with( endpoint=test_endpoints._TEST_ENDPOINT_NAME, instances=[[1.0, 2.0, 3.0], [1.0, 3.0, 4.0]], parameters={"param": 3.0}, ) expected_dataset = gca_dataset.Dataset( display_name=test_datasets._TEST_DISPLAY_NAME, metadata_schema_uri=test_datasets. _TEST_METADATA_SCHEMA_URI_NONTABULAR, metadata=test_datasets._TEST_NONTABULAR_DATASET_METADATA, encryption_spec=_TEST_ENCRYPTION_SPEC, ) expected_import_config = gca_dataset.ImportDataConfig( gcs_source=gca_io.GcsSource( uris=[test_datasets._TEST_SOURCE_URI_GCS]), import_schema_uri=test_datasets._TEST_IMPORT_SCHEMA_URI, data_item_labels=test_datasets._TEST_DATA_LABEL_ITEMS, ) create_dataset_mock.assert_called_once_with( parent=test_datasets._TEST_PARENT, dataset=expected_dataset, metadata=test_datasets._TEST_REQUEST_METADATA, ) import_data_mock.assert_called_once_with( name=test_datasets._TEST_NAME, import_configs=[expected_import_config]) expected_dataset.name = test_datasets._TEST_NAME assert my_dataset._gca_resource == expected_dataset mock_python_package_to_gcs.assert_called_once_with( gcs_staging_dir=test_training_jobs._TEST_BUCKET_NAME, project=test_training_jobs._TEST_PROJECT, credentials=initializer.global_config.credentials, ) true_args = test_training_jobs._TEST_RUN_ARGS true_worker_pool_spec = { "replica_count": test_training_jobs._TEST_REPLICA_COUNT, "machine_spec": { "machine_type": test_training_jobs._TEST_MACHINE_TYPE, "accelerator_type": test_training_jobs._TEST_ACCELERATOR_TYPE, "accelerator_count": test_training_jobs._TEST_ACCELERATOR_COUNT, }, "python_package_spec": { "executor_image_uri": test_training_jobs._TEST_TRAINING_CONTAINER_IMAGE, "python_module": source_utils._TrainingScriptPythonPackager.module_name, "package_uris": [test_training_jobs._TEST_OUTPUT_PYTHON_PACKAGE_PATH], "args": true_args, }, } true_fraction_split = gca_training_pipeline.FractionSplit( training_fraction=test_training_jobs._TEST_TRAINING_FRACTION_SPLIT, validation_fraction=test_training_jobs. _TEST_VALIDATION_FRACTION_SPLIT, test_fraction=test_training_jobs._TEST_TEST_FRACTION_SPLIT, ) true_container_spec = gca_model.ModelContainerSpec( image_uri=test_training_jobs._TEST_SERVING_CONTAINER_IMAGE, predict_route=test_training_jobs. _TEST_SERVING_CONTAINER_PREDICTION_ROUTE, health_route=test_training_jobs. _TEST_SERVING_CONTAINER_HEALTH_ROUTE, ) true_managed_model = gca_model.Model( display_name=test_training_jobs._TEST_MODEL_DISPLAY_NAME, container_spec=true_container_spec, ) true_input_data_config = gca_training_pipeline.InputDataConfig( fraction_split=true_fraction_split, dataset_id=my_dataset.name, gcs_destination=gca_io.GcsDestination( output_uri_prefix=test_training_jobs._TEST_BASE_OUTPUT_DIR), ) true_training_pipeline = gca_training_pipeline.TrainingPipeline( display_name=test_training_jobs._TEST_DISPLAY_NAME, training_task_definition=schema.training_job.definition. custom_task, training_task_inputs=json_format.ParseDict( { "worker_pool_specs": [true_worker_pool_spec], "base_output_directory": { "output_uri_prefix": test_training_jobs._TEST_BASE_OUTPUT_DIR }, }, struct_pb2.Value(), ), model_to_upload=true_managed_model, input_data_config=true_input_data_config, ) mock_pipeline_service_create.assert_called_once_with( parent=initializer.global_config.common_location_path(), training_pipeline=true_training_pipeline, ) assert job._gca_resource is mock_pipeline_service_get.return_value mock_model_service_get.assert_called_once_with( name=test_training_jobs._TEST_MODEL_NAME) assert model_from_job._gca_resource is mock_model_service_get.return_value assert job.get_model( )._gca_resource is mock_model_service_get.return_value assert not job.has_failed assert job.state == gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED