コード例 #1
0
def test_pytorch_eia(pytorch_eia_version, pytorch_eia_py_version):
    base_args = {
        "framework": "pytorch",
        "version": pytorch_eia_version,
        "py_version": pytorch_eia_py_version,
        "image_scope": "inference",
        "instance_type": "ml.c4.xlarge",
        "accelerator_type": "ml.eia1.medium",
    }

    uri = image_uris.retrieve(region=REGION, **base_args)

    expected = expected_uris.framework_uri(
        "pytorch-inference-eia",
        pytorch_eia_version,
        DLC_ACCOUNT,
        py_version=pytorch_eia_py_version,
        region=REGION,
    )
    assert expected == uri

    for region, account in DLC_ALTERNATE_REGION_ACCOUNTS.items():
        uri = image_uris.retrieve(region=region, **base_args)

        expected = expected_uris.framework_uri(
            "pytorch-inference-eia",
            pytorch_eia_version,
            account,
            py_version=pytorch_eia_py_version,
            region=region,
        )
        assert expected == uri
コード例 #2
0
def _expected_ray_tf_uri(ray_tf_version, processor):
    if Version(ray_tf_version) > Version("1.0.0"):
        return expected_uris.framework_uri(
            "sagemaker-rl-ray-container",
            _version_for_tag("ray", ray_tf_version, "tf", True),
            RL_ACCOUNT,
            py_version="py37",
            processor=processor,
        )
    elif Version(ray_tf_version) > Version("0.6.5"):
        return expected_uris.framework_uri(
            "sagemaker-rl-ray-container",
            _version_for_tag("ray", ray_tf_version, "tf", True),
            RL_ACCOUNT,
            py_version="py36",
            processor=processor,
        )
    else:
        return expected_uris.framework_uri(
            "sagemaker-rl-tensorflow",
            _version_for_tag("ray", ray_tf_version, "tf"),
            SAGEMAKER_ACCOUNT,
            py_version="py3",
            processor=processor,
        )
コード例 #3
0
def _expected_mxnet_inference_uri(mxnet_version,
                                  py_version,
                                  processor="cpu",
                                  region=REGION,
                                  eia=False):
    version = Version(mxnet_version)
    if version < Version("1.4"):
        repo = "sagemaker-mxnet"
    elif mxnet_version == "1.4.0":
        repo = "sagemaker-mxnet-serving"
    elif version >= Version("1.5"):
        repo = "mxnet-inference"
    else:
        repo = "sagemaker-mxnet-serving" if py_version == "py2" and not eia else "mxnet-inference"

    if eia:
        repo = "-".join((repo, "eia"))

    return expected_uris.framework_uri(
        repo,
        mxnet_version,
        _sagemaker_or_dlc_account(repo, region),
        py_version=py_version,
        processor=processor,
        region=region,
    )
コード例 #4
0
def _expected_coach_tf_uri(coach_tf_version, processor):
    if Version(coach_tf_version) > Version("0.11.1"):
        return expected_uris.framework_uri(
            "sagemaker-rl-coach-container",
            _version_for_tag("coach", coach_tf_version, "tf", True),
            RL_ACCOUNT,
            py_version="py3",
            processor=processor,
        )
    else:
        return expected_uris.framework_uri(
            "sagemaker-rl-tensorflow",
            _version_for_tag("coach", coach_tf_version, "tf"),
            SAGEMAKER_ACCOUNT,
            py_version="py3",
            processor=processor,
        )
コード例 #5
0
def _expected_framework_uri(framework, version, region="us-west-2", processor="cpu"):
    return expected_uris.framework_uri(
        "sagemaker-{}".format(framework),
        fw_version=version,
        py_version="py3",
        account=ACCOUNTS[region],
        region=region,
        processor=processor,
    )
コード例 #6
0
def _expected_chainer_uri(chainer_version, py_version, processor="cpu", region=REGION):
    account = SAGEMAKER_ACCOUNT if region == REGION else SAGEMAKER_ALTERNATE_REGION_ACCOUNTS[region]
    return expected_uris.framework_uri(
        repo="sagemaker-chainer",
        fw_version=chainer_version,
        py_version=py_version,
        processor=processor,
        region=region,
        account=account,
    )
コード例 #7
0
def test_vw(vw_version):
    version = "vw-{}".format(vw_version)
    uri = image_uris.retrieve("vw",
                              REGION,
                              version=version,
                              instance_type="ml.c4.xlarge")

    expected = expected_uris.framework_uri("sagemaker-rl-vw-container",
                                           version, RL_ACCOUNT)
    assert expected == uri
コード例 #8
0
def test_xgboost_framework(xgboost_framework_version):
    for region in regions.regions():
        uri = image_uris.retrieve(
            framework="xgboost", region=region, version=xgboost_framework_version, py_version="py3",
        )

        expected = expected_uris.framework_uri(
            "sagemaker-xgboost",
            xgboost_framework_version,
            FRAMEWORK_REGISTRIES[region],
            py_version="py3",
            region=region,
        )
        assert expected == uri
コード例 #9
0
def test_coach_mxnet(coach_mxnet_version):
    for instance_type, processor in INSTANCE_TYPES_AND_PROCESSORS:
        uri = image_uris.retrieve(
            "coach-mxnet", REGION, version=coach_mxnet_version, instance_type=instance_type
        )

        expected = expected_uris.framework_uri(
            "sagemaker-rl-mxnet",
            "coach{}".format(coach_mxnet_version),
            SAGEMAKER_ACCOUNT,
            py_version="py3",
            processor=processor,
        )
        assert expected == uri
コード例 #10
0
def _expected_tf_inference_uri(tf_inference_version, processor="cpu", region=REGION, eia=False):
    version = Version(tf_inference_version)
    repo = _expected_tf_inference_repo(version, eia)
    py_version = "py2" if version < Version("1.11") else None

    account = _sagemaker_or_dlc_account(repo, region)
    return expected_uris.framework_uri(
        repo,
        tf_inference_version,
        account,
        py_version,
        processor=processor,
        region=region,
    )
コード例 #11
0
def test_ray_pytorch(ray_pytorch_version):
    for instance_type, processor in INSTANCE_TYPES_AND_PROCESSORS:
        uri = image_uris.retrieve(
            "ray-pytorch", REGION, version=ray_pytorch_version, instance_type=instance_type
        )

        expected = expected_uris.framework_uri(
            "sagemaker-rl-ray-container",
            "ray-{}-torch".format(ray_pytorch_version),
            RL_ACCOUNT,
            py_version="py36",
            processor=processor,
        )

        assert expected == uri
コード例 #12
0
def _expected_pytorch_inference_uri(pytorch_version, py_version, processor="cpu", region=REGION):
    version = Version(pytorch_version)
    if version < Version("1.2"):
        repo = "sagemaker-pytorch"
    else:
        repo = "pytorch-inference"

    return expected_uris.framework_uri(
        repo,
        pytorch_version,
        _sagemaker_or_dlc_account(repo, region),
        py_version=py_version,
        processor=processor,
        region=region,
    )
コード例 #13
0
def test_xgboost_framework_cpu_only(xgboost_framework_version):
    for region in FRAMEWORK_REGISTRIES.keys():
        uri = image_uris.retrieve(
            framework="xgboost",
            region=region,
            version=xgboost_framework_version,
        )

        expected = expected_uris.framework_uri(
            "sagemaker-xgboost",
            xgboost_framework_version,
            FRAMEWORK_REGISTRIES[region],
            region=region,
            py_version="py3",
            processor="cpu",
        )
        assert expected == uri
コード例 #14
0
def _expected_mxnet_training_uri(mxnet_version, py_version, processor="cpu", region=REGION):
    version = Version(mxnet_version)
    if version < Version("1.4") or mxnet_version == "1.4.0":
        repo = "sagemaker-mxnet"
    elif version >= Version("1.6.0"):
        repo = "mxnet-training"
    else:
        repo = "sagemaker-mxnet" if py_version == "py2" else "mxnet-training"

    return expected_uris.framework_uri(
        repo,
        mxnet_version,
        _sagemaker_or_dlc_account(repo, region),
        py_version=py_version,
        processor=processor,
        region=region,
    )
コード例 #15
0
def test_valid_uris(sklearn_version):
    for region in regions.regions():
        uri = image_uris.retrieve(
            "sklearn",
            region=region,
            version=sklearn_version,
            py_version="py3",
            instance_type="ml.c4.xlarge",
        )

        expected = expected_uris.framework_uri(
            "sagemaker-scikit-learn",
            sklearn_version,
            ACCOUNTS[region],
            py_version="py3",
            region=region,
        )
        assert expected == uri
コード例 #16
0
def _expected_tf_training_uri(tf_training_version, py_version, processor="cpu", region=REGION):
    version = Version(tf_training_version)
    if version < Version("1.11"):
        repo = "sagemaker-tensorflow"
    elif version < Version("1.13"):
        repo = "sagemaker-tensorflow-scriptmode"
    elif version >= Version("1.14"):
        repo = "tensorflow-training"
    else:
        repo = "sagemaker-tensorflow-scriptmode" if py_version == "py2" else "tensorflow-training"

    return expected_uris.framework_uri(
        repo,
        tf_training_version,
        _sagemaker_or_dlc_account(repo, region),
        py_version=py_version,
        processor=processor,
        region=region,
    )
コード例 #17
0
def test_valid_uris(version):
    py_version = "py37" if version == "0.3.1" else "py38"
    for region in ACCOUNTS.keys():
        uri = image_uris.retrieve(
            "autogluon",
            region=region,
            version=version,
            py_version=py_version,
            image_scope="training",
            instance_type="ml.c4.xlarge",
        )

        expected = expected_uris.framework_uri(
            "autogluon-training",
            version,
            ACCOUNTS[region],
            py_version=py_version,
            region=region,
        )
        assert uri == expected