コード例 #1
0
def test_update_endpoint_no_args(name_from_base):
    new_endpoint_config_name = "new-endpoint-config"
    name_from_base.return_value = new_endpoint_config_name

    sagemaker_session = empty_sagemaker_session()
    existing_endpoint_config_name = "existing-endpoint-config"

    predictor = Predictor(ENDPOINT, sagemaker_session=sagemaker_session)
    predictor._endpoint_config_name = existing_endpoint_config_name

    predictor.update_endpoint()

    assert ["model-1", "model-2"] == predictor._model_names
    assert new_endpoint_config_name == predictor._endpoint_config_name

    name_from_base.assert_called_with(existing_endpoint_config_name)
    sagemaker_session.create_endpoint_config_from_existing.assert_called_with(
        existing_endpoint_config_name,
        new_endpoint_config_name,
        new_tags=None,
        new_kms_key=None,
        new_data_capture_config_dict=None,
        new_production_variants=None,
    )
    sagemaker_session.update_endpoint.assert_called_with(
        ENDPOINT, new_endpoint_config_name, wait=True)
コード例 #2
0
def test_update_endpoint_all_args(name_from_base, production_variant):
    new_endpoint_config_name = "new-endpoint-config"
    name_from_base.return_value = new_endpoint_config_name

    sagemaker_session = empty_sagemaker_session()
    existing_endpoint_config_name = "existing-endpoint-config"

    predictor = Predictor(ENDPOINT, sagemaker_session=sagemaker_session)
    predictor._endpoint_config_name = existing_endpoint_config_name

    new_instance_count = 2
    new_instance_type = "ml.c4.xlarge"
    new_accelerator_type = "ml.eia1.medium"
    new_model_name = "new-model"
    new_tags = {"Key": "foo", "Value": "bar"}
    new_kms_key = "new-key"
    new_data_capture_config_dict = {}

    predictor.update_endpoint(
        initial_instance_count=new_instance_count,
        instance_type=new_instance_type,
        accelerator_type=new_accelerator_type,
        model_name=new_model_name,
        tags=new_tags,
        kms_key=new_kms_key,
        data_capture_config_dict=new_data_capture_config_dict,
        wait=False,
    )

    assert [new_model_name] == predictor._model_names
    assert new_endpoint_config_name == predictor._endpoint_config_name

    production_variant.assert_called_with(
        new_model_name,
        new_instance_type,
        initial_instance_count=new_instance_count,
        accelerator_type=new_accelerator_type,
    )
    sagemaker_session.create_endpoint_config_from_existing.assert_called_with(
        existing_endpoint_config_name,
        new_endpoint_config_name,
        new_tags=new_tags,
        new_kms_key=new_kms_key,
        new_data_capture_config_dict=new_data_capture_config_dict,
        new_production_variants=[production_variant.return_value],
    )
    sagemaker_session.update_endpoint.assert_called_with(
        ENDPOINT, new_endpoint_config_name, wait=False)
コード例 #3
0
def test_update_endpoint_instance_type_and_count(name_from_base,
                                                 production_variant):
    new_endpoint_config_name = "new-endpoint-config"
    name_from_base.return_value = new_endpoint_config_name

    sagemaker_session = empty_sagemaker_session()
    existing_endpoint_config_name = "existing-endpoint-config"
    existing_model_name = "existing-model"

    predictor = Predictor(ENDPOINT, sagemaker_session=sagemaker_session)
    predictor._endpoint_config_name = existing_endpoint_config_name
    predictor._model_names = [existing_model_name]

    new_instance_count = 2
    new_instance_type = "ml.c4.xlarge"

    predictor.update_endpoint(
        initial_instance_count=new_instance_count,
        instance_type=new_instance_type,
    )

    assert [existing_model_name] == predictor._model_names
    assert new_endpoint_config_name == predictor._endpoint_config_name

    production_variant.assert_called_with(
        existing_model_name,
        new_instance_type,
        initial_instance_count=new_instance_count,
        accelerator_type=None,
    )
    sagemaker_session.create_endpoint_config_from_existing.assert_called_with(
        existing_endpoint_config_name,
        new_endpoint_config_name,
        new_tags=None,
        new_kms_key=None,
        new_data_capture_config_dict=None,
        new_production_variants=[production_variant.return_value],
    )
    sagemaker_session.update_endpoint.assert_called_with(
        ENDPOINT, new_endpoint_config_name, wait=True)