def test_sparkml(version):
    for region in regions.regions():
        uri = image_uris.retrieve("sparkml-serving", region=region, version=version)

        expected = expected_uris.algo_uri(
            "sagemaker-sparkml-serving", ACCOUNTS[region], region, version=version
        )
        assert expected == uri
def test_xgboost_algo(xgboost_algo_version):
    for region in ALGO_REGISTRIES.keys():
        uri = image_uris.retrieve(framework="xgboost", region=region, version=xgboost_algo_version)

        expected = expected_uris.algo_uri(
            "xgboost", ALGO_REGISTRIES[region], region, version=xgboost_algo_version
        )
        assert expected == uri
Example #3
0
def test_debugger():
    for region in ACCOUNTS.keys():
        uri = image_uris.retrieve("debugger", region=region)
        expected = expected_uris.algo_uri("sagemaker-debugger-rules",
                                          ACCOUNTS[region],
                                          region,
                                          version="latest")
        assert expected == uri
def test_algo_uris(algo):
    for region in ACCOUNTS.keys():
        uri = image_uris.retrieve(algo, region)
        expected = expected_uris.algo_uri(algo,
                                          ACCOUNTS[region],
                                          region,
                                          version="latest")
        assert expected == uri
Example #5
0
def test_data_wrangler_ecr_uri():
    for region in DATA_WRANGLER_ACCOUNTS.keys():
        actual_uri = image_uris.retrieve("data-wrangler", region=region)
        expected_uri = expected_uris.algo_uri(
            "sagemaker-data-wrangler-container",
            DATA_WRANGLER_ACCOUNTS[region],
            region,
            version="1.x",
        )
        assert expected_uri == actual_uri
Example #6
0
def test_algo_uris(algo):
    for region in regions.regions():
        if region in ACCOUNTS:
            uri = image_uris.retrieve(algo, region)
            expected = expected_uris.algo_uri(algo, ACCOUNTS[region], region, version="latest")
            assert expected == uri
        else:
            with pytest.raises(ValueError) as e:
                image_uris.retrieve(algo, region)
            assert "Unsupported region: {}.".format(region) in str(e.value)
def test_lda():
    algo = "lda"
    accounts = _accounts_for_algo(algo)

    for region in regions.regions():
        if region in accounts:
            uri = image_uris.retrieve(algo, region)
            assert expected_uris.algo_uri(algo, accounts[region], region) == uri
        else:
            with pytest.raises(ValueError) as e:
                image_uris.retrieve(algo, region)
            assert "Unsupported region: {}.".format(region) in str(e.value)
Example #8
0
def test_algo_uris(algo):
    accounts = _accounts_for_algo(algo)
    for region in accounts:
        uri = image_uris.retrieve(algo, region)
        assert expected_uris.algo_uri(algo, accounts[region], region) == uri