def test_deploy_with_explanations(self, deploy_model_with_explanations_mock, sync): aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) test_endpoint = models.Endpoint(_TEST_ENDPOINT_NAME) test_model = models.Model(_TEST_ID) test_endpoint.deploy( model=test_model, machine_type=_TEST_MACHINE_TYPE, accelerator_type=_TEST_ACCELERATOR_TYPE, accelerator_count=_TEST_ACCELERATOR_COUNT, explanation_metadata=_TEST_EXPLANATION_METADATA, explanation_parameters=_TEST_EXPLANATION_PARAMETERS, sync=sync, ) if not sync: test_endpoint.wait() expected_machine_spec = gca_machine_resources_v1beta1.MachineSpec( machine_type=_TEST_MACHINE_TYPE, accelerator_type=_TEST_ACCELERATOR_TYPE, accelerator_count=_TEST_ACCELERATOR_COUNT, ) expected_dedicated_resources = gca_machine_resources_v1beta1.DedicatedResources( machine_spec=expected_machine_spec, min_replica_count=1, max_replica_count=1, ) expected_deployed_model = gca_endpoint_v1beta1.DeployedModel( dedicated_resources=expected_dedicated_resources, model=test_model.resource_name, display_name=None, explanation_spec=gca_endpoint_v1beta1.explanation.ExplanationSpec( metadata=_TEST_EXPLANATION_METADATA, parameters=_TEST_EXPLANATION_PARAMETERS, ), ) deploy_model_with_explanations_mock.assert_called_once_with( endpoint=test_endpoint.resource_name, deployed_model=expected_deployed_model, traffic_split={"0": 100}, metadata=(), )
def test_batch_predict_with_all_args( self, create_batch_prediction_job_with_explanations_mock, sync): aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) creds = auth_credentials.AnonymousCredentials() batch_prediction_job = jobs.BatchPredictionJob.create( model_name=_TEST_MODEL_NAME, job_display_name=_TEST_BATCH_PREDICTION_JOB_DISPLAY_NAME, gcs_source=_TEST_BATCH_PREDICTION_GCS_SOURCE, gcs_destination_prefix=_TEST_BATCH_PREDICTION_GCS_DEST_PREFIX, predictions_format="csv", model_parameters={}, machine_type=_TEST_MACHINE_TYPE, accelerator_type=_TEST_ACCELERATOR_TYPE, accelerator_count=_TEST_ACCELERATOR_COUNT, starting_replica_count=_TEST_STARTING_REPLICA_COUNT, max_replica_count=_TEST_MAX_REPLICA_COUNT, generate_explanation=True, explanation_metadata=_TEST_EXPLANATION_METADATA, explanation_parameters=_TEST_EXPLANATION_PARAMETERS, labels=_TEST_LABEL, credentials=creds, sync=sync, ) if not sync: batch_prediction_job.wait() # Construct expected request expected_gapic_batch_prediction_job = gca_batch_prediction_job_v1beta1.BatchPredictionJob( display_name=_TEST_BATCH_PREDICTION_JOB_DISPLAY_NAME, model=_TEST_MODEL_NAME, input_config=gca_batch_prediction_job_v1beta1.BatchPredictionJob. InputConfig( instances_format="jsonl", gcs_source=gca_io_v1beta1.GcsSource( uris=[_TEST_BATCH_PREDICTION_GCS_SOURCE]), ), output_config=gca_batch_prediction_job_v1beta1.BatchPredictionJob. OutputConfig( gcs_destination=gca_io_v1beta1.GcsDestination( output_uri_prefix=_TEST_BATCH_PREDICTION_GCS_DEST_PREFIX), predictions_format="csv", ), dedicated_resources=gca_machine_resources_v1beta1. BatchDedicatedResources( machine_spec=gca_machine_resources_v1beta1.MachineSpec( machine_type=_TEST_MACHINE_TYPE, accelerator_type=_TEST_ACCELERATOR_TYPE, accelerator_count=_TEST_ACCELERATOR_COUNT, ), starting_replica_count=_TEST_STARTING_REPLICA_COUNT, max_replica_count=_TEST_MAX_REPLICA_COUNT, ), generate_explanation=True, explanation_spec=gca_explanation_v1beta1.ExplanationSpec( metadata=_TEST_EXPLANATION_METADATA, parameters=_TEST_EXPLANATION_PARAMETERS, ), labels=_TEST_LABEL, ) create_batch_prediction_job_with_explanations_mock.assert_called_once_with( parent=f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}", batch_prediction_job=expected_gapic_batch_prediction_job, )