def test_update_endpoint_no_args(name_from_base): new_endpoint_config_name = "new-endpoint-config" name_from_base.return_value = new_endpoint_config_name sagemaker_session = empty_sagemaker_session() existing_endpoint_config_name = "existing-endpoint-config" predictor = Predictor(ENDPOINT, sagemaker_session=sagemaker_session) predictor._endpoint_config_name = existing_endpoint_config_name predictor.update_endpoint() assert ["model-1", "model-2"] == predictor._model_names assert new_endpoint_config_name == predictor._endpoint_config_name name_from_base.assert_called_with(existing_endpoint_config_name) sagemaker_session.create_endpoint_config_from_existing.assert_called_with( existing_endpoint_config_name, new_endpoint_config_name, new_tags=None, new_kms_key=None, new_data_capture_config_dict=None, new_production_variants=None, ) sagemaker_session.update_endpoint.assert_called_with( ENDPOINT, new_endpoint_config_name, wait=True)
def test_update_endpoint_all_args(name_from_base, production_variant): new_endpoint_config_name = "new-endpoint-config" name_from_base.return_value = new_endpoint_config_name sagemaker_session = empty_sagemaker_session() existing_endpoint_config_name = "existing-endpoint-config" predictor = Predictor(ENDPOINT, sagemaker_session=sagemaker_session) predictor._endpoint_config_name = existing_endpoint_config_name new_instance_count = 2 new_instance_type = "ml.c4.xlarge" new_accelerator_type = "ml.eia1.medium" new_model_name = "new-model" new_tags = {"Key": "foo", "Value": "bar"} new_kms_key = "new-key" new_data_capture_config_dict = {} predictor.update_endpoint( initial_instance_count=new_instance_count, instance_type=new_instance_type, accelerator_type=new_accelerator_type, model_name=new_model_name, tags=new_tags, kms_key=new_kms_key, data_capture_config_dict=new_data_capture_config_dict, wait=False, ) assert [new_model_name] == predictor._model_names assert new_endpoint_config_name == predictor._endpoint_config_name production_variant.assert_called_with( new_model_name, new_instance_type, initial_instance_count=new_instance_count, accelerator_type=new_accelerator_type, ) sagemaker_session.create_endpoint_config_from_existing.assert_called_with( existing_endpoint_config_name, new_endpoint_config_name, new_tags=new_tags, new_kms_key=new_kms_key, new_data_capture_config_dict=new_data_capture_config_dict, new_production_variants=[production_variant.return_value], ) sagemaker_session.update_endpoint.assert_called_with( ENDPOINT, new_endpoint_config_name, wait=False)
def test_update_endpoint_instance_type_and_count(name_from_base, production_variant): new_endpoint_config_name = "new-endpoint-config" name_from_base.return_value = new_endpoint_config_name sagemaker_session = empty_sagemaker_session() existing_endpoint_config_name = "existing-endpoint-config" existing_model_name = "existing-model" predictor = Predictor(ENDPOINT, sagemaker_session=sagemaker_session) predictor._endpoint_config_name = existing_endpoint_config_name predictor._model_names = [existing_model_name] new_instance_count = 2 new_instance_type = "ml.c4.xlarge" predictor.update_endpoint( initial_instance_count=new_instance_count, instance_type=new_instance_type, ) assert [existing_model_name] == predictor._model_names assert new_endpoint_config_name == predictor._endpoint_config_name production_variant.assert_called_with( existing_model_name, new_instance_type, initial_instance_count=new_instance_count, accelerator_type=None, ) sagemaker_session.create_endpoint_config_from_existing.assert_called_with( existing_endpoint_config_name, new_endpoint_config_name, new_tags=None, new_kms_key=None, new_data_capture_config_dict=None, new_production_variants=[production_variant.return_value], ) sagemaker_session.update_endpoint.assert_called_with( ENDPOINT, new_endpoint_config_name, wait=True)