def test_deploy_with_display_name(self, deploy_model_mock, sync):
        test_endpoint = models.Endpoint(_TEST_ENDPOINT_NAME)
        test_model = models.Model(_TEST_ID)
        test_endpoint.deploy(model=test_model,
                             deployed_model_display_name=_TEST_DISPLAY_NAME,
                             sync=sync)

        if not sync:
            test_endpoint.wait()

        automatic_resources = gca_machine_resources.AutomaticResources(
            min_replica_count=1,
            max_replica_count=1,
        )
        deployed_model = gca_endpoint.DeployedModel(
            automatic_resources=automatic_resources,
            model=test_model.resource_name,
            display_name=_TEST_DISPLAY_NAME,
        )
        deploy_model_mock.assert_called_once_with(
            endpoint=test_endpoint.resource_name,
            deployed_model=deployed_model,
            traffic_split={"0": 100},
            metadata=(),
        )
    def test_deploy_with_traffic_percent(self, deploy_model_mock, sync):
        with mock.patch.object(endpoint_service_client.EndpointServiceClient,
                               "get_endpoint") as get_endpoint_mock:
            get_endpoint_mock.return_value = gca_endpoint.Endpoint(
                display_name=_TEST_DISPLAY_NAME,
                name=_TEST_ENDPOINT_NAME,
                traffic_split={"model1": 100},
            )

            test_endpoint = models.Endpoint(_TEST_ENDPOINT_NAME)
            test_model = models.Model(_TEST_ID)
            test_endpoint.deploy(model=test_model,
                                 traffic_percentage=70,
                                 sync=sync)
            if not sync:
                test_endpoint.wait()
            automatic_resources = gca_machine_resources.AutomaticResources(
                min_replica_count=1,
                max_replica_count=1,
            )
            deployed_model = gca_endpoint.DeployedModel(
                automatic_resources=automatic_resources,
                model=test_model.resource_name,
                display_name=None,
            )
            deploy_model_mock.assert_called_once_with(
                endpoint=test_endpoint.resource_name,
                deployed_model=deployed_model,
                traffic_split={
                    "model1": 30,
                    "0": 70
                },
                metadata=(),
            )
 def test_deploy_with_max_replica_count(self, deploy_model_mock, sync):
     aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION)
     test_endpoint = models.Endpoint(_TEST_ENDPOINT_NAME)
     test_model = models.Model(_TEST_ID)
     test_endpoint.deploy(model=test_model, max_replica_count=2, sync=sync)
     if not sync:
         test_endpoint.wait()
     automatic_resources = gca_machine_resources.AutomaticResources(
         min_replica_count=1, max_replica_count=2,
     )
     deployed_model = gca_endpoint.DeployedModel(
         automatic_resources=automatic_resources,
         model=test_model.resource_name,
         display_name=None,
     )
     deploy_model_mock.assert_called_once_with(
         endpoint=test_endpoint.resource_name,
         deployed_model=deployed_model,
         traffic_split={"0": 100},
         metadata=(),
     )