def test_create_sagemaker_model_optional_model_params(base_name_from_image,
                                                      name_from_base,
                                                      prepare_container_def,
                                                      sagemaker_session):
    container_def = {
        "Image": MODEL_IMAGE,
        "Environment": {},
        "ModelDataUrl": MODEL_DATA
    }
    prepare_container_def.return_value = container_def

    vpc_config = {"Subnets": ["123"], "SecurityGroupIds": ["456", "789"]}

    model = Model(
        MODEL_IMAGE,
        MODEL_DATA,
        name=MODEL_NAME,
        role=ROLE,
        vpc_config=vpc_config,
        enable_network_isolation=True,
        sagemaker_session=sagemaker_session,
    )
    model._create_sagemaker_model(INSTANCE_TYPE)

    base_name_from_image.assert_not_called()
    name_from_base.assert_not_called()

    sagemaker_session.create_model.assert_called_with(
        MODEL_NAME,
        ROLE,
        container_def,
        vpc_config=vpc_config,
        enable_network_isolation=True,
        tags=None,
    )
def test_create_sagemaker_model_generates_model_name(base_name_from_image,
                                                     name_from_base,
                                                     prepare_container_def,
                                                     sagemaker_session):
    container_def = {
        "Image": MODEL_IMAGE,
        "Environment": {},
        "ModelDataUrl": MODEL_DATA
    }
    prepare_container_def.return_value = container_def

    model = Model(
        MODEL_IMAGE,
        MODEL_DATA,
        sagemaker_session=sagemaker_session,
    )
    model._create_sagemaker_model(INSTANCE_TYPE)

    base_name_from_image.assert_called_with(MODEL_IMAGE)
    name_from_base.assert_called_with(base_name_from_image.return_value)

    sagemaker_session.create_model.assert_called_with(
        MODEL_NAME,
        None,
        container_def,
        vpc_config=None,
        enable_network_isolation=False,
        tags=None,
    )
Esempio n. 3
0
def test_create_sagemaker_model_instance_type(prepare_container_def,
                                              sagemaker_session):
    model = Model(MODEL_DATA, MODEL_IMAGE, sagemaker_session=sagemaker_session)
    model._create_sagemaker_model(INSTANCE_TYPE)

    prepare_container_def.assert_called_with(INSTANCE_TYPE,
                                             accelerator_type=None)
Esempio n. 4
0
def test_create_sagemaker_model(name_from_image, prepare_container_def,
                                sagemaker_session):
    name_from_image.return_value = MODEL_NAME

    container_def = {
        "Image": MODEL_IMAGE,
        "Environment": {},
        "ModelDataUrl": MODEL_DATA
    }
    prepare_container_def.return_value = container_def

    model = Model(MODEL_DATA, MODEL_IMAGE, sagemaker_session=sagemaker_session)
    model._create_sagemaker_model(INSTANCE_TYPE)

    prepare_container_def.assert_called_with(INSTANCE_TYPE,
                                             accelerator_type=None)
    name_from_image.assert_called_with(MODEL_IMAGE)

    sagemaker_session.create_model.assert_called_with(
        MODEL_NAME,
        None,
        container_def,
        vpc_config=None,
        enable_network_isolation=False,
        tags=None)
Esempio n. 5
0
def test_create_sagemaker_model_accelerator_type(prepare_container_def, sagemaker_session):
    model = Model(MODEL_IMAGE, MODEL_DATA, name=MODEL_NAME, sagemaker_session=sagemaker_session)

    accelerator_type = "ml.eia.medium"
    model._create_sagemaker_model(INSTANCE_TYPE, accelerator_type=accelerator_type)

    prepare_container_def.assert_called_with(INSTANCE_TYPE, accelerator_type=accelerator_type)
Esempio n. 6
0
def test_create_sagemaker_model_creates_correct_session(local_session, session):
    model = Model(MODEL_IMAGE, MODEL_DATA)
    model._create_sagemaker_model("local")
    assert model.sagemaker_session == local_session.return_value

    model = Model(MODEL_IMAGE, MODEL_DATA)
    model._create_sagemaker_model("ml.m5.xlarge")
    assert model.sagemaker_session == session.return_value
Esempio n. 7
0
def test_create_sagemaker_model_tags(prepare_container_def, sagemaker_session):
    container_def = {"Image": MODEL_IMAGE, "Environment": {}, "ModelDataUrl": MODEL_DATA}
    prepare_container_def.return_value = container_def

    model = Model(MODEL_IMAGE, MODEL_DATA, name=MODEL_NAME, sagemaker_session=sagemaker_session)

    tags = {"Key": "foo", "Value": "bar"}
    model._create_sagemaker_model(INSTANCE_TYPE, tags=tags)

    sagemaker_session.create_model.assert_called_with(
        MODEL_NAME, None, container_def, vpc_config=None, enable_network_isolation=False, tags=tags
    )
Esempio n. 8
0
def test_create_sagemaker_model_generates_model_name_each_time(
    base_name_from_image, name_from_base, prepare_container_def, sagemaker_session
):
    container_def = {"Image": MODEL_IMAGE, "Environment": {}, "ModelDataUrl": MODEL_DATA}
    prepare_container_def.return_value = container_def

    model = Model(
        MODEL_IMAGE,
        MODEL_DATA,
        sagemaker_session=sagemaker_session,
    )
    model._create_sagemaker_model(INSTANCE_TYPE)
    model._create_sagemaker_model(INSTANCE_TYPE)

    base_name_from_image.assert_called_once_with(MODEL_IMAGE)
    name_from_base.assert_called_with(base_name_from_image.return_value)
    assert 2 == name_from_base.call_count
Esempio n. 9
0
def test_create_sagemaker_model(prepare_container_def, sagemaker_session):
    container_def = {
        "Image": MODEL_IMAGE,
        "Environment": {},
        "ModelDataUrl": MODEL_DATA
    }
    prepare_container_def.return_value = container_def

    model = Model(MODEL_DATA,
                  MODEL_IMAGE,
                  name=MODEL_NAME,
                  sagemaker_session=sagemaker_session)
    model._create_sagemaker_model()

    prepare_container_def.assert_called_with(None,
                                             accelerator_type=None,
                                             serverless_inference_config=None)
    sagemaker_session.create_model.assert_called_with(
        MODEL_NAME,
        None,
        container_def,
        vpc_config=None,
        enable_network_isolation=False,
        tags=None)