Esempio n. 1
0
def create_batch_prediction_job_mock():
    with mock.patch.object(
            job_service_client.JobServiceClient,
            "create_batch_prediction_job") as create_batch_prediction_job_mock:
        create_batch_prediction_job_mock.return_value = gca_batch_prediction_job.BatchPredictionJob(
            name=_TEST_BATCH_PREDICTION_JOB_NAME,
            display_name=_TEST_DISPLAY_NAME,
            state=_TEST_JOB_STATE_SUCCESS,
        )
        yield create_batch_prediction_job_mock
Esempio n. 2
0
def get_batch_prediction_job_mock():
    with patch.object(
            job_service_client.JobServiceClient,
            "get_batch_prediction_job") as get_batch_prediction_job_mock:
        get_batch_prediction_job_mock.side_effect = [
            gca_batch_prediction_job.BatchPredictionJob(
                name=_TEST_BATCH_PREDICTION_JOB_NAME,
                display_name=_TEST_DISPLAY_NAME,
                state=_TEST_JOB_STATE_RUNNING,
            ),
            gca_batch_prediction_job.BatchPredictionJob(
                name=_TEST_BATCH_PREDICTION_JOB_NAME,
                display_name=_TEST_DISPLAY_NAME,
                state=_TEST_JOB_STATE_SUCCESS,
            ),
            gca_batch_prediction_job.BatchPredictionJob(
                name=_TEST_BATCH_PREDICTION_JOB_NAME,
                display_name=_TEST_DISPLAY_NAME,
                state=_TEST_JOB_STATE_SUCCESS,
            ),
        ]
        yield get_batch_prediction_job_mock
Esempio n. 3
0
def get_batch_prediction_job_running_bq_output_mock():
    with patch.object(
            job_service_client.JobServiceClient,
            "get_batch_prediction_job") as get_batch_prediction_job_mock:
        get_batch_prediction_job_mock.return_value = gca_batch_prediction_job.BatchPredictionJob(
            name=_TEST_BATCH_PREDICTION_JOB_NAME,
            display_name=_TEST_DISPLAY_NAME,
            model=_TEST_MODEL_NAME,
            input_config=_TEST_GCS_INPUT_CONFIG,
            output_config=_TEST_BQ_OUTPUT_CONFIG,
            output_info=_TEST_BQ_OUTPUT_INFO,
            state=_TEST_JOB_STATE_RUNNING,
        )
        yield get_batch_prediction_job_mock
    def test_init_aiplatform_with_encryption_key_name_and_batch_predict_gcs_source_and_dest(
        self, create_batch_prediction_job_mock, sync
    ):
        aiplatform.init(
            project=_TEST_PROJECT,
            location=_TEST_LOCATION,
            encryption_spec_key_name=_TEST_ENCRYPTION_KEY_NAME,
        )
        test_model = models.Model(_TEST_ID)

        # Make SDK batch_predict method call
        batch_prediction_job = test_model.batch_predict(
            job_display_name=_TEST_BATCH_PREDICTION_DISPLAY_NAME,
            gcs_source=_TEST_BATCH_PREDICTION_GCS_SOURCE,
            gcs_destination_prefix=_TEST_BATCH_PREDICTION_GCS_DEST_PREFIX,
            sync=sync,
        )

        if not sync:
            batch_prediction_job.wait()

        # Construct expected request
        expected_gapic_batch_prediction_job = gca_batch_prediction_job.BatchPredictionJob(
            display_name=_TEST_BATCH_PREDICTION_DISPLAY_NAME,
            model=model_service_client.ModelServiceClient.model_path(
                _TEST_PROJECT, _TEST_LOCATION, _TEST_ID
            ),
            input_config=gca_batch_prediction_job.BatchPredictionJob.InputConfig(
                instances_format="jsonl",
                gcs_source=gca_io.GcsSource(uris=[_TEST_BATCH_PREDICTION_GCS_SOURCE]),
            ),
            output_config=gca_batch_prediction_job.BatchPredictionJob.OutputConfig(
                gcs_destination=gca_io.GcsDestination(
                    output_uri_prefix=_TEST_BATCH_PREDICTION_GCS_DEST_PREFIX
                ),
                predictions_format="jsonl",
            ),
            encryption_spec=_TEST_ENCRYPTION_SPEC,
        )

        create_batch_prediction_job_mock.assert_called_once_with(
            parent=_TEST_PARENT,
            batch_prediction_job=expected_gapic_batch_prediction_job,
        )
Esempio n. 5
0
    def test_batch_predict_gcs_source_bq_dest(self,
                                              create_batch_prediction_job_mock,
                                              sync):
        aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION)

        batch_prediction_job = jobs.BatchPredictionJob.create(
            model_name=_TEST_MODEL_NAME,
            job_display_name=_TEST_BATCH_PREDICTION_JOB_DISPLAY_NAME,
            gcs_source=_TEST_BATCH_PREDICTION_GCS_SOURCE,
            bigquery_destination_prefix=_TEST_BATCH_PREDICTION_BQ_PREFIX,
            sync=sync,
        )

        if not sync:
            batch_prediction_job.wait()

        assert (batch_prediction_job.output_info ==
                gca_batch_prediction_job.BatchPredictionJob.OutputInfo())

        # Construct expected request
        expected_gapic_batch_prediction_job = gca_batch_prediction_job.BatchPredictionJob(
            display_name=_TEST_BATCH_PREDICTION_JOB_DISPLAY_NAME,
            model=_TEST_MODEL_NAME,
            input_config=gca_batch_prediction_job.BatchPredictionJob.
            InputConfig(
                instances_format="jsonl",
                gcs_source=gca_io.GcsSource(
                    uris=[_TEST_BATCH_PREDICTION_GCS_SOURCE]),
            ),
            output_config=gca_batch_prediction_job.BatchPredictionJob.
            OutputConfig(
                bigquery_destination=gca_io.BigQueryDestination(
                    output_uri=
                    _TEST_BATCH_PREDICTION_BQ_DEST_PREFIX_WITH_PROTOCOL),
                predictions_format="bigquery",
            ),
        )

        create_batch_prediction_job_mock.assert_called_once_with(
            parent=_TEST_PARENT,
            batch_prediction_job=expected_gapic_batch_prediction_job,
        )
Esempio n. 6
0
    def test_batch_predict_gcs_source_bq_dest(self,
                                              create_batch_prediction_job_mock,
                                              sync):

        test_model = models.Model(_TEST_ID)

        # Make SDK batch_predict method call
        batch_prediction_job = test_model.batch_predict(
            job_display_name=_TEST_BATCH_PREDICTION_DISPLAY_NAME,
            gcs_source=_TEST_BATCH_PREDICTION_GCS_SOURCE,
            bigquery_destination_prefix=_TEST_BATCH_PREDICTION_BQ_PREFIX,
            sync=sync,
        )

        if not sync:
            batch_prediction_job.wait()

        # Construct expected request
        expected_gapic_batch_prediction_job = gca_batch_prediction_job.BatchPredictionJob(
            display_name=_TEST_BATCH_PREDICTION_DISPLAY_NAME,
            model=model_service_client.ModelServiceClient.model_path(
                _TEST_PROJECT, _TEST_LOCATION, _TEST_ID),
            input_config=gca_batch_prediction_job.BatchPredictionJob.
            InputConfig(
                instances_format="jsonl",
                gcs_source=gca_io.GcsSource(
                    uris=[_TEST_BATCH_PREDICTION_GCS_SOURCE]),
            ),
            output_config=gca_batch_prediction_job.BatchPredictionJob.
            OutputConfig(
                bigquery_destination=gca_io.BigQueryDestination(
                    output_uri=
                    _TEST_BATCH_PREDICTION_BQ_DEST_PREFIX_WITH_PROTOCOL),
                predictions_format="bigquery",
            ),
        )

        create_batch_prediction_job_mock.assert_called_once_with(
            parent=_TEST_PARENT,
            batch_prediction_job=expected_gapic_batch_prediction_job,
        )