def test_deploy_with_display_name(self, deploy_model_mock, sync): test_endpoint = models.Endpoint(_TEST_ENDPOINT_NAME) test_model = models.Model(_TEST_ID) test_endpoint.deploy(model=test_model, deployed_model_display_name=_TEST_DISPLAY_NAME, sync=sync) if not sync: test_endpoint.wait() automatic_resources = gca_machine_resources.AutomaticResources( min_replica_count=1, max_replica_count=1, ) deployed_model = gca_endpoint.DeployedModel( automatic_resources=automatic_resources, model=test_model.resource_name, display_name=_TEST_DISPLAY_NAME, ) deploy_model_mock.assert_called_once_with( endpoint=test_endpoint.resource_name, deployed_model=deployed_model, traffic_split={"0": 100}, metadata=(), )
def test_deploy_with_traffic_percent(self, deploy_model_mock, sync): with mock.patch.object(endpoint_service_client.EndpointServiceClient, "get_endpoint") as get_endpoint_mock: get_endpoint_mock.return_value = gca_endpoint.Endpoint( display_name=_TEST_DISPLAY_NAME, name=_TEST_ENDPOINT_NAME, traffic_split={"model1": 100}, ) test_endpoint = models.Endpoint(_TEST_ENDPOINT_NAME) test_model = models.Model(_TEST_ID) test_endpoint.deploy(model=test_model, traffic_percentage=70, sync=sync) if not sync: test_endpoint.wait() automatic_resources = gca_machine_resources.AutomaticResources( min_replica_count=1, max_replica_count=1, ) deployed_model = gca_endpoint.DeployedModel( automatic_resources=automatic_resources, model=test_model.resource_name, display_name=None, ) deploy_model_mock.assert_called_once_with( endpoint=test_endpoint.resource_name, deployed_model=deployed_model, traffic_split={ "model1": 30, "0": 70 }, metadata=(), )
def test_deploy_with_max_replica_count(self, deploy_model_mock, sync): aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) test_endpoint = models.Endpoint(_TEST_ENDPOINT_NAME) test_model = models.Model(_TEST_ID) test_endpoint.deploy(model=test_model, max_replica_count=2, sync=sync) if not sync: test_endpoint.wait() automatic_resources = gca_machine_resources.AutomaticResources( min_replica_count=1, max_replica_count=2, ) deployed_model = gca_endpoint.DeployedModel( automatic_resources=automatic_resources, model=test_model.resource_name, display_name=None, ) deploy_model_mock.assert_called_once_with( endpoint=test_endpoint.resource_name, deployed_model=deployed_model, traffic_split={"0": 100}, metadata=(), )