def test_sparkml(version): for region in regions.regions(): uri = image_uris.retrieve("sparkml-serving", region=region, version=version) expected = expected_uris.algo_uri( "sagemaker-sparkml-serving", ACCOUNTS[region], region, version=version ) assert expected == uri
def test_xgboost_algo(xgboost_algo_version): for region in ALGO_REGISTRIES.keys(): uri = image_uris.retrieve(framework="xgboost", region=region, version=xgboost_algo_version) expected = expected_uris.algo_uri( "xgboost", ALGO_REGISTRIES[region], region, version=xgboost_algo_version ) assert expected == uri
def test_debugger(): for region in ACCOUNTS.keys(): uri = image_uris.retrieve("debugger", region=region) expected = expected_uris.algo_uri("sagemaker-debugger-rules", ACCOUNTS[region], region, version="latest") assert expected == uri
def test_algo_uris(algo): for region in ACCOUNTS.keys(): uri = image_uris.retrieve(algo, region) expected = expected_uris.algo_uri(algo, ACCOUNTS[region], region, version="latest") assert expected == uri
def test_data_wrangler_ecr_uri(): for region in DATA_WRANGLER_ACCOUNTS.keys(): actual_uri = image_uris.retrieve("data-wrangler", region=region) expected_uri = expected_uris.algo_uri( "sagemaker-data-wrangler-container", DATA_WRANGLER_ACCOUNTS[region], region, version="1.x", ) assert expected_uri == actual_uri
def test_algo_uris(algo): for region in regions.regions(): if region in ACCOUNTS: uri = image_uris.retrieve(algo, region) expected = expected_uris.algo_uri(algo, ACCOUNTS[region], region, version="latest") assert expected == uri else: with pytest.raises(ValueError) as e: image_uris.retrieve(algo, region) assert "Unsupported region: {}.".format(region) in str(e.value)
def test_lda(): algo = "lda" accounts = _accounts_for_algo(algo) for region in regions.regions(): if region in accounts: uri = image_uris.retrieve(algo, region) assert expected_uris.algo_uri(algo, accounts[region], region) == uri else: with pytest.raises(ValueError) as e: image_uris.retrieve(algo, region) assert "Unsupported region: {}.".format(region) in str(e.value)
def test_algo_uris(algo): accounts = _accounts_for_algo(algo) for region in accounts: uri = image_uris.retrieve(algo, region) assert expected_uris.algo_uri(algo, accounts[region], region) == uri