def test_pytorch_eia(pytorch_eia_version, pytorch_eia_py_version): base_args = { "framework": "pytorch", "version": pytorch_eia_version, "py_version": pytorch_eia_py_version, "image_scope": "inference", "instance_type": "ml.c4.xlarge", "accelerator_type": "ml.eia1.medium", } uri = image_uris.retrieve(region=REGION, **base_args) expected = expected_uris.framework_uri( "pytorch-inference-eia", pytorch_eia_version, DLC_ACCOUNT, py_version=pytorch_eia_py_version, region=REGION, ) assert expected == uri for region, account in DLC_ALTERNATE_REGION_ACCOUNTS.items(): uri = image_uris.retrieve(region=region, **base_args) expected = expected_uris.framework_uri( "pytorch-inference-eia", pytorch_eia_version, account, py_version=pytorch_eia_py_version, region=region, ) assert expected == uri
def _expected_ray_tf_uri(ray_tf_version, processor): if Version(ray_tf_version) > Version("1.0.0"): return expected_uris.framework_uri( "sagemaker-rl-ray-container", _version_for_tag("ray", ray_tf_version, "tf", True), RL_ACCOUNT, py_version="py37", processor=processor, ) elif Version(ray_tf_version) > Version("0.6.5"): return expected_uris.framework_uri( "sagemaker-rl-ray-container", _version_for_tag("ray", ray_tf_version, "tf", True), RL_ACCOUNT, py_version="py36", processor=processor, ) else: return expected_uris.framework_uri( "sagemaker-rl-tensorflow", _version_for_tag("ray", ray_tf_version, "tf"), SAGEMAKER_ACCOUNT, py_version="py3", processor=processor, )
def _expected_mxnet_inference_uri(mxnet_version, py_version, processor="cpu", region=REGION, eia=False): version = Version(mxnet_version) if version < Version("1.4"): repo = "sagemaker-mxnet" elif mxnet_version == "1.4.0": repo = "sagemaker-mxnet-serving" elif version >= Version("1.5"): repo = "mxnet-inference" else: repo = "sagemaker-mxnet-serving" if py_version == "py2" and not eia else "mxnet-inference" if eia: repo = "-".join((repo, "eia")) return expected_uris.framework_uri( repo, mxnet_version, _sagemaker_or_dlc_account(repo, region), py_version=py_version, processor=processor, region=region, )
def _expected_coach_tf_uri(coach_tf_version, processor): if Version(coach_tf_version) > Version("0.11.1"): return expected_uris.framework_uri( "sagemaker-rl-coach-container", _version_for_tag("coach", coach_tf_version, "tf", True), RL_ACCOUNT, py_version="py3", processor=processor, ) else: return expected_uris.framework_uri( "sagemaker-rl-tensorflow", _version_for_tag("coach", coach_tf_version, "tf"), SAGEMAKER_ACCOUNT, py_version="py3", processor=processor, )
def _expected_framework_uri(framework, version, region="us-west-2", processor="cpu"): return expected_uris.framework_uri( "sagemaker-{}".format(framework), fw_version=version, py_version="py3", account=ACCOUNTS[region], region=region, processor=processor, )
def _expected_chainer_uri(chainer_version, py_version, processor="cpu", region=REGION): account = SAGEMAKER_ACCOUNT if region == REGION else SAGEMAKER_ALTERNATE_REGION_ACCOUNTS[region] return expected_uris.framework_uri( repo="sagemaker-chainer", fw_version=chainer_version, py_version=py_version, processor=processor, region=region, account=account, )
def test_vw(vw_version): version = "vw-{}".format(vw_version) uri = image_uris.retrieve("vw", REGION, version=version, instance_type="ml.c4.xlarge") expected = expected_uris.framework_uri("sagemaker-rl-vw-container", version, RL_ACCOUNT) assert expected == uri
def test_xgboost_framework(xgboost_framework_version): for region in regions.regions(): uri = image_uris.retrieve( framework="xgboost", region=region, version=xgboost_framework_version, py_version="py3", ) expected = expected_uris.framework_uri( "sagemaker-xgboost", xgboost_framework_version, FRAMEWORK_REGISTRIES[region], py_version="py3", region=region, ) assert expected == uri
def test_coach_mxnet(coach_mxnet_version): for instance_type, processor in INSTANCE_TYPES_AND_PROCESSORS: uri = image_uris.retrieve( "coach-mxnet", REGION, version=coach_mxnet_version, instance_type=instance_type ) expected = expected_uris.framework_uri( "sagemaker-rl-mxnet", "coach{}".format(coach_mxnet_version), SAGEMAKER_ACCOUNT, py_version="py3", processor=processor, ) assert expected == uri
def _expected_tf_inference_uri(tf_inference_version, processor="cpu", region=REGION, eia=False): version = Version(tf_inference_version) repo = _expected_tf_inference_repo(version, eia) py_version = "py2" if version < Version("1.11") else None account = _sagemaker_or_dlc_account(repo, region) return expected_uris.framework_uri( repo, tf_inference_version, account, py_version, processor=processor, region=region, )
def test_ray_pytorch(ray_pytorch_version): for instance_type, processor in INSTANCE_TYPES_AND_PROCESSORS: uri = image_uris.retrieve( "ray-pytorch", REGION, version=ray_pytorch_version, instance_type=instance_type ) expected = expected_uris.framework_uri( "sagemaker-rl-ray-container", "ray-{}-torch".format(ray_pytorch_version), RL_ACCOUNT, py_version="py36", processor=processor, ) assert expected == uri
def _expected_pytorch_inference_uri(pytorch_version, py_version, processor="cpu", region=REGION): version = Version(pytorch_version) if version < Version("1.2"): repo = "sagemaker-pytorch" else: repo = "pytorch-inference" return expected_uris.framework_uri( repo, pytorch_version, _sagemaker_or_dlc_account(repo, region), py_version=py_version, processor=processor, region=region, )
def test_xgboost_framework_cpu_only(xgboost_framework_version): for region in FRAMEWORK_REGISTRIES.keys(): uri = image_uris.retrieve( framework="xgboost", region=region, version=xgboost_framework_version, ) expected = expected_uris.framework_uri( "sagemaker-xgboost", xgboost_framework_version, FRAMEWORK_REGISTRIES[region], region=region, py_version="py3", processor="cpu", ) assert expected == uri
def _expected_mxnet_training_uri(mxnet_version, py_version, processor="cpu", region=REGION): version = Version(mxnet_version) if version < Version("1.4") or mxnet_version == "1.4.0": repo = "sagemaker-mxnet" elif version >= Version("1.6.0"): repo = "mxnet-training" else: repo = "sagemaker-mxnet" if py_version == "py2" else "mxnet-training" return expected_uris.framework_uri( repo, mxnet_version, _sagemaker_or_dlc_account(repo, region), py_version=py_version, processor=processor, region=region, )
def test_valid_uris(sklearn_version): for region in regions.regions(): uri = image_uris.retrieve( "sklearn", region=region, version=sklearn_version, py_version="py3", instance_type="ml.c4.xlarge", ) expected = expected_uris.framework_uri( "sagemaker-scikit-learn", sklearn_version, ACCOUNTS[region], py_version="py3", region=region, ) assert expected == uri
def _expected_tf_training_uri(tf_training_version, py_version, processor="cpu", region=REGION): version = Version(tf_training_version) if version < Version("1.11"): repo = "sagemaker-tensorflow" elif version < Version("1.13"): repo = "sagemaker-tensorflow-scriptmode" elif version >= Version("1.14"): repo = "tensorflow-training" else: repo = "sagemaker-tensorflow-scriptmode" if py_version == "py2" else "tensorflow-training" return expected_uris.framework_uri( repo, tf_training_version, _sagemaker_or_dlc_account(repo, region), py_version=py_version, processor=processor, region=region, )
def test_valid_uris(version): py_version = "py37" if version == "0.3.1" else "py38" for region in ACCOUNTS.keys(): uri = image_uris.retrieve( "autogluon", region=region, version=version, py_version=py_version, image_scope="training", instance_type="ml.c4.xlarge", ) expected = expected_uris.framework_uri( "autogluon-training", version, ACCOUNTS[region], py_version=py_version, region=region, ) assert uri == expected