コード例 #1
0
 def test_print_model(self):
     aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION)
     test_model = models.Model(_TEST_ID)
     assert (
         repr(test_model)
         == f"{object.__repr__(test_model)} \nresource name: {test_model.resource_name}"
     )
コード例 #2
0
 def test_constructor_gets_model_with_custom_location(self, get_model_mock):
     aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION)
     models.Model(_TEST_ID, location=_TEST_LOCATION_2)
     test_model_resource_name = model_service_client.ModelServiceClient.model_path(
         _TEST_PROJECT, _TEST_LOCATION_2, _TEST_ID
     )
     get_model_mock.assert_called_once_with(name=test_model_resource_name)
コード例 #3
0
    def test_deploy_with_display_name(self, deploy_model_mock, sync):
        test_endpoint = models.Endpoint(_TEST_ENDPOINT_NAME)
        test_model = models.Model(_TEST_ID)
        test_endpoint.deploy(model=test_model,
                             deployed_model_display_name=_TEST_DISPLAY_NAME,
                             sync=sync)

        if not sync:
            test_endpoint.wait()

        automatic_resources = gca_machine_resources.AutomaticResources(
            min_replica_count=1,
            max_replica_count=1,
        )
        deployed_model = gca_endpoint.DeployedModel(
            automatic_resources=automatic_resources,
            model=test_model.resource_name,
            display_name=_TEST_DISPLAY_NAME,
        )
        deploy_model_mock.assert_called_once_with(
            endpoint=test_endpoint.resource_name,
            deployed_model=deployed_model,
            traffic_split={"0": 100},
            metadata=(),
        )
コード例 #4
0
    def test_deploy_no_endpoint_dedicated_resources(self, deploy_model_mock, sync):

        test_model = models.Model(_TEST_ID)
        test_endpoint = test_model.deploy(
            machine_type=_TEST_MACHINE_TYPE,
            accelerator_type=_TEST_ACCELERATOR_TYPE,
            accelerator_count=_TEST_ACCELERATOR_COUNT,
            sync=sync,
        )

        if not sync:
            test_endpoint.wait()

        expected_machine_spec = gca_machine_resources.MachineSpec(
            machine_type=_TEST_MACHINE_TYPE,
            accelerator_type=_TEST_ACCELERATOR_TYPE,
            accelerator_count=_TEST_ACCELERATOR_COUNT,
        )
        expected_dedicated_resources = gca_machine_resources.DedicatedResources(
            machine_spec=expected_machine_spec, min_replica_count=1, max_replica_count=1
        )
        expected_deployed_model = gca_endpoint.DeployedModel(
            dedicated_resources=expected_dedicated_resources,
            model=test_model.resource_name,
            display_name=None,
        )
        deploy_model_mock.assert_called_once_with(
            endpoint=test_endpoint.resource_name,
            deployed_model=expected_deployed_model,
            traffic_split={"0": 100},
            metadata=(),
        )
コード例 #5
0
 def test_deploy_raise_error_max_replica(self, sync):
     with pytest.raises(ValueError):
         test_endpoint = models.Endpoint(_TEST_ENDPOINT_NAME)
         test_model = models.Model(_TEST_ID)
         test_endpoint.deploy(model=test_model,
                              max_replica_count=-2,
                              sync=sync)
コード例 #6
0
 def test_deploy_raise_error_traffic_negative(self, sync):
     with pytest.raises(ValueError):
         test_endpoint = models.Endpoint(_TEST_ENDPOINT_NAME)
         test_model = models.Model(_TEST_ID)
         test_endpoint.deploy(model=test_model,
                              traffic_percentage=-18,
                              sync=sync)
コード例 #7
0
 def test_deploy_raise_error_traffic_split(self, sync):
     with pytest.raises(ValueError):
         test_endpoint = models.Endpoint(_TEST_ENDPOINT_NAME)
         test_model = models.Model(_TEST_ID)
         test_endpoint.deploy(model=test_model,
                              traffic_split={"a": 99},
                              sync=sync)
コード例 #8
0
    def test_deploy_with_traffic_percent(self, deploy_model_mock, sync):
        with mock.patch.object(endpoint_service_client.EndpointServiceClient,
                               "get_endpoint") as get_endpoint_mock:
            get_endpoint_mock.return_value = gca_endpoint.Endpoint(
                display_name=_TEST_DISPLAY_NAME,
                name=_TEST_ENDPOINT_NAME,
                traffic_split={"model1": 100},
            )

            test_endpoint = models.Endpoint(_TEST_ENDPOINT_NAME)
            test_model = models.Model(_TEST_ID)
            test_endpoint.deploy(model=test_model,
                                 traffic_percentage=70,
                                 sync=sync)
            if not sync:
                test_endpoint.wait()
            automatic_resources = gca_machine_resources.AutomaticResources(
                min_replica_count=1,
                max_replica_count=1,
            )
            deployed_model = gca_endpoint.DeployedModel(
                automatic_resources=automatic_resources,
                model=test_model.resource_name,
                display_name=None,
            )
            deploy_model_mock.assert_called_once_with(
                endpoint=test_endpoint.resource_name,
                deployed_model=deployed_model,
                traffic_split={
                    "model1": 30,
                    "0": 70
                },
                metadata=(),
            )
コード例 #9
0
 def test_deploy_raise_error_traffic_split(self, sync):
     with pytest.raises(ValueError):
         aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION)
         test_endpoint = models.Endpoint(_TEST_ENDPOINT_NAME)
         test_model = models.Model(_TEST_ID)
         test_endpoint.deploy(model=test_model,
                              traffic_split={"a": 99},
                              sync=sync)
コード例 #10
0
 def test_deploy_raise_error_max_replica(self, sync):
     with pytest.raises(ValueError):
         aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION)
         test_endpoint = models.Endpoint(_TEST_ENDPOINT_NAME)
         test_model = models.Model(_TEST_ID)
         test_endpoint.deploy(model=test_model,
                              max_replica_count=-2,
                              sync=sync)
コード例 #11
0
    def test_delete_model(self, delete_model_mock, sync):

        test_model = models.Model(_TEST_ID)
        test_model.delete(sync=sync)

        if not sync:
            test_model.wait()

        delete_model_mock.assert_called_once_with(name=test_model.resource_name)
コード例 #12
0
 def test_print_model_if_waiting(self):
     aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION)
     test_model = models.Model(_TEST_ID)
     test_model._gca_resource = None
     test_model._latest_future = futures.Future()
     assert (
         repr(test_model)
         == f"{object.__repr__(test_model)} is waiting for upstream dependencies to complete."
     )
コード例 #13
0
 def test_print_model_if_exception(self):
     aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION)
     test_model = models.Model(_TEST_ID)
     test_model._gca_resource = None
     mock_exception = Exception("mock exception")
     test_model._exception = mock_exception
     assert (
         repr(test_model) ==
         f"{object.__repr__(test_model)} failed with {str(mock_exception)}")
コード例 #14
0
    def test_deploy_raise_error_traffic_80(self, sync):
        with pytest.raises(ValueError):
            aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION)
            test_endpoint = models.Endpoint(_TEST_ENDPOINT_NAME)
            test_model = models.Model(_TEST_ID)
            test_endpoint.deploy(model=test_model, traffic_percentage=80, sync=sync)

            if not sync:
                test_endpoint.wait()
コード例 #15
0
    def test_deploy_raise_error_traffic_80(self, sync):
        with pytest.raises(ValueError):
            test_endpoint = models.Endpoint(_TEST_ENDPOINT_NAME)
            test_model = models.Model(_TEST_ID)
            test_endpoint.deploy(model=test_model,
                                 traffic_percentage=80,
                                 sync=sync)

            if not sync:
                test_endpoint.wait()
コード例 #16
0
 def test_constructor_creates_client_with_custom_credentials(
         self, create_client_mock):
     aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION)
     creds = auth_credentials.AnonymousCredentials()
     models.Model(_TEST_ID, credentials=creds)
     create_client_mock.assert_called_once_with(
         client_class=utils.ModelClientWithOverride,
         credentials=creds,
         location_override=_TEST_LOCATION,
         prediction_client=False,
     )
コード例 #17
0
    def test_batch_predict_no_source(self, create_batch_prediction_job_mock):

        test_model = models.Model(_TEST_ID)

        # Make SDK batch_predict method call without source
        with pytest.raises(ValueError) as e:
            test_model.batch_predict(
                job_display_name=_TEST_BATCH_PREDICTION_DISPLAY_NAME,
                bigquery_destination_prefix=_TEST_BATCH_PREDICTION_BQ_PREFIX,
            )

        assert e.match(regexp=r"source")
コード例 #18
0
    def test_batch_predict_no_destination(self):

        test_model = models.Model(_TEST_ID)

        # Make SDK batch_predict method call without destination
        with pytest.raises(ValueError) as e:
            test_model.batch_predict(
                job_display_name=_TEST_BATCH_PREDICTION_DISPLAY_NAME,
                gcs_source=_TEST_BATCH_PREDICTION_GCS_SOURCE,
            )

        assert e.match(regexp=r"destination")
コード例 #19
0
 def test_constructor_create_client_with_custom_location(self, create_client_mock):
     aiplatform.init(
         project=_TEST_PROJECT,
         location=_TEST_LOCATION,
         credentials=_TEST_CREDENTIALS,
     )
     models.Model(_TEST_ID, location=_TEST_LOCATION_2)
     create_client_mock.assert_called_once_with(
         client_class=utils.ModelClientWithOverride,
         credentials=initializer.global_config.credentials,
         location_override=_TEST_LOCATION_2,
         prediction_client=False,
     )
コード例 #20
0
    def test_deploy_raises_with_impartial_explanation_spec(self):

        test_model = models.Model(_TEST_ID)

        with pytest.raises(ValueError) as e:
            test_model.deploy(
                machine_type=_TEST_MACHINE_TYPE,
                accelerator_type=_TEST_ACCELERATOR_TYPE,
                accelerator_count=_TEST_ACCELERATOR_COUNT,
                explanation_metadata=_TEST_EXPLANATION_METADATA,
                # Missing required `explanation_parameters` argument
            )

        assert e.match(regexp=r"`explanation_parameters` should be specified or None.")
コード例 #21
0
    def test_batch_predict_wrong_prediction_format(self):

        test_model = models.Model(_TEST_ID)

        # Make SDK batch_predict method call
        with pytest.raises(ValueError) as e:
            test_model.batch_predict(
                job_display_name=_TEST_BATCH_PREDICTION_DISPLAY_NAME,
                gcs_source=_TEST_BATCH_PREDICTION_GCS_SOURCE,
                predictions_format="wrong",
                bigquery_destination_prefix=_TEST_BATCH_PREDICTION_BQ_PREFIX,
            )

        assert e.match(regexp=r"accepted prediction format")
コード例 #22
0
    def test_deploy_with_explanations(self,
                                      deploy_model_with_explanations_mock,
                                      sync):
        aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION)
        test_endpoint = models.Endpoint(_TEST_ENDPOINT_NAME)
        test_model = models.Model(_TEST_ID)
        test_endpoint.deploy(
            model=test_model,
            machine_type=_TEST_MACHINE_TYPE,
            accelerator_type=_TEST_ACCELERATOR_TYPE,
            accelerator_count=_TEST_ACCELERATOR_COUNT,
            explanation_metadata=_TEST_EXPLANATION_METADATA,
            explanation_parameters=_TEST_EXPLANATION_PARAMETERS,
            sync=sync,
        )

        if not sync:
            test_endpoint.wait()

        expected_machine_spec = gca_machine_resources_v1beta1.MachineSpec(
            machine_type=_TEST_MACHINE_TYPE,
            accelerator_type=_TEST_ACCELERATOR_TYPE,
            accelerator_count=_TEST_ACCELERATOR_COUNT,
        )
        expected_dedicated_resources = gca_machine_resources_v1beta1.DedicatedResources(
            machine_spec=expected_machine_spec,
            min_replica_count=1,
            max_replica_count=1,
        )
        expected_deployed_model = gca_endpoint_v1beta1.DeployedModel(
            dedicated_resources=expected_dedicated_resources,
            model=test_model.resource_name,
            display_name=None,
            explanation_spec=gca_endpoint_v1beta1.explanation.ExplanationSpec(
                metadata=_TEST_EXPLANATION_METADATA,
                parameters=_TEST_EXPLANATION_PARAMETERS,
            ),
        )
        deploy_model_with_explanations_mock.assert_called_once_with(
            endpoint=test_endpoint.resource_name,
            deployed_model=expected_deployed_model,
            traffic_split={"0": 100},
            metadata=(),
        )
コード例 #23
0
    def test_init_aiplatform_with_encryption_key_name_and_batch_predict_gcs_source_and_dest(
        self, create_batch_prediction_job_mock, sync
    ):
        aiplatform.init(
            project=_TEST_PROJECT,
            location=_TEST_LOCATION,
            encryption_spec_key_name=_TEST_ENCRYPTION_KEY_NAME,
        )
        test_model = models.Model(_TEST_ID)

        # Make SDK batch_predict method call
        batch_prediction_job = test_model.batch_predict(
            job_display_name=_TEST_BATCH_PREDICTION_DISPLAY_NAME,
            gcs_source=_TEST_BATCH_PREDICTION_GCS_SOURCE,
            gcs_destination_prefix=_TEST_BATCH_PREDICTION_GCS_DEST_PREFIX,
            sync=sync,
        )

        if not sync:
            batch_prediction_job.wait()

        # Construct expected request
        expected_gapic_batch_prediction_job = gca_batch_prediction_job.BatchPredictionJob(
            display_name=_TEST_BATCH_PREDICTION_DISPLAY_NAME,
            model=model_service_client.ModelServiceClient.model_path(
                _TEST_PROJECT, _TEST_LOCATION, _TEST_ID
            ),
            input_config=gca_batch_prediction_job.BatchPredictionJob.InputConfig(
                instances_format="jsonl",
                gcs_source=gca_io.GcsSource(uris=[_TEST_BATCH_PREDICTION_GCS_SOURCE]),
            ),
            output_config=gca_batch_prediction_job.BatchPredictionJob.OutputConfig(
                gcs_destination=gca_io.GcsDestination(
                    output_uri_prefix=_TEST_BATCH_PREDICTION_GCS_DEST_PREFIX
                ),
                predictions_format="jsonl",
            ),
            encryption_spec=_TEST_ENCRYPTION_SPEC,
        )

        create_batch_prediction_job_mock.assert_called_once_with(
            parent=_TEST_PARENT,
            batch_prediction_job=expected_gapic_batch_prediction_job,
        )
コード例 #24
0
    def test_batch_predict_gcs_source_bq_dest(self,
                                              create_batch_prediction_job_mock,
                                              sync):

        test_model = models.Model(_TEST_ID)

        # Make SDK batch_predict method call
        batch_prediction_job = test_model.batch_predict(
            job_display_name=_TEST_BATCH_PREDICTION_DISPLAY_NAME,
            gcs_source=_TEST_BATCH_PREDICTION_GCS_SOURCE,
            bigquery_destination_prefix=_TEST_BATCH_PREDICTION_BQ_PREFIX,
            sync=sync,
        )

        if not sync:
            batch_prediction_job.wait()

        # Construct expected request
        expected_gapic_batch_prediction_job = gca_batch_prediction_job.BatchPredictionJob(
            display_name=_TEST_BATCH_PREDICTION_DISPLAY_NAME,
            model=model_service_client.ModelServiceClient.model_path(
                _TEST_PROJECT, _TEST_LOCATION, _TEST_ID),
            input_config=gca_batch_prediction_job.BatchPredictionJob.
            InputConfig(
                instances_format="jsonl",
                gcs_source=gca_io.GcsSource(
                    uris=[_TEST_BATCH_PREDICTION_GCS_SOURCE]),
            ),
            output_config=gca_batch_prediction_job.BatchPredictionJob.
            OutputConfig(
                bigquery_destination=gca_io.BigQueryDestination(
                    output_uri=
                    _TEST_BATCH_PREDICTION_BQ_DEST_PREFIX_WITH_PROTOCOL),
                predictions_format="bigquery",
            ),
        )

        create_batch_prediction_job_mock.assert_called_once_with(
            parent=_TEST_PARENT,
            batch_prediction_job=expected_gapic_batch_prediction_job,
        )
コード例 #25
0
 def test_deploy_with_max_replica_count(self, deploy_model_mock, sync):
     aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION)
     test_endpoint = models.Endpoint(_TEST_ENDPOINT_NAME)
     test_model = models.Model(_TEST_ID)
     test_endpoint.deploy(model=test_model, max_replica_count=2, sync=sync)
     if not sync:
         test_endpoint.wait()
     automatic_resources = gca_machine_resources.AutomaticResources(
         min_replica_count=1, max_replica_count=2,
     )
     deployed_model = gca_endpoint.DeployedModel(
         automatic_resources=automatic_resources,
         model=test_model.resource_name,
         display_name=None,
     )
     deploy_model_mock.assert_called_once_with(
         endpoint=test_endpoint.resource_name,
         deployed_model=deployed_model,
         traffic_split={"0": 100},
         metadata=(),
     )
コード例 #26
0
    def test_deploy_with_dedicated_resources(self, deploy_model_mock, sync):
        aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION)
        test_endpoint = models.Endpoint(_TEST_ENDPOINT_NAME)
        test_model = models.Model(_TEST_ID)
        test_endpoint.deploy(
            model=test_model,
            machine_type=_TEST_MACHINE_TYPE,
            accelerator_type=_TEST_ACCELERATOR_TYPE,
            accelerator_count=_TEST_ACCELERATOR_COUNT,
            service_account=_TEST_SERVICE_ACCOUNT,
            sync=sync,
        )

        if not sync:
            test_endpoint.wait()

        expected_machine_spec = gca_machine_resources.MachineSpec(
            machine_type=_TEST_MACHINE_TYPE,
            accelerator_type=_TEST_ACCELERATOR_TYPE,
            accelerator_count=_TEST_ACCELERATOR_COUNT,
        )
        expected_dedicated_resources = gca_machine_resources.DedicatedResources(
            machine_spec=expected_machine_spec,
            min_replica_count=1,
            max_replica_count=1,
        )
        expected_deployed_model = gca_endpoint.DeployedModel(
            dedicated_resources=expected_dedicated_resources,
            model=test_model.resource_name,
            display_name=None,
            service_account=_TEST_SERVICE_ACCOUNT,
        )
        deploy_model_mock.assert_called_once_with(
            endpoint=test_endpoint.resource_name,
            deployed_model=expected_deployed_model,
            traffic_split={"0": 100},
            metadata=(),
        )
コード例 #27
0
 def test_constructor_gets_model(self, get_model_mock):
     aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION)
     models.Model(_TEST_ID)
     get_model_mock.assert_called_once_with(name=_TEST_MODEL_RESOURCE_NAME)
コード例 #28
0
    def test_batch_predict_with_all_args(
        self, create_batch_prediction_job_with_explanations_mock, sync
    ):
        aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION)
        test_model = models.Model(_TEST_ID)
        creds = auth_credentials.AnonymousCredentials()

        # Make SDK batch_predict method call passing all arguments
        batch_prediction_job = test_model.batch_predict(
            job_display_name=_TEST_BATCH_PREDICTION_DISPLAY_NAME,
            gcs_source=_TEST_BATCH_PREDICTION_GCS_SOURCE,
            gcs_destination_prefix=_TEST_BATCH_PREDICTION_GCS_DEST_PREFIX,
            predictions_format="csv",
            model_parameters={},
            machine_type=_TEST_MACHINE_TYPE,
            accelerator_type=_TEST_ACCELERATOR_TYPE,
            accelerator_count=_TEST_ACCELERATOR_COUNT,
            starting_replica_count=_TEST_STARTING_REPLICA_COUNT,
            max_replica_count=_TEST_MAX_REPLICA_COUNT,
            generate_explanation=True,
            explanation_metadata=_TEST_EXPLANATION_METADATA,
            explanation_parameters=_TEST_EXPLANATION_PARAMETERS,
            labels=_TEST_LABEL,
            credentials=creds,
            encryption_spec_key_name=_TEST_ENCRYPTION_KEY_NAME,
            sync=sync,
        )

        if not sync:
            batch_prediction_job.wait()

        # Construct expected request
        expected_gapic_batch_prediction_job = gca_batch_prediction_job_v1beta1.BatchPredictionJob(
            display_name=_TEST_BATCH_PREDICTION_DISPLAY_NAME,
            model=model_service_client_v1beta1.ModelServiceClient.model_path(
                _TEST_PROJECT, _TEST_LOCATION, _TEST_ID
            ),
            input_config=gca_batch_prediction_job_v1beta1.BatchPredictionJob.InputConfig(
                instances_format="jsonl",
                gcs_source=gca_io_v1beta1.GcsSource(
                    uris=[_TEST_BATCH_PREDICTION_GCS_SOURCE]
                ),
            ),
            output_config=gca_batch_prediction_job_v1beta1.BatchPredictionJob.OutputConfig(
                gcs_destination=gca_io_v1beta1.GcsDestination(
                    output_uri_prefix=_TEST_BATCH_PREDICTION_GCS_DEST_PREFIX
                ),
                predictions_format="csv",
            ),
            dedicated_resources=gca_machine_resources_v1beta1.BatchDedicatedResources(
                machine_spec=gca_machine_resources_v1beta1.MachineSpec(
                    machine_type=_TEST_MACHINE_TYPE,
                    accelerator_type=_TEST_ACCELERATOR_TYPE,
                    accelerator_count=_TEST_ACCELERATOR_COUNT,
                ),
                starting_replica_count=_TEST_STARTING_REPLICA_COUNT,
                max_replica_count=_TEST_MAX_REPLICA_COUNT,
            ),
            generate_explanation=True,
            explanation_spec=gca_explanation_v1beta1.ExplanationSpec(
                metadata=_TEST_EXPLANATION_METADATA,
                parameters=_TEST_EXPLANATION_PARAMETERS,
            ),
            labels=_TEST_LABEL,
            encryption_spec=_TEST_ENCRYPTION_SPEC_V1BETA1,
        )

        create_batch_prediction_job_with_explanations_mock.assert_called_once_with(
            parent=f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}",
            batch_prediction_job=expected_gapic_batch_prediction_job,
        )