def test_mxnet_training_sm_env_variables(mxnet_training): env_vars = { "SAGEMAKER_TRAINING_MODULE": "sagemaker_mxnet_container.training:main" } container_name_prefix = "mx_training_sm_env" execute_env_variables_test(image_uri=mxnet_training, env_vars_to_test=env_vars, container_name_prefix=container_name_prefix)
def test_tensorflow_inference_sm_env_variables(tensorflow_inference): _, fw_version = get_framework_and_version_from_tag(tensorflow_inference) version_obj = Version(fw_version) tf_short_version = f"{version_obj.major}.{version_obj.minor}" env_vars = {"SAGEMAKER_TFS_VERSION": tf_short_version} container_name_prefix = "tf_inference_sm_env" execute_env_variables_test(image_uri=tensorflow_inference, env_vars_to_test=env_vars, container_name_prefix=container_name_prefix)
def test_pytorch_inference_sm_env_variables(pytorch_inference): env_vars = { "SAGEMAKER_SERVING_MODULE": "sagemaker_pytorch_serving_container.serving:main" } container_name_prefix = "pt_inference_sm_env" execute_env_variables_test(image_uri=pytorch_inference, env_vars_to_test=env_vars, container_name_prefix=container_name_prefix)
def test_pytorch_training_job_type_env_var(pytorch_training): _, image_framework_version = test_utils.get_framework_and_version_from_tag(pytorch_training) if Version(image_framework_version) < Version("1.10"): pytest.skip("This env variable was added after PT 1.10 release. Skipping test.") env_vars = { "DLC_CONTAINER_TYPE": "training" } container_name_prefix = "pt_train_job_type_env_var" test_utils.execute_env_variables_test( image_uri=pytorch_training, env_vars_to_test=env_vars, container_name_prefix=container_name_prefix )