def test_ortmodule_fallback_init__torch_version(is_training, fallback_enabled, matching_policy, persist_fallback):
    # is_training: True for torch.nn.Module training model, eval mode otherwise
    # fallback_enabled: True PyTorch executes the forward graph instead of ORT backend
    # matching_policy: True matches FALLBACK_UNSUPPORTED_TORCH_MODEL policy to ORTModuleDeviceException exception.
    #   Otherwise, an incorrect policy (FALLBACK_UNSUPPORTED_DEVICE) is used to verify that the fallback does not happen

    from packaging import version

    from onnxruntime.training.ortmodule import MINIMUM_RUNTIME_PYTORCH_VERSION_STR

    runtime_pytorch_version = version.parse(torch.__version__.split("+")[0])
    minimum_runtime_pytorch_version = version.parse(MINIMUM_RUNTIME_PYTORCH_VERSION_STR)
    if runtime_pytorch_version < minimum_runtime_pytorch_version:

        if fallback_enabled:
            if matching_policy:
                policy = "FALLBACK_BAD_INITIALIZATION"
            else:
                policy = "FALLBACK_UNSUPPORTED_DEVICE"
        else:
            policy = "FALLBACK_DISABLE"
        os.environ["ORTMODULE_FALLBACK_POLICY"] = policy
        os.environ["ORTMODULE_FALLBACK_RETRY"] = str(not persist_fallback)

        device = "cuda"
        N, D_in, H, D_out = 64, 784, 500, 10
        x = torch.randn(N, D_in, device=device)

        pt_model = NeuralNetSinglePositionalArgument(D_in, H, D_out).to(device)

        for i in range(3):
            if fallback_enabled:
                if matching_policy:
                    ort_model = ORTModule(pt_model)
                    ort_model.train(is_training)
                    pt_model.train(is_training)

                    ort_out = ort_model(x)
                    pt_out = pt_model(x)
                    _test_helpers.assert_values_are_close(ort_out, pt_out, rtol=0, atol=0)
                else:
                    with pytest.raises(_fallback.ORTModuleInitException) as ex_info:
                        ort_model = ORTModule(pt_model)
                    assert "ONNX Runtime ORTModule frontend requires PyTorch version greater or equal to" in str(
                        ex_info.value
                    )
            else:
                with pytest.raises(_fallback.ORTModuleInitException) as ex_info:
                    # Initialize with fallback policy because Exception will happen during __init__
                    ort_model = ORTModule(pt_model)
                assert "ONNX Runtime ORTModule frontend requires PyTorch version greater or equal to" in str(
                    ex_info.value
                )
    else:
        warnings.warn(
            "Skipping test_ortmodule_fallback_torch_version."
            f" It requires PyTorch prior to {MINIMUM_RUNTIME_PYTORCH_VERSION_STR}"
        )
def test_ortmodule_fallback_device__multiple(is_training, fallback_enabled,
                                             matching_policy,
                                             persist_fallback):
    # is_training: True for torch.nn.Module training model, eval mode otherwise
    # fallback_enabled: True PyTorch executes the forward graph instead of ORT backend
    # matching_policy: True matches FALLBACK_UNSUPPORTED_DEVICE policy to ORTModuleDeviceException exception.
    #   Otherwise, an incorrect policy (FALLBACK_UNSUPPORTED_DATA) is used to verify that the fallback does not happen

    if fallback_enabled:
        if matching_policy:
            policy = "FALLBACK_UNSUPPORTED_DEVICE"
        else:
            policy = "FALLBACK_UNSUPPORTED_DATA"
    else:
        policy = "FALLBACK_DISABLE"
    os.environ["ORTMODULE_FALLBACK_POLICY"] = policy
    os.environ["ORTMODULE_FALLBACK_RETRY"] = str(not persist_fallback)

    class ManyDevicesNet(torch.nn.Module):
        def __init__(self):
            super().__init__()
            self.net1 = torch.nn.Linear(10, 10).to("cuda:0")
            self.relu = torch.nn.ReLU()
            self.net2 = torch.nn.Linear(10, 5).to("cpu")

        def forward(self, x):
            x = self.relu(self.net1(x.to("cuda:0")))
            return self.net2(x.to("cpu"))

    pt_model = ManyDevicesNet()
    inputs = torch.randn(20, 10)

    for _ in range(3):
        if fallback_enabled:
            if matching_policy:
                ort_model = ORTModule(copy.deepcopy(pt_model))
                pt_model.train(is_training)
                ort_model.train(is_training)
                ort_out = ort_model(inputs)
                pt_out = pt_model(inputs)
                _test_helpers.assert_values_are_close(ort_out,
                                                      pt_out,
                                                      rtol=0,
                                                      atol=0)
            else:
                with pytest.raises(
                        _fallback.ORTModuleFallbackException) as type_error:
                    # Initialize with fallback policy because Exception will happen during __init__
                    ort_model = ORTModule(copy.deepcopy(pt_model))
                assert "ORTModule supports a single device per model" in str(
                    type_error.value)
        else:
            with pytest.raises(
                    _fallback.ORTModuleFallbackException) as type_error:
                # Initialize with fallback policy because Exception will happen during __init__
                ort_model = ORTModule(copy.deepcopy(pt_model))
            assert "ORTModule supports a single device per model" in str(
                type_error.value)
def test_ortmodule_fallback_onnx_model__missing_op(is_training,
                                                   fallback_enabled,
                                                   matching_policy,
                                                   persist_fallback):
    # is_training: True for torch.nn.Module training model, eval mode otherwise
    # fallback_enabled: True PyTorch executes the forward graph instead of ORT backend
    # matching_policy: True matches FALLBACK_UNSUPPORTED_ONNX_MODEL policy to ORTModuleDeviceException exception.
    #   Otherwise, an incorrect policy (FALLBACK_UNSUPPORTED_DEVICE) is used to verify that the fallback does not happen

    if fallback_enabled:
        if matching_policy:
            policy = "FALLBACK_UNSUPPORTED_ONNX_MODEL"
        else:
            policy = "FALLBACK_UNSUPPORTED_DEVICE"
    else:
        policy = "FALLBACK_DISABLE"
    os.environ["ORTMODULE_FALLBACK_POLICY"] = policy
    os.environ["ORTMODULE_FALLBACK_RETRY"] = str(not persist_fallback)

    class CrossModule(torch.nn.Module):
        def forward(self, x, y):
            return torch.cross(x, y)

    x = torch.randn(2, 3)
    y = torch.randn(2, 3)

    pt_model = CrossModule()
    ort_model = ORTModule(copy.deepcopy(pt_model))
    ort_model.train(is_training)
    pt_model.train(is_training)

    for i in range(3):
        if fallback_enabled:
            if matching_policy:
                if i > 0 and persist_fallback:
                    assert (ort_model._torch_module._execution_manager(
                        is_training=is_training)._fallback_manager._exception
                            is not None)
                pt_out = pt_model(x, y)
                ort_out = ort_model(x, y)
                _test_helpers.assert_values_are_close(ort_out,
                                                      pt_out,
                                                      rtol=0,
                                                      atol=0)
            else:
                with pytest.raises(
                        _fallback.ORTModuleONNXModelException) as ex_info:
                    _ = ort_model(x, y)
                assert "There was an error while exporting the PyTorch model to ONNX" in str(
                    ex_info.value)
        else:
            with pytest.raises(
                    _fallback.ORTModuleONNXModelException) as ex_info:
                # Initialize with fallback policy because Exception will happen during __init__
                _ = ort_model(x, y)
            assert "There was an error while exporting the PyTorch model to ONNX" in str(
                ex_info.value)
def test_ortmodule_fallback_onnx_model__custom_autograd(
        is_training, fallback_enabled, matching_policy, persist_fallback):
    # is_training: True for torch.nn.Module training model, eval mode otherwise
    # fallback_enabled: True PyTorch executes the forward graph instead of ORT backend
    # matching_policy: True matches FALLBACK_UNSUPPORTED_ONNX_MODEL policy to ORTModuleDeviceException exception.
    #   Otherwise, an incorrect policy (FALLBACK_UNSUPPORTED_DEVICE) is used to verify that the fallback does not happen

    if fallback_enabled:
        if matching_policy:
            policy = "FALLBACK_UNSUPPORTED_ONNX_MODEL"
        else:
            policy = "FALLBACK_UNSUPPORTED_DEVICE"
    else:
        policy = "FALLBACK_DISABLE"
    os.environ["ORTMODULE_FALLBACK_POLICY"] = policy
    os.environ["ORTMODULE_FALLBACK_RETRY"] = str(not persist_fallback)

    dtype = torch.float
    device = torch.device("cuda")
    N, D_in, H, D_out = 64, 1000, 100, 10

    x = torch.randn(N, D_in, device=device, dtype=dtype)
    y = torch.randn(N, D_out, device=device, dtype=dtype)
    w1 = torch.randn(D_in, H, device=device, dtype=dtype, requires_grad=True)
    w2 = torch.randn(H, D_out, device=device, dtype=dtype, requires_grad=True)

    pt_model = MyCustomFunctionReluModel()
    ort_model = ORTModule(copy.deepcopy(pt_model))
    ort_model.train(is_training)
    pt_model.train(is_training)

    for i in range(3):
        if fallback_enabled:
            if matching_policy:
                if i > 0 and persist_fallback:
                    assert (ort_model._torch_module._execution_manager(
                        is_training=is_training)._fallback_manager._exception
                            is not None)
                pt_out = pt_model(x.mm(w1)).mm(w2)
                ort_out = ort_model(x.mm(w1)).mm(w2)
                _test_helpers.assert_values_are_close(ort_out,
                                                      pt_out,
                                                      rtol=1e-03,
                                                      atol=1e-04)
            else:
                with pytest.raises(
                        _fallback.ORTModuleONNXModelException) as ex_info:
                    _ = ort_model(x.mm(w1)).mm(w2)
                assert "There was an error while exporting the PyTorch model to ONNX" in str(
                    ex_info.value)
        else:
            with pytest.raises(
                    _fallback.ORTModuleONNXModelException) as ex_info:
                # Initialize with fallback policy because Exception will happen during __init__
                _ = ort_model(x.mm(w1)).mm(w2)
            assert "There was an error while exporting the PyTorch model to ONNX" in str(
                ex_info.value)
def test_ortmodule_fallback_init__missing_cpp_extensions(
        is_training, fallback_enabled, matching_policy, persist_fallback):
    # is_training: True for torch.nn.Module training model, eval mode otherwise
    # fallback_enabled: True PyTorch executes the forward graph instead of ORT backend
    # matching_policy: True matches FALLBACK_UNSUPPORTED_TORCH_MODEL policy to ORTModuleDeviceException exception.
    #   Otherwise, an incorrect policy (FALLBACK_UNSUPPORTED_DEVICE) is used to verify that the fallback does not happen

    if is_torch_cpp_extensions_installed(ORTMODULE_TORCH_CPP_DIR):
        warnings.warn(
            "Skipping test_ortmodule_fallback_init__missing_cpp_extensions."
            f" It requires PyTorch CPP extensions to be missing")
    else:

        if fallback_enabled:
            if matching_policy:
                policy = "FALLBACK_BAD_INITIALIZATION"
            else:
                policy = "FALLBACK_UNSUPPORTED_DEVICE"
        else:
            policy = "FALLBACK_DISABLE"
        os.environ["ORTMODULE_FALLBACK_POLICY"] = policy
        os.environ["ORTMODULE_FALLBACK_RETRY"] = str(not persist_fallback)

        device = "cuda"
        N, D_in, H, D_out = 64, 784, 500, 10
        x = torch.randn(N, D_in, device=device)

        pt_model = NeuralNetSinglePositionalArgument(D_in, H, D_out).to(device)

        for _ in range(3):
            if fallback_enabled:
                if matching_policy:
                    ort_model = ORTModule(pt_model)
                    ort_model.train(is_training)
                    pt_model.train(is_training)

                    ort_out = ort_model(x)
                    pt_out = pt_model(x)
                    _test_helpers.assert_values_are_close(ort_out,
                                                          pt_out,
                                                          rtol=0,
                                                          atol=0)
                else:
                    with pytest.raises(
                            _fallback.ORTModuleInitException) as ex_info:
                        ort_model = ORTModule(pt_model)
                    assert "ORTModule's extensions were not detected" in str(
                        ex_info.value)
            else:
                with pytest.raises(
                        _fallback.ORTModuleInitException) as ex_info:
                    # Initialize with fallback policy because Exception will happen during __init__
                    ort_model = ORTModule(pt_model)
                assert "ORTModule's extensions were not detected" in str(
                    ex_info.value)
def test_ortmodule_fallback_input(is_training, fallback_enabled,
                                  matching_policy, persist_fallback):
    # is_training: True for torch.nn.Module training model, eval mode otherwise
    # fallback_enabled: True PyTorch executes the forward graph instead of ORT backend
    # matching_policy: True matches FALLBACK_UNSUPPORTED_DATA policy to ORTModuleDeviceException exception.
    #   Otherwise, an incorrect policy (FALLBACK_UNSUPPORTED_DEVICE) is used to verify that the fallback does not happen

    if fallback_enabled:
        if matching_policy:
            policy = "FALLBACK_UNSUPPORTED_DATA"
        else:
            policy = "FALLBACK_UNSUPPORTED_DEVICE"
    else:
        policy = "FALLBACK_DISABLE"
    os.environ["ORTMODULE_FALLBACK_POLICY"] = policy
    os.environ["ORTMODULE_FALLBACK_RETRY"] = str(not persist_fallback)

    pt_model = MyCustomClassInputNet()
    ort_model = ORTModule(copy.deepcopy(pt_model))
    inputs = torch.randn(1, 2)

    class CustomClass:
        def __init__(self, x):
            self.x = x

    ort_model.train(is_training)
    pt_model.train(is_training)

    for i in range(3):
        if fallback_enabled:
            if matching_policy:
                if i > 0 and persist_fallback:
                    assert (ort_model._torch_module._execution_manager(
                        is_training=is_training)._fallback_manager._exception
                            is not None)
                ort_out = ort_model(inputs, CustomClass(1))
                pt_out = pt_model(inputs, CustomClass(1))
                _test_helpers.assert_values_are_close(ort_out,
                                                      pt_out,
                                                      rtol=0,
                                                      atol=0)
            else:
                with pytest.raises(_fallback.ORTModuleIOError) as ex_info:
                    _ = ort_model(torch.randn(1, 2), CustomClass(1))
                assert ("ORTModule does not support the following model data"
                        " type <class 'orttraining_test_ortmodule_fallback."
                        "test_ortmodule_fallback_input.<locals>.CustomClass'>"
                        in str(ex_info.value))
        else:
            with pytest.raises(_fallback.ORTModuleIOError) as ex_info:
                _ = ort_model(torch.randn(1, 2), CustomClass(1))
            assert ("ORTModule does not support the following model data"
                    " type <class 'orttraining_test_ortmodule_fallback."
                    "test_ortmodule_fallback_input.<locals>.CustomClass'>"
                    in str(ex_info.value))
def test_ortmodule_fallback_torch_model(is_training, fallback_enabled,
                                        matching_policy, persist_fallback):
    # is_training: True for torch.nn.Module training model, eval mode otherwise
    # fallback_enabled: True PyTorch executes the forward graph instead of ORT backend
    # matching_policy: True matches FALLBACK_UNSUPPORTED_TORCH_MODEL policy to ORTModuleDeviceException exception.
    #   Otherwise, an incorrect policy (FALLBACK_UNSUPPORTED_DEVICE) is used to verify that the fallback does not happen

    if fallback_enabled:
        if matching_policy:
            policy = "FALLBACK_UNSUPPORTED_TORCH_MODEL"
        else:
            policy = "FALLBACK_UNSUPPORTED_DEVICE"
    else:
        policy = "FALLBACK_DISABLE"
    os.environ["ORTMODULE_FALLBACK_POLICY"] = policy
    os.environ["ORTMODULE_FALLBACK_RETRY"] = str(not persist_fallback)

    device = "cuda"
    N, D_in, H, D_out = 64, 784, 500, 10
    x = torch.randn(N, D_in, device=device)

    pt_model = NeuralNetSinglePositionalArgument(D_in, H, D_out).to(device)
    pt_model = torch.nn.DataParallel(pt_model)

    for _ in range(3):
        if fallback_enabled:
            if matching_policy:
                ort_model = ORTModule(pt_model)
                ort_model.train(is_training)
                pt_model.train(is_training)

                ort_out = ort_model(x)
                pt_out = pt_model(x)
                _test_helpers.assert_values_are_close(ort_out,
                                                      pt_out,
                                                      rtol=1e-3,
                                                      atol=1e-6)
            else:
                with pytest.raises(
                        _fallback.ORTModuleTorchModelException) as ex_info:
                    ort_model = ORTModule(pt_model)
                assert "ORTModule is not compatible with torch.nn.DataParallel" in str(
                    ex_info.value)
        else:
            with pytest.raises(
                    _fallback.ORTModuleTorchModelException) as ex_info:
                # Initialize with fallback policy because Exception will happen during __init__
                ort_model = ORTModule(pt_model)
            assert "ORTModule is not compatible with torch.nn.DataParallel" in str(
                ex_info.value)
def test_ortmodule_fallback_output(is_training, fallback_enabled,
                                   matching_policy, persist_fallback):
    # is_training: True for torch.nn.Module training model, eval mode otherwise
    # fallback_enabled: True PyTorch executes the forward graph instead of ORT backend
    # matching_policy: True matches FALLBACK_UNSUPPORTED_DATA policy to ORTModuleDeviceException exception.
    #   Otherwise, an incorrect policy (FALLBACK_UNSUPPORTED_DEVICE) is used to verify that the fallback does not happen

    if fallback_enabled:
        if matching_policy:
            policy = "FALLBACK_UNSUPPORTED_DATA"
        else:
            policy = "FALLBACK_UNSUPPORTED_DEVICE"
    else:
        policy = "FALLBACK_DISABLE"
    os.environ["ORTMODULE_FALLBACK_POLICY"] = policy
    os.environ["ORTMODULE_FALLBACK_RETRY"] = str(not persist_fallback)

    device = "cuda"
    N, D_in, H, D_out = 64, 784, 500, 10
    pt_model = NeuralNetCustomClassOutput(D_in, H, D_out).to(device)
    ort_model = ORTModule(copy.deepcopy(pt_model))
    x = torch.randn(N, D_in, device=device)
    y = torch.randn(N, D_in, device=device)
    z = torch.randn(N, D_in, device=device)

    ort_model.train(is_training)
    pt_model.train(is_training)

    for i in range(3):
        if fallback_enabled:
            if matching_policy:
                if i > 0 and persist_fallback:
                    assert (ort_model._torch_module._execution_manager(
                        is_training=is_training)._fallback_manager._exception
                            is not None)
                ort_out = ort_model(x, y, z)
                pt_out = pt_model(x, y, z)
                _test_helpers.assert_values_are_close(ort_out.out1,
                                                      pt_out.out1,
                                                      rtol=0,
                                                      atol=0)
                _test_helpers.assert_values_are_close(ort_out.out2,
                                                      pt_out.out2,
                                                      rtol=0,
                                                      atol=0)
                _test_helpers.assert_values_are_close(ort_out.out3,
                                                      pt_out.out3,
                                                      rtol=0,
                                                      atol=0)
            else:
                with pytest.raises(
                        _fallback.ORTModuleIOError) as runtime_error:
                    ort_model(x, y, z)
                assert "ORTModule does not support the following model output type" in str(
                    runtime_error.value)
        else:
            with pytest.raises(_fallback.ORTModuleIOError) as runtime_error:
                ort_model(x, y, z)
            assert "ORTModule does not support the following model output type" in str(
                runtime_error.value)
    def test_ortmodule_dlpack(self):
        class NeuralNetTanh(torch.nn.Module):
            def __init__(self, input_size, hidden_size, num_classes):
                super(NeuralNetTanh, self).__init__()

                self.fc1 = torch.nn.Linear(input_size, hidden_size)
                self.tanh = torch.nn.Tanh()

            def forward(self, input1):
                out = self.fc1(input1)
                out = self.tanh(out)
                return out

        def run_step(model, x):
            prediction = model(x)
            loss = prediction.sum()
            loss.backward()
            return prediction, loss

        N, D_in, H, D_out = 120, 1536, 500, 1536
        pt_model = NeuralNetTanh(D_in, H, D_out)
        ort_model = ORTModule(copy.deepcopy(pt_model))

        for step in range(10):
            pt_x = torch.randn(N, D_in, device='cpu', requires_grad=True)
            ort_x = copy.deepcopy(pt_x)
            ort_prediction, ort_loss = run_step(ort_model, ort_x)
            pt_prediction, pt_loss = run_step(pt_model, pt_x)
            _test_helpers.assert_values_are_close(ort_prediction,
                                                  pt_prediction,
                                                  atol=1e-4)
            _test_helpers.assert_values_are_close(ort_x.grad, pt_x.grad)
            _test_helpers.assert_values_are_close(ort_loss, pt_loss, atol=1e-4)