コード例 #1
0
def test_mxnet_training_sm_env_variables(mxnet_training):
    env_vars = {
        "SAGEMAKER_TRAINING_MODULE": "sagemaker_mxnet_container.training:main"
    }
    container_name_prefix = "mx_training_sm_env"
    execute_env_variables_test(image_uri=mxnet_training,
                               env_vars_to_test=env_vars,
                               container_name_prefix=container_name_prefix)
コード例 #2
0
def test_tensorflow_inference_sm_env_variables(tensorflow_inference):
    _, fw_version = get_framework_and_version_from_tag(tensorflow_inference)
    version_obj = Version(fw_version)
    tf_short_version = f"{version_obj.major}.{version_obj.minor}"
    env_vars = {"SAGEMAKER_TFS_VERSION": tf_short_version}
    container_name_prefix = "tf_inference_sm_env"
    execute_env_variables_test(image_uri=tensorflow_inference,
                               env_vars_to_test=env_vars,
                               container_name_prefix=container_name_prefix)
コード例 #3
0
def test_pytorch_inference_sm_env_variables(pytorch_inference):
    env_vars = {
        "SAGEMAKER_SERVING_MODULE":
        "sagemaker_pytorch_serving_container.serving:main"
    }
    container_name_prefix = "pt_inference_sm_env"
    execute_env_variables_test(image_uri=pytorch_inference,
                               env_vars_to_test=env_vars,
                               container_name_prefix=container_name_prefix)
コード例 #4
0
def test_pytorch_training_job_type_env_var(pytorch_training):
    _, image_framework_version = test_utils.get_framework_and_version_from_tag(pytorch_training)
    if Version(image_framework_version) < Version("1.10"):
        pytest.skip("This env variable was added after PT 1.10 release. Skipping test.")
    env_vars = {
        "DLC_CONTAINER_TYPE": "training"
    }
    container_name_prefix = "pt_train_job_type_env_var"
    test_utils.execute_env_variables_test(
        image_uri=pytorch_training,
        env_vars_to_test=env_vars,
        container_name_prefix=container_name_prefix
    )