def get_batch_prediction_job_mock(): with patch.object( _TEST_API_CLIENT, "get_batch_prediction_job" ) as get_batch_prediction_job_mock: get_batch_prediction_job_mock.side_effect = [ gca_batch_prediction_job_compat.BatchPredictionJob( name=_TEST_BATCH_PREDICTION_JOB_NAME, display_name=_TEST_DISPLAY_NAME, state=_TEST_JOB_STATE_PENDING, ), gca_batch_prediction_job_compat.BatchPredictionJob( name=_TEST_BATCH_PREDICTION_JOB_NAME, display_name=_TEST_DISPLAY_NAME, state=_TEST_JOB_STATE_RUNNING, ), gca_batch_prediction_job_compat.BatchPredictionJob( name=_TEST_BATCH_PREDICTION_JOB_NAME, display_name=_TEST_DISPLAY_NAME, state=_TEST_JOB_STATE_SUCCESS, ), gca_batch_prediction_job_compat.BatchPredictionJob( name=_TEST_BATCH_PREDICTION_JOB_NAME, display_name=_TEST_DISPLAY_NAME, state=_TEST_JOB_STATE_SUCCESS, ), ] yield get_batch_prediction_job_mock
def create_batch_prediction_job_with_explanations_mock(): with mock.patch.object( _TEST_API_CLIENT, "create_batch_prediction_job" ) as create_batch_prediction_job_mock: create_batch_prediction_job_mock.return_value = gca_batch_prediction_job_compat.BatchPredictionJob( name=_TEST_BATCH_PREDICTION_JOB_NAME, display_name=_TEST_DISPLAY_NAME, state=_TEST_JOB_STATE_SUCCESS, ) yield create_batch_prediction_job_mock
def get_batch_prediction_job_running_bq_output_mock(): with patch.object( _TEST_API_CLIENT, "get_batch_prediction_job" ) as get_batch_prediction_job_mock: get_batch_prediction_job_mock.return_value = gca_batch_prediction_job_compat.BatchPredictionJob( name=_TEST_BATCH_PREDICTION_JOB_NAME, display_name=_TEST_DISPLAY_NAME, model=_TEST_MODEL_NAME, input_config=_TEST_GCS_INPUT_CONFIG, output_config=_TEST_BQ_OUTPUT_CONFIG, output_info=_TEST_BQ_OUTPUT_INFO, state=_TEST_JOB_STATE_RUNNING, ) yield get_batch_prediction_job_mock
def test_batch_predict_gcs_source_bq_dest( self, create_batch_prediction_job_mock, sync ): aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) batch_prediction_job = jobs.BatchPredictionJob.create( model_name=_TEST_MODEL_NAME, job_display_name=_TEST_BATCH_PREDICTION_JOB_DISPLAY_NAME, gcs_source=_TEST_BATCH_PREDICTION_GCS_SOURCE, bigquery_destination_prefix=_TEST_BATCH_PREDICTION_BQ_PREFIX, sync=sync, create_request_timeout=None, ) batch_prediction_job.wait_for_resource_creation() batch_prediction_job.wait() assert ( batch_prediction_job.output_info == gca_batch_prediction_job_compat.BatchPredictionJob.OutputInfo() ) # Construct expected request expected_gapic_batch_prediction_job = gca_batch_prediction_job_compat.BatchPredictionJob( display_name=_TEST_BATCH_PREDICTION_JOB_DISPLAY_NAME, model=_TEST_MODEL_NAME, input_config=gca_batch_prediction_job_compat.BatchPredictionJob.InputConfig( instances_format="jsonl", gcs_source=gca_io_compat.GcsSource( uris=[_TEST_BATCH_PREDICTION_GCS_SOURCE] ), ), output_config=gca_batch_prediction_job_compat.BatchPredictionJob.OutputConfig( bigquery_destination=gca_io_compat.BigQueryDestination( output_uri=_TEST_BATCH_PREDICTION_BQ_DEST_PREFIX_WITH_PROTOCOL ), predictions_format="bigquery", ), ) create_batch_prediction_job_mock.assert_called_once_with( parent=_TEST_PARENT, batch_prediction_job=expected_gapic_batch_prediction_job, timeout=None, )
def test_batch_predict_gcs_source_and_dest_with_timeout( self, create_batch_prediction_job_mock, sync ): aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) # Make SDK batch_predict method call batch_prediction_job = jobs.BatchPredictionJob.create( model_name=_TEST_MODEL_NAME, job_display_name=_TEST_BATCH_PREDICTION_JOB_DISPLAY_NAME, gcs_source=_TEST_BATCH_PREDICTION_GCS_SOURCE, gcs_destination_prefix=_TEST_BATCH_PREDICTION_GCS_DEST_PREFIX, sync=sync, create_request_timeout=180.0, ) batch_prediction_job.wait_for_resource_creation() batch_prediction_job.wait() # Construct expected request expected_gapic_batch_prediction_job = gca_batch_prediction_job_compat.BatchPredictionJob( display_name=_TEST_BATCH_PREDICTION_JOB_DISPLAY_NAME, model=_TEST_MODEL_NAME, input_config=gca_batch_prediction_job_compat.BatchPredictionJob.InputConfig( instances_format="jsonl", gcs_source=gca_io_compat.GcsSource( uris=[_TEST_BATCH_PREDICTION_GCS_SOURCE] ), ), output_config=gca_batch_prediction_job_compat.BatchPredictionJob.OutputConfig( gcs_destination=gca_io_compat.GcsDestination( output_uri_prefix=_TEST_BATCH_PREDICTION_GCS_DEST_PREFIX ), predictions_format="jsonl", ), ) create_batch_prediction_job_mock.assert_called_once_with( parent=_TEST_PARENT, batch_prediction_job=expected_gapic_batch_prediction_job, timeout=180.0, )
def test_batch_predict_with_all_args( self, create_batch_prediction_job_with_explanations_mock, sync ): aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) creds = auth_credentials.AnonymousCredentials() batch_prediction_job = jobs.BatchPredictionJob.create( model_name=_TEST_MODEL_NAME, job_display_name=_TEST_BATCH_PREDICTION_JOB_DISPLAY_NAME, gcs_source=_TEST_BATCH_PREDICTION_GCS_SOURCE, gcs_destination_prefix=_TEST_BATCH_PREDICTION_GCS_DEST_PREFIX, predictions_format="csv", model_parameters={}, machine_type=_TEST_MACHINE_TYPE, accelerator_type=_TEST_ACCELERATOR_TYPE, accelerator_count=_TEST_ACCELERATOR_COUNT, starting_replica_count=_TEST_STARTING_REPLICA_COUNT, max_replica_count=_TEST_MAX_REPLICA_COUNT, generate_explanation=True, explanation_metadata=_TEST_EXPLANATION_METADATA, explanation_parameters=_TEST_EXPLANATION_PARAMETERS, labels=_TEST_LABEL, credentials=creds, sync=sync, create_request_timeout=None, ) batch_prediction_job.wait_for_resource_creation() batch_prediction_job.wait() # Construct expected request expected_gapic_batch_prediction_job = gca_batch_prediction_job_compat.BatchPredictionJob( display_name=_TEST_BATCH_PREDICTION_JOB_DISPLAY_NAME, model=_TEST_MODEL_NAME, input_config=gca_batch_prediction_job_compat.BatchPredictionJob.InputConfig( instances_format="jsonl", gcs_source=gca_io_compat.GcsSource( uris=[_TEST_BATCH_PREDICTION_GCS_SOURCE] ), ), output_config=gca_batch_prediction_job_compat.BatchPredictionJob.OutputConfig( gcs_destination=gca_io_compat.GcsDestination( output_uri_prefix=_TEST_BATCH_PREDICTION_GCS_DEST_PREFIX ), predictions_format="csv", ), dedicated_resources=gca_machine_resources_compat.BatchDedicatedResources( machine_spec=gca_machine_resources_compat.MachineSpec( machine_type=_TEST_MACHINE_TYPE, accelerator_type=_TEST_ACCELERATOR_TYPE, accelerator_count=_TEST_ACCELERATOR_COUNT, ), starting_replica_count=_TEST_STARTING_REPLICA_COUNT, max_replica_count=_TEST_MAX_REPLICA_COUNT, ), generate_explanation=True, explanation_spec=gca_explanation_compat.ExplanationSpec( metadata=_TEST_EXPLANATION_METADATA, parameters=_TEST_EXPLANATION_PARAMETERS, ), labels=_TEST_LABEL, ) create_batch_prediction_job_with_explanations_mock.assert_called_once_with( parent=f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}", batch_prediction_job=expected_gapic_batch_prediction_job, timeout=None, )