def deploy_model_with_explanations_mock():
    with mock.patch.object(
            endpoint_service_client_v1beta1.EndpointServiceClient,
            "deploy_model") as deploy_model_mock:
        deployed_model = gca_endpoint_v1beta1.DeployedModel(
            model=_TEST_MODEL_NAME,
            display_name=_TEST_DISPLAY_NAME,
        )
        deploy_model_lro_mock = mock.Mock(ga_operation.Operation)
        deploy_model_lro_mock.result.return_value = gca_endpoint_service_v1beta1.DeployModelResponse(
            deployed_model=deployed_model, )
        deploy_model_mock.return_value = deploy_model_lro_mock
        yield deploy_model_mock
예제 #2
0
    def test_deploy_with_explanations(self,
                                      deploy_model_with_explanations_mock,
                                      sync):
        aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION)
        test_endpoint = models.Endpoint(_TEST_ENDPOINT_NAME)
        test_model = models.Model(_TEST_ID)
        test_endpoint.deploy(
            model=test_model,
            machine_type=_TEST_MACHINE_TYPE,
            accelerator_type=_TEST_ACCELERATOR_TYPE,
            accelerator_count=_TEST_ACCELERATOR_COUNT,
            explanation_metadata=_TEST_EXPLANATION_METADATA,
            explanation_parameters=_TEST_EXPLANATION_PARAMETERS,
            sync=sync,
        )

        if not sync:
            test_endpoint.wait()

        expected_machine_spec = gca_machine_resources_v1beta1.MachineSpec(
            machine_type=_TEST_MACHINE_TYPE,
            accelerator_type=_TEST_ACCELERATOR_TYPE,
            accelerator_count=_TEST_ACCELERATOR_COUNT,
        )
        expected_dedicated_resources = gca_machine_resources_v1beta1.DedicatedResources(
            machine_spec=expected_machine_spec,
            min_replica_count=1,
            max_replica_count=1,
        )
        expected_deployed_model = gca_endpoint_v1beta1.DeployedModel(
            dedicated_resources=expected_dedicated_resources,
            model=test_model.resource_name,
            display_name=None,
            explanation_spec=gca_endpoint_v1beta1.explanation.ExplanationSpec(
                metadata=_TEST_EXPLANATION_METADATA,
                parameters=_TEST_EXPLANATION_PARAMETERS,
            ),
        )
        deploy_model_with_explanations_mock.assert_called_once_with(
            endpoint=test_endpoint.resource_name,
            deployed_model=expected_deployed_model,
            traffic_split={"0": 100},
            metadata=(),
        )