def test_ortmodule_fallback_warn_message(is_training, persist_fallback):
    # is_training: True for torch.nn.Module training model, eval mode otherwise

    policy = "FALLBACK_UNSUPPORTED_DEVICE"
    os.environ["ORTMODULE_FALLBACK_POLICY"] = policy
    os.environ["ORTMODULE_FALLBACK_RETRY"] = str(not persist_fallback)
    os.environ["ORTMODULE_SKIPCHECK_POLICY"] = "SKIP_CHECK_DISABLED"

    data_device = "cuda"
    N, D_in, H, D_out = 64, 784, 500, 10

    pt_model = NeuralNetSinglePositionalArgument(D_in, H, D_out)
    ort_model = ORTModule(copy.deepcopy(pt_model))
    pt_model.train(is_training)
    ort_model.train(is_training)
    # For initial model export, use same device for data and model so that PyTorch model can be traced during export
    _ = ort_model(torch.randn(N, D_in))

    # Use data in different device for testing
    inputs = torch.randn(N, D_in, device=data_device)

    for _ in range(3):
        with pytest.raises(RuntimeError):
            with pytest.warns(UserWarning) as warning_record:
                ort_model(inputs)
        assert "Fallback to PyTorch due to exception" in str(
            warning_record[0].message.args[0])

    del os.environ["ORTMODULE_SKIPCHECK_POLICY"]
def test_ortmodule_fallback_device__mismatch(is_training, fallback_enabled, matching_policy, persist_fallback):
    # is_training: True for torch.nn.Module training model, eval mode otherwise
    # fallback_enabled: True PyTorch executes the forward graph instead of ORT backend
    # matching_policy: True matches FALLBACK_UNSUPPORTED_DEVICE policy to ORTModuleDeviceException exception.
    #   Otherwise, an incorrect policy (FALLBACK_UNSUPPORTED_DATA) is used to verify that the fallback does not happen

    if fallback_enabled:
        if matching_policy:
            policy = "FALLBACK_UNSUPPORTED_DEVICE"
        else:
            policy = "FALLBACK_UNSUPPORTED_DATA"
    else:
        policy = "FALLBACK_DISABLE"
    os.environ["ORTMODULE_FALLBACK_POLICY"] = policy
    os.environ["ORTMODULE_FALLBACK_RETRY"] = str(not persist_fallback)
    os.environ["ORTMODULE_SKIPCHECK_POLICY"] = "SKIP_CHECK_DISABLED"

    data_device = "cuda"
    N, D_in, H, D_out = 64, 784, 500, 10

    pt_model = NeuralNetSinglePositionalArgument(D_in, H, D_out)
    ort_model = ORTModule(copy.deepcopy(pt_model))
    pt_model.train(is_training)
    ort_model.train(is_training)
    # For initial model export, use same device for data and model so that PyTorch model can be traced during export
    _ = ort_model(torch.randn(N, D_in))

    # Use data in different device for testing
    inputs = torch.randn(N, D_in, device=data_device)
    ort_model_device = ort_model._torch_module._execution_manager(ort_model._is_training())._device
    input_device = torch.device(inputs.device)

    for _ in range(3):
        if fallback_enabled:
            if matching_policy:
                with pytest.raises(RuntimeError) as e:
                    ort_model(inputs)
                assert (
                    "Tensor for argument #1 'self' is on CPU, but expected them to be on GPU (while checking arguments for addmm)"
                    in str(e.value)
                ) or (
                    "Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0!"
                    in str(e.value)
                )
            else:
                with pytest.raises(_fallback.ORTModuleDeviceException) as e:
                    ort_model(inputs)
                assert (
                    f"Input argument to forward found on device {input_device}, "
                    f"but expected it to be on module device {ort_model_device}." in str(e.value)
                )
        else:
            with pytest.raises(_fallback.ORTModuleDeviceException) as e:
                ort_model(inputs)
            assert (
                f"Input argument to forward found on device {input_device}, "
                f"but expected it to be on module device {ort_model_device}." in str(e.value)
            )

    del os.environ["ORTMODULE_SKIPCHECK_POLICY"]
def test_ortmodule_fallback_init__torch_version(is_training, fallback_enabled, matching_policy, persist_fallback):
    # is_training: True for torch.nn.Module training model, eval mode otherwise
    # fallback_enabled: True PyTorch executes the forward graph instead of ORT backend
    # matching_policy: True matches FALLBACK_UNSUPPORTED_TORCH_MODEL policy to ORTModuleDeviceException exception.
    #   Otherwise, an incorrect policy (FALLBACK_UNSUPPORTED_DEVICE) is used to verify that the fallback does not happen

    from packaging import version

    from onnxruntime.training.ortmodule import MINIMUM_RUNTIME_PYTORCH_VERSION_STR

    runtime_pytorch_version = version.parse(torch.__version__.split("+")[0])
    minimum_runtime_pytorch_version = version.parse(MINIMUM_RUNTIME_PYTORCH_VERSION_STR)
    if runtime_pytorch_version < minimum_runtime_pytorch_version:

        if fallback_enabled:
            if matching_policy:
                policy = "FALLBACK_BAD_INITIALIZATION"
            else:
                policy = "FALLBACK_UNSUPPORTED_DEVICE"
        else:
            policy = "FALLBACK_DISABLE"
        os.environ["ORTMODULE_FALLBACK_POLICY"] = policy
        os.environ["ORTMODULE_FALLBACK_RETRY"] = str(not persist_fallback)

        device = "cuda"
        N, D_in, H, D_out = 64, 784, 500, 10
        x = torch.randn(N, D_in, device=device)

        pt_model = NeuralNetSinglePositionalArgument(D_in, H, D_out).to(device)

        for i in range(3):
            if fallback_enabled:
                if matching_policy:
                    ort_model = ORTModule(pt_model)
                    ort_model.train(is_training)
                    pt_model.train(is_training)

                    ort_out = ort_model(x)
                    pt_out = pt_model(x)
                    _test_helpers.assert_values_are_close(ort_out, pt_out, rtol=0, atol=0)
                else:
                    with pytest.raises(_fallback.ORTModuleInitException) as ex_info:
                        ort_model = ORTModule(pt_model)
                    assert "ONNX Runtime ORTModule frontend requires PyTorch version greater or equal to" in str(
                        ex_info.value
                    )
            else:
                with pytest.raises(_fallback.ORTModuleInitException) as ex_info:
                    # Initialize with fallback policy because Exception will happen during __init__
                    ort_model = ORTModule(pt_model)
                assert "ONNX Runtime ORTModule frontend requires PyTorch version greater or equal to" in str(
                    ex_info.value
                )
    else:
        warnings.warn(
            "Skipping test_ortmodule_fallback_torch_version."
            f" It requires PyTorch prior to {MINIMUM_RUNTIME_PYTORCH_VERSION_STR}"
        )
def test_ortmodule_fallback_init__missing_cpp_extensions(
        is_training, fallback_enabled, matching_policy, persist_fallback):
    # is_training: True for torch.nn.Module training model, eval mode otherwise
    # fallback_enabled: True PyTorch executes the forward graph instead of ORT backend
    # matching_policy: True matches FALLBACK_UNSUPPORTED_TORCH_MODEL policy to ORTModuleDeviceException exception.
    #   Otherwise, an incorrect policy (FALLBACK_UNSUPPORTED_DEVICE) is used to verify that the fallback does not happen

    if is_torch_cpp_extensions_installed(ORTMODULE_TORCH_CPP_DIR):
        warnings.warn(
            "Skipping test_ortmodule_fallback_init__missing_cpp_extensions."
            f" It requires PyTorch CPP extensions to be missing")
    else:

        if fallback_enabled:
            if matching_policy:
                policy = "FALLBACK_BAD_INITIALIZATION"
            else:
                policy = "FALLBACK_UNSUPPORTED_DEVICE"
        else:
            policy = "FALLBACK_DISABLE"
        os.environ["ORTMODULE_FALLBACK_POLICY"] = policy
        os.environ["ORTMODULE_FALLBACK_RETRY"] = str(not persist_fallback)

        device = "cuda"
        N, D_in, H, D_out = 64, 784, 500, 10
        x = torch.randn(N, D_in, device=device)

        pt_model = NeuralNetSinglePositionalArgument(D_in, H, D_out).to(device)

        for _ in range(3):
            if fallback_enabled:
                if matching_policy:
                    ort_model = ORTModule(pt_model)
                    ort_model.train(is_training)
                    pt_model.train(is_training)

                    ort_out = ort_model(x)
                    pt_out = pt_model(x)
                    _test_helpers.assert_values_are_close(ort_out,
                                                          pt_out,
                                                          rtol=0,
                                                          atol=0)
                else:
                    with pytest.raises(
                            _fallback.ORTModuleInitException) as ex_info:
                        ort_model = ORTModule(pt_model)
                    assert "ORTModule's extensions were not detected" in str(
                        ex_info.value)
            else:
                with pytest.raises(
                        _fallback.ORTModuleInitException) as ex_info:
                    # Initialize with fallback policy because Exception will happen during __init__
                    ort_model = ORTModule(pt_model)
                assert "ORTModule's extensions were not detected" in str(
                    ex_info.value)
def test_ortmodule_fallback_torch_model(is_training, fallback_enabled,
                                        matching_policy, persist_fallback):
    # is_training: True for torch.nn.Module training model, eval mode otherwise
    # fallback_enabled: True PyTorch executes the forward graph instead of ORT backend
    # matching_policy: True matches FALLBACK_UNSUPPORTED_TORCH_MODEL policy to ORTModuleDeviceException exception.
    #   Otherwise, an incorrect policy (FALLBACK_UNSUPPORTED_DEVICE) is used to verify that the fallback does not happen

    if fallback_enabled:
        if matching_policy:
            policy = "FALLBACK_UNSUPPORTED_TORCH_MODEL"
        else:
            policy = "FALLBACK_UNSUPPORTED_DEVICE"
    else:
        policy = "FALLBACK_DISABLE"
    os.environ["ORTMODULE_FALLBACK_POLICY"] = policy
    os.environ["ORTMODULE_FALLBACK_RETRY"] = str(not persist_fallback)

    device = "cuda"
    N, D_in, H, D_out = 64, 784, 500, 10
    x = torch.randn(N, D_in, device=device)

    pt_model = NeuralNetSinglePositionalArgument(D_in, H, D_out).to(device)
    pt_model = torch.nn.DataParallel(pt_model)

    for _ in range(3):
        if fallback_enabled:
            if matching_policy:
                ort_model = ORTModule(pt_model)
                ort_model.train(is_training)
                pt_model.train(is_training)

                ort_out = ort_model(x)
                pt_out = pt_model(x)
                _test_helpers.assert_values_are_close(ort_out,
                                                      pt_out,
                                                      rtol=1e-3,
                                                      atol=1e-6)
            else:
                with pytest.raises(
                        _fallback.ORTModuleTorchModelException) as ex_info:
                    ort_model = ORTModule(pt_model)
                assert "ORTModule is not compatible with torch.nn.DataParallel" in str(
                    ex_info.value)
        else:
            with pytest.raises(
                    _fallback.ORTModuleTorchModelException) as ex_info:
                # Initialize with fallback policy because Exception will happen during __init__
                ort_model = ORTModule(pt_model)
            assert "ORTModule is not compatible with torch.nn.DataParallel" in str(
                ex_info.value)