def test_update_endpoint_no_args(name_from_base):
    new_endpoint_config_name = "new-endpoint-config"
    name_from_base.return_value = new_endpoint_config_name

    sagemaker_session = empty_sagemaker_session()
    existing_endpoint_config_name = "existing-endpoint-config"

    predictor = Predictor(ENDPOINT, sagemaker_session=sagemaker_session)
    predictor._endpoint_config_name = existing_endpoint_config_name

    predictor.update_endpoint()

    assert ["model-1", "model-2"] == predictor._model_names
    assert new_endpoint_config_name == predictor._endpoint_config_name

    name_from_base.assert_called_with(existing_endpoint_config_name)
    sagemaker_session.create_endpoint_config_from_existing.assert_called_with(
        existing_endpoint_config_name,
        new_endpoint_config_name,
        new_tags=None,
        new_kms_key=None,
        new_data_capture_config_dict=None,
        new_production_variants=None,
    )
    sagemaker_session.update_endpoint.assert_called_with(
        ENDPOINT, new_endpoint_config_name, wait=True)
def test_update_endpoint_no_one_default_model_name_with_instance_type_and_count(
):
    sagemaker_session = empty_sagemaker_session()
    predictor = Predictor(ENDPOINT, sagemaker_session=sagemaker_session)

    with pytest.raises(ValueError) as exception:
        predictor.update_endpoint(initial_instance_count=2,
                                  instance_type="ml.c4.xlarge")

    assert "Unable to choose a default model for a new EndpointConfig" in str(
        exception.value)
def test_update_endpoint_no_instance_type_or_no_instance_count():
    sagemaker_session = empty_sagemaker_session()
    predictor = Predictor(ENDPOINT, sagemaker_session=sagemaker_session)

    bad_args = ({
        "instance_type": "ml.c4.xlarge"
    }, {
        "initial_instance_count": 2
    })
    for args in bad_args:
        with pytest.raises(ValueError) as exception:
            predictor.update_endpoint(**args)

        expected_msg = "Missing initial_instance_count and/or instance_type."
        assert expected_msg in str(exception.value)
def test_update_endpoint_all_args(name_from_base, production_variant):
    new_endpoint_config_name = "new-endpoint-config"
    name_from_base.return_value = new_endpoint_config_name

    sagemaker_session = empty_sagemaker_session()
    existing_endpoint_config_name = "existing-endpoint-config"

    predictor = Predictor(ENDPOINT, sagemaker_session=sagemaker_session)
    predictor._endpoint_config_name = existing_endpoint_config_name

    new_instance_count = 2
    new_instance_type = "ml.c4.xlarge"
    new_accelerator_type = "ml.eia1.medium"
    new_model_name = "new-model"
    new_tags = {"Key": "foo", "Value": "bar"}
    new_kms_key = "new-key"
    new_data_capture_config_dict = {}

    predictor.update_endpoint(
        initial_instance_count=new_instance_count,
        instance_type=new_instance_type,
        accelerator_type=new_accelerator_type,
        model_name=new_model_name,
        tags=new_tags,
        kms_key=new_kms_key,
        data_capture_config_dict=new_data_capture_config_dict,
        wait=False,
    )

    assert [new_model_name] == predictor._model_names
    assert new_endpoint_config_name == predictor._endpoint_config_name

    production_variant.assert_called_with(
        new_model_name,
        new_instance_type,
        initial_instance_count=new_instance_count,
        accelerator_type=new_accelerator_type,
    )
    sagemaker_session.create_endpoint_config_from_existing.assert_called_with(
        existing_endpoint_config_name,
        new_endpoint_config_name,
        new_tags=new_tags,
        new_kms_key=new_kms_key,
        new_data_capture_config_dict=new_data_capture_config_dict,
        new_production_variants=[production_variant.return_value],
    )
    sagemaker_session.update_endpoint.assert_called_with(
        ENDPOINT, new_endpoint_config_name, wait=False)
def test_update_endpoint_instance_type_and_count(name_from_base,
                                                 production_variant):
    new_endpoint_config_name = "new-endpoint-config"
    name_from_base.return_value = new_endpoint_config_name

    sagemaker_session = empty_sagemaker_session()
    existing_endpoint_config_name = "existing-endpoint-config"
    existing_model_name = "existing-model"

    predictor = Predictor(ENDPOINT, sagemaker_session=sagemaker_session)
    predictor._endpoint_config_name = existing_endpoint_config_name
    predictor._model_names = [existing_model_name]

    new_instance_count = 2
    new_instance_type = "ml.c4.xlarge"

    predictor.update_endpoint(
        initial_instance_count=new_instance_count,
        instance_type=new_instance_type,
    )

    assert [existing_model_name] == predictor._model_names
    assert new_endpoint_config_name == predictor._endpoint_config_name

    production_variant.assert_called_with(
        existing_model_name,
        new_instance_type,
        initial_instance_count=new_instance_count,
        accelerator_type=None,
    )
    sagemaker_session.create_endpoint_config_from_existing.assert_called_with(
        existing_endpoint_config_name,
        new_endpoint_config_name,
        new_tags=None,
        new_kms_key=None,
        new_data_capture_config_dict=None,
        new_production_variants=[production_variant.return_value],
    )
    sagemaker_session.update_endpoint.assert_called_with(
        ENDPOINT, new_endpoint_config_name, wait=True)
Пример #6
0
def test_multi_data_model_deploy_pretrained_models_update_endpoint(
    container_image, sagemaker_session, cpu_instance_type, alternative_cpu_instance_type
):
    timestamp = sagemaker_timestamp()
    endpoint_name = "test-multimodel-endpoint-{}".format(timestamp)
    model_name = "test-multimodel-{}".format(timestamp)

    # Define pretrained model local path
    pretrained_model_data_local_path = os.path.join(DATA_DIR, "sparkml_model", "mleap_model.tar.gz")

    with timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session):
        model_data_prefix = os.path.join(
            "s3://", sagemaker_session.default_bucket(), "multimodel-{}/".format(timestamp)
        )
        multi_data_model = MultiDataModel(
            name=model_name,
            model_data_prefix=model_data_prefix,
            image_uri=container_image,
            role=ROLE,
            sagemaker_session=sagemaker_session,
        )

        # Add model before deploy
        multi_data_model.add_model(pretrained_model_data_local_path, PRETRAINED_MODEL_PATH_1)
        # Deploy model to an endpoint
        multi_data_model.deploy(1, cpu_instance_type, endpoint_name=endpoint_name)
        # Add model after deploy
        multi_data_model.add_model(pretrained_model_data_local_path, PRETRAINED_MODEL_PATH_2)

        # List model assertions
        endpoint_models = []
        for model_path in multi_data_model.list_models():
            endpoint_models.append(model_path)
        assert PRETRAINED_MODEL_PATH_1 in endpoint_models
        assert PRETRAINED_MODEL_PATH_2 in endpoint_models

        predictor = Predictor(
            endpoint_name=endpoint_name,
            sagemaker_session=sagemaker_session,
            serializer=NumpySerializer(),
            deserializer=string_deserializer,
        )

        data = numpy.zeros(shape=(1, 1, 28, 28))
        result = predictor.predict(data, target_model=PRETRAINED_MODEL_PATH_1)
        assert result == "Invoked model: {}".format(PRETRAINED_MODEL_PATH_1)

        result = predictor.predict(data, target_model=PRETRAINED_MODEL_PATH_2)
        assert result == "Invoked model: {}".format(PRETRAINED_MODEL_PATH_2)

        endpoint_desc = sagemaker_session.sagemaker_client.describe_endpoint(
            EndpointName=endpoint_name
        )
        old_config_name = endpoint_desc["EndpointConfigName"]

        # Update endpoint
        predictor.update_endpoint(
            initial_instance_count=1, instance_type=alternative_cpu_instance_type
        )

        endpoint_desc = sagemaker_session.sagemaker_client.describe_endpoint(
            EndpointName=endpoint_name
        )
        new_config_name = endpoint_desc["EndpointConfigName"]

        new_config = sagemaker_session.sagemaker_client.describe_endpoint_config(
            EndpointConfigName=new_config_name
        )
        assert old_config_name != new_config_name
        assert new_config["ProductionVariants"][0]["InstanceType"] == alternative_cpu_instance_type
        assert new_config["ProductionVariants"][0]["InitialInstanceCount"] == 1

        # Cleanup
        sagemaker_session.sagemaker_client.delete_endpoint_config(
            EndpointConfigName=old_config_name
        )
        sagemaker_session.sagemaker_client.delete_endpoint_config(
            EndpointConfigName=new_config_name
        )
        multi_data_model.delete_model()

    with pytest.raises(Exception) as exception:
        sagemaker_session.sagemaker_client.describe_model(ModelName=model_name)
        assert "Could not find model" in str(exception.value)
        sagemaker_session.sagemaker_client.describe_endpoint_config(name=old_config_name)
        assert "Could not find endpoint" in str(exception.value)
        sagemaker_session.sagemaker_client.describe_endpoint_config(name=new_config_name)
        assert "Could not find endpoint" in str(exception.value)