def _test_neo_framework_uris(framework, version): framework_in_config = f"neo-{framework}" framework_in_uri = f"neo-{framework}" if framework == "tensorflow" else f"inference-{framework}" for region in regions.regions(): if region in ACCOUNTS: uri = image_uris.retrieve(framework_in_config, region, instance_type="ml_c5", version=version) assert _expected_framework_uri(framework_in_uri, version, region=region) == uri else: with pytest.raises(ValueError) as e: image_uris.retrieve(framework_in_config, region, instance_type="ml_c5", version=version) assert "Unsupported region: {}.".format(region) in str(e.value) uri = image_uris.retrieve(framework_in_config, "us-west-2", instance_type="ml_p2", version=version) assert _expected_framework_uri(framework_in_uri, version, processor="gpu") == uri
def test_model_monitor(): for region in regions.regions(): if region in ACCOUNTS.keys(): uri = image_uris.retrieve("model-monitor", region=region) expected = expected_uris.monitor_uri(ACCOUNTS[region], region) assert expected == uri
def test_sparkml(version): for region in regions.regions(): uri = image_uris.retrieve("sparkml-serving", region=region, version=version) expected = expected_uris.algo_uri( "sagemaker-sparkml-serving", ACCOUNTS[region], region, version=version ) assert expected == uri
def test_xgboost_algo(xgboost_algo_version): for region in regions.regions(): uri = image_uris.retrieve(framework="xgboost", region=region, version=xgboost_algo_version) expected = expected_uris.algo_uri( "xgboost", ALGO_REGISTRIES[region], region, version=xgboost_algo_version ) assert expected == uri
def test_debugger(): for region in regions.regions(): if region in ACCOUNTS.keys(): uri = image_uris.retrieve("debugger", region=region) expected = expected_uris.algo_uri( "sagemaker-debugger-rules", ACCOUNTS[region], region, version="latest" ) assert expected == uri
def test_algo_uris(algo): for region in regions.regions(): if region in ACCOUNTS: uri = image_uris.retrieve(algo, region) expected = expected_uris.algo_uri(algo, ACCOUNTS[region], region, version="latest") assert expected == uri else: with pytest.raises(ValueError) as e: image_uris.retrieve(algo, region) assert "Unsupported region: {}.".format(region) in str(e.value)
def test_data_wrangler_ecr_uri(): for region in regions.regions(): if region in DATA_WRANGLER_ACCOUNTS.keys(): actual_uri = image_uris.retrieve("data-wrangler", region=region) expected_uri = expected_uris.algo_uri( "sagemaker-data-wrangler-container", DATA_WRANGLER_ACCOUNTS[region], region, version="1.x", ) assert expected_uri == actual_uri
def test_lda(): algo = "lda" accounts = _accounts_for_algo(algo) for region in regions.regions(): if region in accounts: uri = image_uris.retrieve(algo, region) assert expected_uris.algo_uri(algo, accounts[region], region) == uri else: with pytest.raises(ValueError) as e: image_uris.retrieve(algo, region) assert "Unsupported region: {}.".format(region) in str(e.value)
def test_xgboost_framework(xgboost_framework_version): for region in regions.regions(): uri = image_uris.retrieve( framework="xgboost", region=region, version=xgboost_framework_version, py_version="py3", ) expected = expected_uris.framework_uri( "sagemaker-xgboost", xgboost_framework_version, FRAMEWORK_REGISTRIES[region], py_version="py3", region=region, ) assert expected == uri
def test_valid_uris(sklearn_version): for region in regions.regions(): uri = image_uris.retrieve( "sklearn", region=region, version=sklearn_version, py_version="py3", instance_type="ml.c4.xlarge", ) expected = expected_uris.framework_uri( "sagemaker-scikit-learn", sklearn_version, ACCOUNTS[region], py_version="py3", region=region, ) assert expected == uri
def _test_inferentia_framework_uris(framework, version): for region in regions.regions(): if region in INFERENTIA_REGIONS: uri = image_uris.retrieve( "inferentia-{}".format(framework), region, instance_type="ml_inf1", version=version ) expected = _expected_framework_uri( "neo-{}".format(framework), version, region=region, processor="inf" ) assert expected == uri else: with pytest.raises(ValueError) as e: image_uris.retrieve( "inferentia-{}".format(framework), region, instance_type="ml_inf", version=version, ) assert "Unsupported region: {}.".format(region) in str(e.value)
def _test_neo_framework_uris(framework, version): framework = "neo-{}".format(framework) for region in regions.regions(): if region in ACCOUNTS: uri = image_uris.retrieve(framework, region, instance_type="ml_c5", version=version) assert _expected_framework_uri(framework, version, region=region) == uri else: with pytest.raises(ValueError) as e: image_uris.retrieve(framework, region, instance_type="ml_c5", version=version) assert "Unsupported region: {}.".format(region) in str(e.value) uri = image_uris.retrieve(framework, "us-west-2", instance_type="ml_p2", version=version) assert _expected_framework_uri(framework, version, processor="gpu") == uri
def test_algo_uris(algo): accounts = _accounts_for_algo(algo) for region in regions.regions(): uri = image_uris.retrieve(algo, region) assert expected_uris.algo_uri(algo, accounts[region], region) == uri