Beispiel #1
0
def test_retrieve_different_config_per_python_version(config_for_framework, caplog):
    config = {
        "processors": ["cpu", "gpu"],
        "scope": ["training", "inference"],
        "versions": {
            "1.0.0": {
                "py3": {"registries": {"us-west-2": "123412341234"}, "repository": "foo"},
                "py37": {"registries": {"us-west-2": "012345678901"}, "repository": "bar"},
            },
        },
    }
    config_for_framework.return_value = config

    uri = image_uris.retrieve(
        framework="useless-string",
        version="1.0.0",
        py_version="py3",
        instance_type="ml.c4.xlarge",
        region="us-west-2",
        image_scope="training",
    )
    assert "123412341234.dkr.ecr.us-west-2.amazonaws.com/foo:1.0.0-cpu-py3" == uri

    uri = image_uris.retrieve(
        framework="useless-string",
        version="1.0.0",
        py_version="py37",
        instance_type="ml.c4.xlarge",
        region="us-west-2",
        image_scope="training",
    )
    assert "012345678901.dkr.ecr.us-west-2.amazonaws.com/bar:1.0.0-cpu-py37" == uri
Beispiel #2
0
def test_retrieve_unsupported_version(config_for_framework):
    with pytest.raises(ValueError) as e:
        image_uris.retrieve(
            framework="some-framework",
            version="1",
            py_version="py3",
            instance_type="ml.c4.xlarge",
            region="us-west-2",
            image_scope="training",
        )

    assert "Unsupported some-framework version: 1." in str(e.value)
    assert "Supported some-framework version(s): 1.0.0, 1.1.0." in str(e.value)

    with pytest.raises(ValueError) as e:
        image_uris.retrieve(
            framework="some-framework",
            py_version="py3",
            instance_type="ml.c4.xlarge",
            region="us-west-2",
            image_scope="training",
        )

    assert "Unsupported some-framework version: None." in str(e.value)
    assert "Supported some-framework version(s): 1.0.0, 1.1.0." in str(e.value)
Beispiel #3
0
def test_retrieve_processor_type_from_version_specific_processor_config(config_for_framework):
    config = copy.deepcopy(BASE_CONFIG)
    del config["processors"]
    config["versions"]["1.0.0"]["processors"] = ["cpu"]
    config_for_framework.return_value = config

    uri = image_uris.retrieve(
        framework="useless-string",
        version="1.0.0",
        py_version="py3",
        instance_type="ml.c4.xlarge",
        region="us-west-2",
        image_scope="training",
    )
    assert "123412341234.dkr.ecr.us-west-2.amazonaws.com/dummy:1.0.0-cpu-py3" == uri

    uri = image_uris.retrieve(
        framework="useless-string",
        version="1.1.0",
        py_version="py3",
        instance_type="ml.c4.xlarge",
        region="us-west-2",
        image_scope="training",
    )
    assert "123412341234.dkr.ecr.us-west-2.amazonaws.com/dummy:1.1.0-py3" == uri
Beispiel #4
0
def test_retrieve_unsupported_python_version(config_for_framework):
    with pytest.raises(ValueError) as e:
        image_uris.retrieve(
            framework="useless-string",
            version="1.0.0",
            py_version="py2",
            instance_type="ml.c4.xlarge",
            region="us-west-2",
            image_scope="training",
        )

    assert "Unsupported Python version: py2." in str(e.value)
    assert "Supported Python version(s): py3, py37." in str(e.value)

    with pytest.raises(ValueError) as e:
        image_uris.retrieve(
            framework="useless-string",
            version="1.0.0",
            instance_type="ml.c4.xlarge",
            region="us-west-2",
            image_scope="training",
        )

    assert "Unsupported Python version: None." in str(e.value)
    assert "Supported Python version(s): py3, py37." in str(e.value)
Beispiel #5
0
def test_retrieve_unsupported_image_scope(config_for_framework):
    with pytest.raises(ValueError) as e:
        image_uris.retrieve(
            framework="useless-string",
            version="1.0.0",
            py_version="py3",
            instance_type="ml.c4.xlarge",
            region="us-west-2",
            image_scope="invalid-image-scope",
        )
    assert "Unsupported image scope: invalid-image-scope." in str(e.value)
    assert "Supported image scope(s): training, inference." in str(e.value)

    config = copy.deepcopy(BASE_CONFIG)
    config["scope"].append("eia")
    config_for_framework.return_value = config

    with pytest.raises(ValueError) as e:
        image_uris.retrieve(
            framework="useless-string",
            version="1.0.0",
            py_version="py3",
            instance_type="ml.c4.xlarge",
            region="us-west-2",
        )
    assert "Unsupported image scope: None." in str(e.value)
    assert "Supported image scope(s): training, inference, eia." in str(e.value)
def _test_neo_framework_uris(framework, version):
    framework_in_config = f"neo-{framework}"
    framework_in_uri = f"neo-{framework}" if framework == "tensorflow" else f"inference-{framework}"

    for region in regions.regions():
        if region in ACCOUNTS:
            uri = image_uris.retrieve(framework_in_config,
                                      region,
                                      instance_type="ml_c5",
                                      version=version)
            assert _expected_framework_uri(framework_in_uri,
                                           version,
                                           region=region) == uri
        else:
            with pytest.raises(ValueError) as e:
                image_uris.retrieve(framework_in_config,
                                    region,
                                    instance_type="ml_c5",
                                    version=version)
            assert "Unsupported region: {}.".format(region) in str(e.value)

    uri = image_uris.retrieve(framework_in_config,
                              "us-west-2",
                              instance_type="ml_p2",
                              version=version)
    assert _expected_framework_uri(framework_in_uri, version,
                                   processor="gpu") == uri
Beispiel #7
0
def _test_image_uris(
    framework,
    fw_version,
    py_version,
    scope,
    expected_fn,
    expected_fn_args,
    base_framework_version=None,
):
    base_args = {
        "framework": framework,
        "version": fw_version,
        "py_version": py_version,
        "image_scope": scope,
    }

    for instance_type, processor in INSTANCE_TYPES_AND_PROCESSORS:
        uri = image_uris.retrieve(region=REGION, instance_type=instance_type, **base_args)

        expected = expected_fn(processor=processor, **expected_fn_args)
        assert expected == uri

    for region in SAGEMAKER_ALTERNATE_REGION_ACCOUNTS.keys():
        uri = image_uris.retrieve(region=region, instance_type="ml.c4.xlarge", **base_args)

        expected = expected_fn(region=region, **expected_fn_args)
        assert expected == uri
Beispiel #8
0
def test_mxnet_eia(mxnet_eia_version, mxnet_py_version):
    base_args = {
        "framework": "mxnet",
        "version": mxnet_eia_version,
        "py_version": mxnet_py_version,
        "image_scope": "inference",
        "instance_type": "ml.c4.xlarge",
        "accelerator_type": "ml.eia1.medium",
    }

    uri = image_uris.retrieve(region=REGION, **base_args)

    expected = _expected_mxnet_inference_uri(mxnet_eia_version,
                                             mxnet_py_version,
                                             eia=True)
    assert expected == uri

    for region in SAGEMAKER_ALTERNATE_REGION_ACCOUNTS.keys():
        uri = image_uris.retrieve(region=region, **base_args)

        expected = _expected_mxnet_inference_uri(mxnet_eia_version,
                                                 mxnet_py_version,
                                                 region=region,
                                                 eia=True)
        assert expected == uri
Beispiel #9
0
def test_pytorch_eia(pytorch_eia_version, pytorch_eia_py_version):
    base_args = {
        "framework": "pytorch",
        "version": pytorch_eia_version,
        "py_version": pytorch_eia_py_version,
        "image_scope": "inference",
        "instance_type": "ml.c4.xlarge",
        "accelerator_type": "ml.eia1.medium",
    }

    uri = image_uris.retrieve(region=REGION, **base_args)

    expected = expected_uris.framework_uri(
        "pytorch-inference-eia",
        pytorch_eia_version,
        DLC_ACCOUNT,
        py_version=pytorch_eia_py_version,
        region=REGION,
    )
    assert expected == uri

    for region, account in DLC_ALTERNATE_REGION_ACCOUNTS.items():
        uri = image_uris.retrieve(region=region, **base_args)

        expected = expected_uris.framework_uri(
            "pytorch-inference-eia",
            pytorch_eia_version,
            account,
            py_version=pytorch_eia_py_version,
            region=region,
        )
        assert expected == uri
Beispiel #10
0
def test_retrieve_aliased_version(config_for_framework):
    version = "1.0.0-build123"

    config = copy.deepcopy(BASE_CONFIG)
    config["version_aliases"] = {version: "1.0.0"}
    config_for_framework.return_value = config

    uri = image_uris.retrieve(
        framework="useless-string",
        version=version,
        py_version="py3",
        instance_type="ml.c4.xlarge",
        region="us-west-2",
        image_scope="training",
    )
    assert "123412341234.dkr.ecr.us-west-2.amazonaws.com/dummy:{}-cpu-py3".format(version) == uri

    del config["versions"]["1.1.0"]
    uri = image_uris.retrieve(
        framework="useless-string",
        version=version,
        py_version="py3",
        instance_type="ml.c4.xlarge",
        region="us-west-2",
        image_scope="training",
    )
    assert "123412341234.dkr.ecr.us-west-2.amazonaws.com/dummy:{}-cpu-py3".format(version) == uri
Beispiel #11
0
def test_retrieve_no_python_version(config_for_framework, caplog):
    config = copy.deepcopy(BASE_CONFIG)
    config["versions"]["1.0.0"]["py_versions"] = []
    config_for_framework.return_value = config

    uri = image_uris.retrieve(
        framework="useless-string",
        version="1.0.0",
        instance_type="ml.c4.xlarge",
        region="us-west-2",
        image_scope="training",
    )
    assert "123412341234.dkr.ecr.us-west-2.amazonaws.com/dummy:1.0.0-cpu" == uri

    caplog.set_level(logging.INFO)
    uri = image_uris.retrieve(
        framework="useless-string",
        version="1.0.0",
        py_version="py3",
        instance_type="ml.c4.xlarge",
        region="us-west-2",
        image_scope="training",
    )
    assert "123412341234.dkr.ecr.us-west-2.amazonaws.com/dummy:1.0.0-cpu" == uri
    assert "Ignoring unnecessary Python version: py3." in caplog.text
def test_jumpstart_xgboost_image_uri(patched_get_model_specs, session):

    patched_get_model_specs.side_effect = get_prototype_model_spec

    model_id, model_version = "xgboost-classification-model", "*"
    instance_type = "ml.p2.xlarge"
    region = "us-west-2"

    model_specs = accessors.JumpStartModelsAccessor.get_model_specs(
        region, model_id, model_version)

    # inference
    uri = image_uris.retrieve(
        framework=None,
        region=region,
        image_scope="inference",
        model_id=model_id,
        model_version=model_version,
        instance_type=instance_type,
    )

    framework_class_uri = XGBoostModel(
        role="mock_role",
        model_data="mock_data",
        entry_point="mock_entry_point",
        framework_version=model_specs.hosting_ecr_specs.framework_version,
        py_version=model_specs.hosting_ecr_specs.py_version,
        sagemaker_session=session,
    ).serving_image_uri(region, instance_type)

    assert uri == framework_class_uri
    assert uri == "246618743249.dkr.ecr.us-west-2.amazonaws.com/sagemaker-xgboost:1.3-1"

    # training
    uri = image_uris.retrieve(
        framework=None,
        region=region,
        image_scope="training",
        model_id=model_id,
        model_version=model_version,
        instance_type=instance_type,
    )

    framework_class_uri = XGBoost(
        role="mock_role",
        entry_point="mock_entry_point",
        framework_version=model_specs.training_ecr_specs.framework_version,
        py_version=model_specs.training_ecr_specs.py_version,
        instance_type=instance_type,
        instance_count=1,
        image_uri_region=region,
        sagemaker_session=session,
    ).training_image_uri(region=region)

    assert uri == framework_class_uri
    assert uri == "246618743249.dkr.ecr.us-west-2.amazonaws.com/sagemaker-xgboost:1.3-1"
def test_gpu_error(sklearn_version):
    with pytest.raises(ValueError) as e:
        image_uris.retrieve(
            "sklearn",
            region="us-west-2",
            version=sklearn_version,
            instance_type="ml.p2.xlarge",
        )

    assert "Unsupported processor: gpu." in str(e.value)
Beispiel #14
0
def test_algo_uris(algo):
    for region in regions.regions():
        if region in ACCOUNTS:
            uri = image_uris.retrieve(algo, region)
            expected = expected_uris.algo_uri(algo, ACCOUNTS[region], region, version="latest")
            assert expected == uri
        else:
            with pytest.raises(ValueError) as e:
                image_uris.retrieve(algo, region)
            assert "Unsupported region: {}.".format(region) in str(e.value)
def test_jumpstart_catboost_image_uri(patched_get_model_specs, session):

    patched_get_model_specs.side_effect = get_prototype_model_spec

    model_id, model_version = "catboost-classification-model", "*"
    instance_type = "ml.p2.xlarge"
    region = "us-west-2"

    model_specs = accessors.JumpStartModelsAccessor.get_model_specs(
        region, model_id, model_version)

    # inference
    uri = image_uris.retrieve(
        framework=None,
        region=region,
        image_scope="inference",
        model_id=model_id,
        model_version=model_version,
        instance_type=instance_type,
    )

    framework_class_uri = PyTorchModel(
        role="mock_role",
        model_data="mock_data",
        entry_point="mock_entry_point",
        framework_version=model_specs.hosting_ecr_specs.framework_version,
        py_version=model_specs.hosting_ecr_specs.py_version,
        sagemaker_session=session,
    ).serving_image_uri(region, instance_type)

    assert uri == framework_class_uri
    assert uri == "763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-inference:1.9.0-gpu-py38"

    # training
    uri = image_uris.retrieve(
        framework=None,
        region=region,
        image_scope="training",
        model_id=model_id,
        model_version=model_version,
        instance_type=instance_type,
    )

    framework_class_uri = PyTorch(
        role="mock_role",
        entry_point="mock_entry_point",
        framework_version=model_specs.training_ecr_specs.framework_version,
        py_version=model_specs.training_ecr_specs.py_version,
        instance_type=instance_type,
        instance_count=1,
        sagemaker_session=session,
    ).training_image_uri(region=region)

    assert uri == framework_class_uri
    assert uri == "763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-training:1.9.0-gpu-py38"
def test_gpu_error(version):
    with pytest.raises(ValueError) as e:
        image_uris.retrieve(
            "autogluon",
            region="us-west-2",
            version=version,
            image_scope="inference",
            instance_type="ml.p2.xlarge",
        )

    assert "Unsupported processor: gpu." in str(e.value)
Beispiel #17
0
def test_jumpstart_tensorflow_image_uri(patched_get_model_specs, session):

    patched_get_model_specs.side_effect = get_prototype_model_spec

    model_id, model_version = "tensorflow-ic-bit-m-r101x1-ilsvrc2012-classification-1", "*"
    instance_type = "ml.p2.xlarge"
    region = "us-west-2"

    model_specs = accessors.JumpStartModelsAccessor.get_model_specs(
        region, model_id, model_version)

    # inference
    uri = image_uris.retrieve(
        framework=None,
        region=region,
        image_scope="inference",
        model_id=model_id,
        model_version=model_version,
        instance_type=instance_type,
    )

    framework_class_uri = TensorFlowModel(
        role="mock_role",
        model_data="mock_data",
        entry_point="mock_entry_point",
        framework_version=model_specs.hosting_ecr_specs.framework_version,
        sagemaker_session=session,
    ).serving_image_uri(region, instance_type)

    assert uri == framework_class_uri
    assert uri == "763104351884.dkr.ecr.us-west-2.amazonaws.com/tensorflow-inference:2.3-gpu"

    # training
    uri = image_uris.retrieve(
        framework=None,
        region=region,
        image_scope="training",
        model_id=model_id,
        model_version=model_version,
        instance_type=instance_type,
    )

    framework_class_uri = TensorFlow(
        role="mock_role",
        entry_point="mock_entry_point",
        framework_version=model_specs.training_ecr_specs.framework_version,
        py_version=model_specs.training_ecr_specs.py_version,
        instance_type=instance_type,
        instance_count=1,
        sagemaker_session=session,
    ).training_image_uri(region=region)

    assert uri == framework_class_uri
    assert uri == "763104351884.dkr.ecr.us-west-2.amazonaws.com/tensorflow-training:2.3-gpu-py37"
def test_py2_error(sklearn_version):
    with pytest.raises(ValueError) as e:
        image_uris.retrieve(
            "sklearn",
            region="us-west-2",
            version=sklearn_version,
            py_version="py2",
            instance_type="ml.c4.xlarge",
        )

    assert "Unsupported Python version: py2." in str(e.value)
Beispiel #19
0
def get_algo_image_uri(algo_name, region, repo_version):
    if algo_name == "xgboost":
        return image_uris.retrieve(algo_name, region=region, version='1.2-2')
    elif algo_name == "mlp":
        mlp_image_uri = image_uris.retrieve("linear-learner", region=region, version=repo_version)
        last_slash_index = mlp_image_uri.rfind('/')
        return "{}/{}:{}".format(
            mlp_image_uri[:last_slash_index], "mxnet-algorithms", repo_version
        )
    else:
        return image_uris.retrieve(algo_name, region=region, version=repo_version)
Beispiel #20
0
def test_retrieve_invalid_accelerator(config_for_framework):
    with pytest.raises(ValueError) as e:
        image_uris.retrieve(
            framework="useless-string",
            version="1.0.0",
            py_version="py3",
            instance_type="ml.c4.xlarge",
            accelerator_type="fake-accelerator",
            region="us-west-2",
        )
    assert "Invalid SageMaker Elastic Inference accelerator type: fake-accelerator." in str(e.value)
def test_lda():
    algo = "lda"
    accounts = _accounts_for_algo(algo)

    for region in regions.regions():
        if region in accounts:
            uri = image_uris.retrieve(algo, region)
            assert expected_uris.algo_uri(algo, accounts[region], region) == uri
        else:
            with pytest.raises(ValueError) as e:
                image_uris.retrieve(algo, region)
            assert "Unsupported region: {}.".format(region) in str(e.value)
def test_py3_error(version):
    with pytest.raises(ValueError) as e:
        image_uris.retrieve(
            "autogluon",
            region="us-west-2",
            version=version,
            py_version="py3",
            image_scope="training",
            instance_type="ml.c4.xlarge",
        )

    assert "Unsupported Python version: py3." in str(e.value)
Beispiel #23
0
def test_retrieve_processor_type_neo(config_for_framework):
    for cpu in ("ml_m4", "ml_m5", "ml_c4", "ml_c5"):
        uri = image_uris.retrieve(
            framework="useless-string",
            version="1.0.0",
            py_version="py3",
            instance_type=cpu,
            region="us-west-2",
            image_scope="training",
        )
        assert "123412341234.dkr.ecr.us-west-2.amazonaws.com/dummy:1.0.0-cpu-py3" == uri

    for gpu in ("ml_p2", "ml_p3"):
        uri = image_uris.retrieve(
            framework="useless-string",
            version="1.0.0",
            py_version="py3",
            instance_type=gpu,
            region="us-west-2",
            image_scope="training",
        )
        assert "123412341234.dkr.ecr.us-west-2.amazonaws.com/dummy:1.0.0-gpu-py3" == uri

    config = copy.deepcopy(BASE_CONFIG)
    config["processors"] = ["inf"]
    config_for_framework.return_value = config

    uri = image_uris.retrieve(
        framework="useless-string",
        version="1.0.0",
        py_version="py3",
        instance_type="ml_inf1",
        region="us-west-2",
        image_scope="training",
    )
    assert "123412341234.dkr.ecr.us-west-2.amazonaws.com/dummy:1.0.0-inf-py3" == uri

    config = copy.deepcopy(BASE_CONFIG)
    config["processors"] = ["c5"]
    config_for_framework.return_value = config

    uri = image_uris.retrieve(
        framework="useless-string",
        version="1.0.0",
        py_version="py3",
        instance_type="ml_c5",
        region="us-west-2",
        image_scope="training",
    )
    assert "123412341234.dkr.ecr.us-west-2.amazonaws.com/dummy:1.0.0-c5-py3" == uri
def endpoint_name(sagemaker_session):
    endpoint_name = unique_name_from_base("model-quality-monitor-integ")
    xgb_model_data = sagemaker_session.upload_data(
        path=os.path.join(XGBOOST_DATA_PATH, "xgb_model.tar.gz"),
        key_prefix="integ-test-data/xgboost/model",
    )

    xgb_image = image_uris.retrieve("xgboost",
                                    sagemaker_session.boto_region_name,
                                    version="1",
                                    image_scope="inference")

    with tests.integ.timeout.timeout_and_delete_endpoint_by_name(
            endpoint_name=endpoint_name,
            sagemaker_session=sagemaker_session,
            hours=2):
        xgb_model = Model(
            model_data=xgb_model_data,
            image_uri=xgb_image,
            name=endpoint_name,  # model name
            role=ROLE,
            sagemaker_session=sagemaker_session,
        )
        xgb_model.deploy(
            INSTANCE_COUNT,
            INSTANCE_TYPE,
            endpoint_name=endpoint_name,
            data_capture_config=DataCaptureConfig(
                True, sagemaker_session=sagemaker_session),
        )
        yield endpoint_name
Beispiel #25
0
 def __init__(
     self,
     entry_point,
     region,
     framework_version,
     py_version,
     instance_type,
     source_dir=None,
     hyperparameters=None,
     **kwargs,
 ):
     self.framework_version = framework_version
     self.py_version = py_version
     self.image_uri = image_uris.retrieve(
         "autogluon",
         region=region,
         version=framework_version,
         py_version=py_version,
         image_scope="training",
         instance_type=instance_type,
     )
     super().__init__(
         entry_point,
         source_dir,
         hyperparameters,
         framework_version=framework_version,
         instance_type=instance_type,
         image_uri=self.image_uri,
         **kwargs,
     )
    def serving_image_uri(self,
                          region_name,
                          instance_type,
                          accelerator_type=None):
        """Create a URI for the serving image.

        Args:
            region_name (str): AWS region where the image is uploaded.
            instance_type (str): SageMaker instance type. Used to determine device type
                (cpu/gpu/family-specific optimized).
            accelerator_type (str): The Elastic Inference accelerator type to
                deploy to the instance for loading and making inferences to the
                model (default: None). For example, 'ml.eia1.medium'.

        Returns:
            str: The appropriate image URI based on the given parameters.

        """
        return image_uris.retrieve(
            self._framework_name,
            region_name,
            version=self.framework_version,
            py_version=self.py_version,
            instance_type=instance_type,
            accelerator_type=accelerator_type,
            image_scope="inference",
        )
    def __init__(self, model_data, role, sagemaker_session=None, **kwargs):
        """Initialization for LinearLearnerModel.

        Args:
            model_data (str): The S3 location of a SageMaker model data
                ``.tar.gz`` file.
            role (str): An AWS IAM role (either name or full ARN). The Amazon
                SageMaker training jobs and APIs that create Amazon SageMaker
                endpoints use this role to access training data and model
                artifacts. After the endpoint is created, the inference code
                might use the IAM role, if it needs to access an AWS resource.
            sagemaker_session (sagemaker.session.Session): Session object which
                manages interactions with Amazon SageMaker APIs and any other
                AWS services needed. If not specified, the estimator creates one
                using the default AWS configuration chain.
            **kwargs: Keyword arguments passed to the ``FrameworkModel``
                initializer.
        """
        sagemaker_session = sagemaker_session or Session()
        image_uri = image_uris.retrieve(
            LinearLearner.repo_name,
            sagemaker_session.boto_region_name,
            version=LinearLearner.repo_version,
        )
        super(LinearLearnerModel,
              self).__init__(image_uri,
                             model_data,
                             role,
                             predictor_cls=LinearLearnerPredictor,
                             sagemaker_session=sagemaker_session,
                             **kwargs)
 def training_image_uri(self):
     """Placeholder docstring"""
     return image_uris.retrieve(
         self.repo_name,
         self.sagemaker_session.boto_region_name,
         version=self.repo_version,
     )
Beispiel #29
0
    def _validate_args(self, py_version):
        """Placeholder docstring"""

        if py_version == "py2" and self._only_python_3_supported():
            msg = (
                "Python 2 containers are only available with {} and lower versions. "
                "Please use a Python 3 container.".format(
                    defaults.LATEST_PY2_VERSION))
            raise AttributeError(msg)

        if self.image_uri is None and self._only_legacy_mode_supported():
            legacy_image_uri = image_uris.retrieve(
                "tensorflow",
                self.sagemaker_session.boto_region_name,
                instance_type=self.instance_type,
                version=self.framework_version,
                py_version=self.py_version,
                image_scope="training",
            )

            msg = (
                "TF {} supports only legacy mode. Please supply the image URI directly with "
                "'image_uri={}' and set 'model_dir=False'. If you are using any legacy parameters "
                "(training_steps, evaluation_steps, checkpoint_path, requirements_file), "
                "make sure to pass them directly as hyperparameters instead. For more, see "
                "https://sagemaker.readthedocs.io/en/v2.0.0.rc0/frameworks/tensorflow/upgrade_from_legacy.html."
            ).format(self.framework_version, legacy_image_uri)

            raise ValueError(msg)
    def _compilation_image_uri(self, region, target_instance_type, framework,
                               framework_version):
        """Retrieve the Neo or Inferentia image URI.

        Args:
            region (str): The AWS region.
            target_instance_type (str): Identifies the device on which you want to run
                your model after compilation, for example: ml_c5. For valid values, see
                https://docs.aws.amazon.com/sagemaker/latest/dg/API_OutputConfig.html.
            framework (str): The framework name.
            framework_version (str): The framework version.
        """
        framework_prefix = ""
        framework_suffix = ""

        if framework == "xgboost":
            framework_suffix = "-neo"
        elif target_instance_type.startswith("ml_inf"):
            framework_prefix = "inferentia-"
        else:
            framework_prefix = "neo-"

        return image_uris.retrieve(
            "{}{}{}".format(framework_prefix, framework, framework_suffix),
            region,
            instance_type=target_instance_type,
            version=framework_version,
        )