def test_ortmodule_fallback_warn_message(is_training, persist_fallback): # is_training: True for torch.nn.Module training model, eval mode otherwise policy = "FALLBACK_UNSUPPORTED_DEVICE" os.environ["ORTMODULE_FALLBACK_POLICY"] = policy os.environ["ORTMODULE_FALLBACK_RETRY"] = str(not persist_fallback) os.environ["ORTMODULE_SKIPCHECK_POLICY"] = "SKIP_CHECK_DISABLED" data_device = "cuda" N, D_in, H, D_out = 64, 784, 500, 10 pt_model = NeuralNetSinglePositionalArgument(D_in, H, D_out) ort_model = ORTModule(copy.deepcopy(pt_model)) pt_model.train(is_training) ort_model.train(is_training) # For initial model export, use same device for data and model so that PyTorch model can be traced during export _ = ort_model(torch.randn(N, D_in)) # Use data in different device for testing inputs = torch.randn(N, D_in, device=data_device) for _ in range(3): with pytest.raises(RuntimeError): with pytest.warns(UserWarning) as warning_record: ort_model(inputs) assert "Fallback to PyTorch due to exception" in str( warning_record[0].message.args[0]) del os.environ["ORTMODULE_SKIPCHECK_POLICY"]
def test_ortmodule_fallback_device__mismatch(is_training, fallback_enabled, matching_policy, persist_fallback): # is_training: True for torch.nn.Module training model, eval mode otherwise # fallback_enabled: True PyTorch executes the forward graph instead of ORT backend # matching_policy: True matches FALLBACK_UNSUPPORTED_DEVICE policy to ORTModuleDeviceException exception. # Otherwise, an incorrect policy (FALLBACK_UNSUPPORTED_DATA) is used to verify that the fallback does not happen if fallback_enabled: if matching_policy: policy = "FALLBACK_UNSUPPORTED_DEVICE" else: policy = "FALLBACK_UNSUPPORTED_DATA" else: policy = "FALLBACK_DISABLE" os.environ["ORTMODULE_FALLBACK_POLICY"] = policy os.environ["ORTMODULE_FALLBACK_RETRY"] = str(not persist_fallback) os.environ["ORTMODULE_SKIPCHECK_POLICY"] = "SKIP_CHECK_DISABLED" data_device = "cuda" N, D_in, H, D_out = 64, 784, 500, 10 pt_model = NeuralNetSinglePositionalArgument(D_in, H, D_out) ort_model = ORTModule(copy.deepcopy(pt_model)) pt_model.train(is_training) ort_model.train(is_training) # For initial model export, use same device for data and model so that PyTorch model can be traced during export _ = ort_model(torch.randn(N, D_in)) # Use data in different device for testing inputs = torch.randn(N, D_in, device=data_device) ort_model_device = ort_model._torch_module._execution_manager(ort_model._is_training())._device input_device = torch.device(inputs.device) for _ in range(3): if fallback_enabled: if matching_policy: with pytest.raises(RuntimeError) as e: ort_model(inputs) assert ( "Tensor for argument #1 'self' is on CPU, but expected them to be on GPU (while checking arguments for addmm)" in str(e.value) ) or ( "Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0!" in str(e.value) ) else: with pytest.raises(_fallback.ORTModuleDeviceException) as e: ort_model(inputs) assert ( f"Input argument to forward found on device {input_device}, " f"but expected it to be on module device {ort_model_device}." in str(e.value) ) else: with pytest.raises(_fallback.ORTModuleDeviceException) as e: ort_model(inputs) assert ( f"Input argument to forward found on device {input_device}, " f"but expected it to be on module device {ort_model_device}." in str(e.value) ) del os.environ["ORTMODULE_SKIPCHECK_POLICY"]
def test_ortmodule_fallback_init__torch_version(is_training, fallback_enabled, matching_policy, persist_fallback): # is_training: True for torch.nn.Module training model, eval mode otherwise # fallback_enabled: True PyTorch executes the forward graph instead of ORT backend # matching_policy: True matches FALLBACK_UNSUPPORTED_TORCH_MODEL policy to ORTModuleDeviceException exception. # Otherwise, an incorrect policy (FALLBACK_UNSUPPORTED_DEVICE) is used to verify that the fallback does not happen from packaging import version from onnxruntime.training.ortmodule import MINIMUM_RUNTIME_PYTORCH_VERSION_STR runtime_pytorch_version = version.parse(torch.__version__.split("+")[0]) minimum_runtime_pytorch_version = version.parse(MINIMUM_RUNTIME_PYTORCH_VERSION_STR) if runtime_pytorch_version < minimum_runtime_pytorch_version: if fallback_enabled: if matching_policy: policy = "FALLBACK_BAD_INITIALIZATION" else: policy = "FALLBACK_UNSUPPORTED_DEVICE" else: policy = "FALLBACK_DISABLE" os.environ["ORTMODULE_FALLBACK_POLICY"] = policy os.environ["ORTMODULE_FALLBACK_RETRY"] = str(not persist_fallback) device = "cuda" N, D_in, H, D_out = 64, 784, 500, 10 x = torch.randn(N, D_in, device=device) pt_model = NeuralNetSinglePositionalArgument(D_in, H, D_out).to(device) for i in range(3): if fallback_enabled: if matching_policy: ort_model = ORTModule(pt_model) ort_model.train(is_training) pt_model.train(is_training) ort_out = ort_model(x) pt_out = pt_model(x) _test_helpers.assert_values_are_close(ort_out, pt_out, rtol=0, atol=0) else: with pytest.raises(_fallback.ORTModuleInitException) as ex_info: ort_model = ORTModule(pt_model) assert "ONNX Runtime ORTModule frontend requires PyTorch version greater or equal to" in str( ex_info.value ) else: with pytest.raises(_fallback.ORTModuleInitException) as ex_info: # Initialize with fallback policy because Exception will happen during __init__ ort_model = ORTModule(pt_model) assert "ONNX Runtime ORTModule frontend requires PyTorch version greater or equal to" in str( ex_info.value ) else: warnings.warn( "Skipping test_ortmodule_fallback_torch_version." f" It requires PyTorch prior to {MINIMUM_RUNTIME_PYTORCH_VERSION_STR}" )
def test_ortmodule_fallback_init__missing_cpp_extensions( is_training, fallback_enabled, matching_policy, persist_fallback): # is_training: True for torch.nn.Module training model, eval mode otherwise # fallback_enabled: True PyTorch executes the forward graph instead of ORT backend # matching_policy: True matches FALLBACK_UNSUPPORTED_TORCH_MODEL policy to ORTModuleDeviceException exception. # Otherwise, an incorrect policy (FALLBACK_UNSUPPORTED_DEVICE) is used to verify that the fallback does not happen if is_torch_cpp_extensions_installed(ORTMODULE_TORCH_CPP_DIR): warnings.warn( "Skipping test_ortmodule_fallback_init__missing_cpp_extensions." f" It requires PyTorch CPP extensions to be missing") else: if fallback_enabled: if matching_policy: policy = "FALLBACK_BAD_INITIALIZATION" else: policy = "FALLBACK_UNSUPPORTED_DEVICE" else: policy = "FALLBACK_DISABLE" os.environ["ORTMODULE_FALLBACK_POLICY"] = policy os.environ["ORTMODULE_FALLBACK_RETRY"] = str(not persist_fallback) device = "cuda" N, D_in, H, D_out = 64, 784, 500, 10 x = torch.randn(N, D_in, device=device) pt_model = NeuralNetSinglePositionalArgument(D_in, H, D_out).to(device) for _ in range(3): if fallback_enabled: if matching_policy: ort_model = ORTModule(pt_model) ort_model.train(is_training) pt_model.train(is_training) ort_out = ort_model(x) pt_out = pt_model(x) _test_helpers.assert_values_are_close(ort_out, pt_out, rtol=0, atol=0) else: with pytest.raises( _fallback.ORTModuleInitException) as ex_info: ort_model = ORTModule(pt_model) assert "ORTModule's extensions were not detected" in str( ex_info.value) else: with pytest.raises( _fallback.ORTModuleInitException) as ex_info: # Initialize with fallback policy because Exception will happen during __init__ ort_model = ORTModule(pt_model) assert "ORTModule's extensions were not detected" in str( ex_info.value)
def test_ortmodule_fallback_torch_model(is_training, fallback_enabled, matching_policy, persist_fallback): # is_training: True for torch.nn.Module training model, eval mode otherwise # fallback_enabled: True PyTorch executes the forward graph instead of ORT backend # matching_policy: True matches FALLBACK_UNSUPPORTED_TORCH_MODEL policy to ORTModuleDeviceException exception. # Otherwise, an incorrect policy (FALLBACK_UNSUPPORTED_DEVICE) is used to verify that the fallback does not happen if fallback_enabled: if matching_policy: policy = "FALLBACK_UNSUPPORTED_TORCH_MODEL" else: policy = "FALLBACK_UNSUPPORTED_DEVICE" else: policy = "FALLBACK_DISABLE" os.environ["ORTMODULE_FALLBACK_POLICY"] = policy os.environ["ORTMODULE_FALLBACK_RETRY"] = str(not persist_fallback) device = "cuda" N, D_in, H, D_out = 64, 784, 500, 10 x = torch.randn(N, D_in, device=device) pt_model = NeuralNetSinglePositionalArgument(D_in, H, D_out).to(device) pt_model = torch.nn.DataParallel(pt_model) for _ in range(3): if fallback_enabled: if matching_policy: ort_model = ORTModule(pt_model) ort_model.train(is_training) pt_model.train(is_training) ort_out = ort_model(x) pt_out = pt_model(x) _test_helpers.assert_values_are_close(ort_out, pt_out, rtol=1e-3, atol=1e-6) else: with pytest.raises( _fallback.ORTModuleTorchModelException) as ex_info: ort_model = ORTModule(pt_model) assert "ORTModule is not compatible with torch.nn.DataParallel" in str( ex_info.value) else: with pytest.raises( _fallback.ORTModuleTorchModelException) as ex_info: # Initialize with fallback policy because Exception will happen during __init__ ort_model = ORTModule(pt_model) assert "ORTModule is not compatible with torch.nn.DataParallel" in str( ex_info.value)