예제 #1
0
def _test_neo_framework_uris(framework, version):
    framework_in_config = f"neo-{framework}"
    framework_in_uri = f"neo-{framework}" if framework == "tensorflow" else f"inference-{framework}"

    for region in regions.regions():
        if region in ACCOUNTS:
            uri = image_uris.retrieve(framework_in_config,
                                      region,
                                      instance_type="ml_c5",
                                      version=version)
            assert _expected_framework_uri(framework_in_uri,
                                           version,
                                           region=region) == uri
        else:
            with pytest.raises(ValueError) as e:
                image_uris.retrieve(framework_in_config,
                                    region,
                                    instance_type="ml_c5",
                                    version=version)
            assert "Unsupported region: {}.".format(region) in str(e.value)

    uri = image_uris.retrieve(framework_in_config,
                              "us-west-2",
                              instance_type="ml_p2",
                              version=version)
    assert _expected_framework_uri(framework_in_uri, version,
                                   processor="gpu") == uri
예제 #2
0
def test_model_monitor():
    for region in regions.regions():
        if region in ACCOUNTS.keys():
            uri = image_uris.retrieve("model-monitor", region=region)

            expected = expected_uris.monitor_uri(ACCOUNTS[region], region)
            assert expected == uri
def test_sparkml(version):
    for region in regions.regions():
        uri = image_uris.retrieve("sparkml-serving", region=region, version=version)

        expected = expected_uris.algo_uri(
            "sagemaker-sparkml-serving", ACCOUNTS[region], region, version=version
        )
        assert expected == uri
def test_xgboost_algo(xgboost_algo_version):
    for region in regions.regions():
        uri = image_uris.retrieve(framework="xgboost", region=region, version=xgboost_algo_version)

        expected = expected_uris.algo_uri(
            "xgboost", ALGO_REGISTRIES[region], region, version=xgboost_algo_version
        )
        assert expected == uri
def test_debugger():
    for region in regions.regions():
        if region in ACCOUNTS.keys():
            uri = image_uris.retrieve("debugger", region=region)

            expected = expected_uris.algo_uri(
                "sagemaker-debugger-rules", ACCOUNTS[region], region, version="latest"
            )
            assert expected == uri
예제 #6
0
def test_algo_uris(algo):
    for region in regions.regions():
        if region in ACCOUNTS:
            uri = image_uris.retrieve(algo, region)
            expected = expected_uris.algo_uri(algo, ACCOUNTS[region], region, version="latest")
            assert expected == uri
        else:
            with pytest.raises(ValueError) as e:
                image_uris.retrieve(algo, region)
            assert "Unsupported region: {}.".format(region) in str(e.value)
def test_data_wrangler_ecr_uri():
    for region in regions.regions():
        if region in DATA_WRANGLER_ACCOUNTS.keys():
            actual_uri = image_uris.retrieve("data-wrangler", region=region)

            expected_uri = expected_uris.algo_uri(
                "sagemaker-data-wrangler-container",
                DATA_WRANGLER_ACCOUNTS[region],
                region,
                version="1.x",
            )
            assert expected_uri == actual_uri
def test_lda():
    algo = "lda"
    accounts = _accounts_for_algo(algo)

    for region in regions.regions():
        if region in accounts:
            uri = image_uris.retrieve(algo, region)
            assert expected_uris.algo_uri(algo, accounts[region], region) == uri
        else:
            with pytest.raises(ValueError) as e:
                image_uris.retrieve(algo, region)
            assert "Unsupported region: {}.".format(region) in str(e.value)
def test_xgboost_framework(xgboost_framework_version):
    for region in regions.regions():
        uri = image_uris.retrieve(
            framework="xgboost", region=region, version=xgboost_framework_version, py_version="py3",
        )

        expected = expected_uris.framework_uri(
            "sagemaker-xgboost",
            xgboost_framework_version,
            FRAMEWORK_REGISTRIES[region],
            py_version="py3",
            region=region,
        )
        assert expected == uri
예제 #10
0
def test_valid_uris(sklearn_version):
    for region in regions.regions():
        uri = image_uris.retrieve(
            "sklearn",
            region=region,
            version=sklearn_version,
            py_version="py3",
            instance_type="ml.c4.xlarge",
        )

        expected = expected_uris.framework_uri(
            "sagemaker-scikit-learn",
            sklearn_version,
            ACCOUNTS[region],
            py_version="py3",
            region=region,
        )
        assert expected == uri
예제 #11
0
def _test_inferentia_framework_uris(framework, version):
    for region in regions.regions():
        if region in INFERENTIA_REGIONS:
            uri = image_uris.retrieve(
                "inferentia-{}".format(framework), region, instance_type="ml_inf1", version=version
            )
            expected = _expected_framework_uri(
                "neo-{}".format(framework), version, region=region, processor="inf"
            )
            assert expected == uri
        else:
            with pytest.raises(ValueError) as e:
                image_uris.retrieve(
                    "inferentia-{}".format(framework),
                    region,
                    instance_type="ml_inf",
                    version=version,
                )
            assert "Unsupported region: {}.".format(region) in str(e.value)
예제 #12
0
def _test_neo_framework_uris(framework, version):
    framework = "neo-{}".format(framework)

    for region in regions.regions():
        if region in ACCOUNTS:
            uri = image_uris.retrieve(framework,
                                      region,
                                      instance_type="ml_c5",
                                      version=version)
            assert _expected_framework_uri(framework, version,
                                           region=region) == uri
        else:
            with pytest.raises(ValueError) as e:
                image_uris.retrieve(framework,
                                    region,
                                    instance_type="ml_c5",
                                    version=version)
            assert "Unsupported region: {}.".format(region) in str(e.value)

    uri = image_uris.retrieve(framework,
                              "us-west-2",
                              instance_type="ml_p2",
                              version=version)
    assert _expected_framework_uri(framework, version, processor="gpu") == uri
def test_algo_uris(algo):
    accounts = _accounts_for_algo(algo)

    for region in regions.regions():
        uri = image_uris.retrieve(algo, region)
        assert expected_uris.algo_uri(algo, accounts[region], region) == uri