def test_create_sagemaker_model_generates_model_name(name_from_base,
                                                     sagemaker_session):
    model_package_name = "my-model-package"

    model_package = ModelPackage(role="role",
                                 model_package_arn=model_package_name,
                                 sagemaker_session=sagemaker_session)

    model_package._create_sagemaker_model()

    name_from_base.assert_called_with(model_package_name)
    assert name_from_base.return_value == model_package.name
def test_create_sagemaker_model_generates_model_name_each_time(
        name_from_base, sagemaker_session):
    model_package_name = "my-model-package"

    model_package = ModelPackage(role="role",
                                 model_package_arn=model_package_name,
                                 sagemaker_session=sagemaker_session)

    model_package._create_sagemaker_model()
    model_package._create_sagemaker_model()

    name_from_base.assert_called_with(model_package_name)
    assert 2 == name_from_base.call_count
def test_create_sagemaker_model_uses_model_name(name_from_base,
                                                sagemaker_session):
    model_name = "my-model"
    model_package_name = "my-model-package"

    model_package = ModelPackage(
        role="role",
        name=model_name,
        model_package_arn=model_package_name,
        sagemaker_session=sagemaker_session,
    )

    model_package._create_sagemaker_model()

    assert model_name == model_package.name
    name_from_base.assert_not_called()

    sagemaker_session.create_model.assert_called_with(
        model_name,
        "role",
        {"ModelPackageName": model_package_name},
        vpc_config=None,
        enable_network_isolation=False,
    )