def test_create_model(sagemaker_session):
    source_dir = 's3://mybucket/source'

    sklearn_model = SKLearnModel(model_data=source_dir,
                                 role=ROLE,
                                 sagemaker_session=sagemaker_session,
                                 entry_point=SCRIPT_PATH)
    default_image_uri = _get_full_cpu_image_uri('0.20.0')
    model_values = sklearn_model.prepare_container_def(CPU)
    assert model_values['Image'] == default_image_uri
def test_create_model(sagemaker_session, sklearn_version):
    source_dir = "s3://mybucket/source"

    sklearn_model = SKLearnModel(
        model_data=source_dir,
        role=ROLE,
        sagemaker_session=sagemaker_session,
        entry_point=SCRIPT_PATH,
        framework_version=sklearn_version,
    )
    image_uri = _get_full_cpu_image_uri(sklearn_version)
    model_values = sklearn_model.prepare_container_def(CPU)
    assert model_values["Image"] == image_uri
def test_create_model_with_network_isolation(upload, sagemaker_session):
    source_dir = "s3://mybucket/source"
    repacked_model_data = "s3://mybucket/prefix/model.tar.gz"

    sklearn_model = SKLearnModel(
        model_data=source_dir,
        role=ROLE,
        sagemaker_session=sagemaker_session,
        entry_point=SCRIPT_PATH,
        enable_network_isolation=True,
    )
    sklearn_model.uploaded_code = UploadedCode(s3_prefix=repacked_model_data, script_name="script")
    sklearn_model.repacked_model_data = repacked_model_data
    model_values = sklearn_model.prepare_container_def(CPU)
    assert model_values["Environment"]["SAGEMAKER_SUBMIT_DIRECTORY"] == "/opt/ml/model/code"
    assert model_values["ModelDataUrl"] == repacked_model_data