def test_ortmodule_fallback_output(is_training, fallback_enabled,
                                   matching_policy, persist_fallback):
    # is_training: True for torch.nn.Module training model, eval mode otherwise
    # fallback_enabled: True PyTorch executes the forward graph instead of ORT backend
    # matching_policy: True matches FALLBACK_UNSUPPORTED_DATA policy to ORTModuleDeviceException exception.
    #   Otherwise, an incorrect policy (FALLBACK_UNSUPPORTED_DEVICE) is used to verify that the fallback does not happen

    if fallback_enabled:
        if matching_policy:
            policy = "FALLBACK_UNSUPPORTED_DATA"
        else:
            policy = "FALLBACK_UNSUPPORTED_DEVICE"
    else:
        policy = "FALLBACK_DISABLE"
    os.environ["ORTMODULE_FALLBACK_POLICY"] = policy
    os.environ["ORTMODULE_FALLBACK_RETRY"] = str(not persist_fallback)

    device = "cuda"
    N, D_in, H, D_out = 64, 784, 500, 10
    pt_model = NeuralNetCustomClassOutput(D_in, H, D_out).to(device)
    ort_model = ORTModule(copy.deepcopy(pt_model))
    x = torch.randn(N, D_in, device=device)
    y = torch.randn(N, D_in, device=device)
    z = torch.randn(N, D_in, device=device)

    ort_model.train(is_training)
    pt_model.train(is_training)

    for i in range(3):
        if fallback_enabled:
            if matching_policy:
                if i > 0 and persist_fallback:
                    assert (ort_model._torch_module._execution_manager(
                        is_training=is_training)._fallback_manager._exception
                            is not None)
                ort_out = ort_model(x, y, z)
                pt_out = pt_model(x, y, z)
                _test_helpers.assert_values_are_close(ort_out.out1,
                                                      pt_out.out1,
                                                      rtol=0,
                                                      atol=0)
                _test_helpers.assert_values_are_close(ort_out.out2,
                                                      pt_out.out2,
                                                      rtol=0,
                                                      atol=0)
                _test_helpers.assert_values_are_close(ort_out.out3,
                                                      pt_out.out3,
                                                      rtol=0,
                                                      atol=0)
            else:
                with pytest.raises(
                        _fallback.ORTModuleIOError) as runtime_error:
                    ort_model(x, y, z)
                assert "ORTModule does not support the following model output type" in str(
                    runtime_error.value)
        else:
            with pytest.raises(_fallback.ORTModuleIOError) as runtime_error:
                ort_model(x, y, z)
            assert "ORTModule does not support the following model output type" in str(
                runtime_error.value)
def test_ortmodule_fallback_warn_message(is_training, persist_fallback):
    # is_training: True for torch.nn.Module training model, eval mode otherwise

    policy = "FALLBACK_UNSUPPORTED_DEVICE"
    os.environ["ORTMODULE_FALLBACK_POLICY"] = policy
    os.environ["ORTMODULE_FALLBACK_RETRY"] = str(not persist_fallback)
    os.environ["ORTMODULE_SKIPCHECK_POLICY"] = "SKIP_CHECK_DISABLED"

    data_device = "cuda"
    N, D_in, H, D_out = 64, 784, 500, 10

    pt_model = NeuralNetSinglePositionalArgument(D_in, H, D_out)
    ort_model = ORTModule(copy.deepcopy(pt_model))
    pt_model.train(is_training)
    ort_model.train(is_training)
    # For initial model export, use same device for data and model so that PyTorch model can be traced during export
    _ = ort_model(torch.randn(N, D_in))

    # Use data in different device for testing
    inputs = torch.randn(N, D_in, device=data_device)

    for _ in range(3):
        with pytest.raises(RuntimeError):
            with pytest.warns(UserWarning) as warning_record:
                ort_model(inputs)
        assert "Fallback to PyTorch due to exception" in str(
            warning_record[0].message.args[0])

    del os.environ["ORTMODULE_SKIPCHECK_POLICY"]
def test_ortmodule_fallback_device__mismatch(is_training, fallback_enabled, matching_policy, persist_fallback):
    # is_training: True for torch.nn.Module training model, eval mode otherwise
    # fallback_enabled: True PyTorch executes the forward graph instead of ORT backend
    # matching_policy: True matches FALLBACK_UNSUPPORTED_DEVICE policy to ORTModuleDeviceException exception.
    #   Otherwise, an incorrect policy (FALLBACK_UNSUPPORTED_DATA) is used to verify that the fallback does not happen

    if fallback_enabled:
        if matching_policy:
            policy = "FALLBACK_UNSUPPORTED_DEVICE"
        else:
            policy = "FALLBACK_UNSUPPORTED_DATA"
    else:
        policy = "FALLBACK_DISABLE"
    os.environ["ORTMODULE_FALLBACK_POLICY"] = policy
    os.environ["ORTMODULE_FALLBACK_RETRY"] = str(not persist_fallback)
    os.environ["ORTMODULE_SKIPCHECK_POLICY"] = "SKIP_CHECK_DISABLED"

    data_device = "cuda"
    N, D_in, H, D_out = 64, 784, 500, 10

    pt_model = NeuralNetSinglePositionalArgument(D_in, H, D_out)
    ort_model = ORTModule(copy.deepcopy(pt_model))
    pt_model.train(is_training)
    ort_model.train(is_training)
    # For initial model export, use same device for data and model so that PyTorch model can be traced during export
    _ = ort_model(torch.randn(N, D_in))

    # Use data in different device for testing
    inputs = torch.randn(N, D_in, device=data_device)
    ort_model_device = ort_model._torch_module._execution_manager(ort_model._is_training())._device
    input_device = torch.device(inputs.device)

    for _ in range(3):
        if fallback_enabled:
            if matching_policy:
                with pytest.raises(RuntimeError) as e:
                    ort_model(inputs)
                assert (
                    "Tensor for argument #1 'self' is on CPU, but expected them to be on GPU (while checking arguments for addmm)"
                    in str(e.value)
                ) or (
                    "Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0!"
                    in str(e.value)
                )
            else:
                with pytest.raises(_fallback.ORTModuleDeviceException) as e:
                    ort_model(inputs)
                assert (
                    f"Input argument to forward found on device {input_device}, "
                    f"but expected it to be on module device {ort_model_device}." in str(e.value)
                )
        else:
            with pytest.raises(_fallback.ORTModuleDeviceException) as e:
                ort_model(inputs)
            assert (
                f"Input argument to forward found on device {input_device}, "
                f"but expected it to be on module device {ort_model_device}." in str(e.value)
            )

    del os.environ["ORTMODULE_SKIPCHECK_POLICY"]
def test_ortmodule_fallback_forward(is_training, fallback_enabled,
                                    matching_policy, persist_fallback):
    # is_training: True for torch.nn.Module training model, eval mode otherwise
    # fallback_enabled: True PyTorch executes the forward graph instead of ORT backend
    # matching_policy: True matches FALLBACK_FORCE_TORCH_FORWARD policy to ORTModuleDeviceException exception.
    #   Otherwise, an incorrect policy (FALLBACK_UNSUPPORTED_DEVICE) is used to verify that the fallback does not happen

    if fallback_enabled:
        if matching_policy:
            policy = "FALLBACK_FORCE_TORCH_FORWARD"
        else:
            policy = "FALLBACK_UNSUPPORTED_DEVICE"
    else:
        policy = "FALLBACK_DISABLE"
    os.environ["ORTMODULE_FALLBACK_POLICY"] = policy
    os.environ["ORTMODULE_FALLBACK_RETRY"] = str(not persist_fallback)

    from dataclasses import dataclass

    @dataclass
    class Point:
        x: int
        y: int

    class UnsupportedInputModel(torch.nn.Module):
        def __init__(self):
            super(UnsupportedInputModel, self).__init__()

        def forward(self, point):
            return point.x * point.y

    pt_model = UnsupportedInputModel()
    ort_model = ORTModule(copy.deepcopy(pt_model))
    inputs = Point(x=2, y=3)

    ort_model.train(is_training)
    pt_model.train(is_training)

    for i in range(3):
        if fallback_enabled:
            if matching_policy:
                if i > 0 and persist_fallback:
                    assert (ort_model._torch_module._execution_manager(
                        is_training=is_training)._fallback_manager._exception
                            is not None)
                ort_out = ort_model(inputs)
                pt_out = pt_model(inputs)
                assert ort_out == pt_out
            else:
                with pytest.raises(
                        _fallback.ORTModuleFallbackException) as type_error:
                    ort_model(inputs)
                assert "ORTModule does not support the following model data type" in str(
                    type_error.value)
        else:
            with pytest.raises(
                    _fallback.ORTModuleFallbackException) as type_error:
                ort_model(inputs)
            assert "ORTModule does not support the following model data type" in str(
                type_error.value)
def test_ortmodule_fallback_onnx_model__custom_autograd(
        is_training, fallback_enabled, matching_policy, persist_fallback):
    # is_training: True for torch.nn.Module training model, eval mode otherwise
    # fallback_enabled: True PyTorch executes the forward graph instead of ORT backend
    # matching_policy: True matches FALLBACK_UNSUPPORTED_ONNX_MODEL policy to ORTModuleDeviceException exception.
    #   Otherwise, an incorrect policy (FALLBACK_UNSUPPORTED_DEVICE) is used to verify that the fallback does not happen

    if fallback_enabled:
        if matching_policy:
            policy = "FALLBACK_UNSUPPORTED_ONNX_MODEL"
        else:
            policy = "FALLBACK_UNSUPPORTED_DEVICE"
    else:
        policy = "FALLBACK_DISABLE"
    os.environ["ORTMODULE_FALLBACK_POLICY"] = policy
    os.environ["ORTMODULE_FALLBACK_RETRY"] = str(not persist_fallback)

    dtype = torch.float
    device = torch.device("cuda")
    N, D_in, H, D_out = 64, 1000, 100, 10

    x = torch.randn(N, D_in, device=device, dtype=dtype)
    y = torch.randn(N, D_out, device=device, dtype=dtype)
    w1 = torch.randn(D_in, H, device=device, dtype=dtype, requires_grad=True)
    w2 = torch.randn(H, D_out, device=device, dtype=dtype, requires_grad=True)

    pt_model = MyCustomFunctionReluModel()
    ort_model = ORTModule(copy.deepcopy(pt_model))
    ort_model.train(is_training)
    pt_model.train(is_training)

    for i in range(3):
        if fallback_enabled:
            if matching_policy:
                if i > 0 and persist_fallback:
                    assert (ort_model._torch_module._execution_manager(
                        is_training=is_training)._fallback_manager._exception
                            is not None)
                pt_out = pt_model(x.mm(w1)).mm(w2)
                ort_out = ort_model(x.mm(w1)).mm(w2)
                _test_helpers.assert_values_are_close(ort_out,
                                                      pt_out,
                                                      rtol=1e-03,
                                                      atol=1e-04)
            else:
                with pytest.raises(
                        _fallback.ORTModuleONNXModelException) as ex_info:
                    _ = ort_model(x.mm(w1)).mm(w2)
                assert "There was an error while exporting the PyTorch model to ONNX" in str(
                    ex_info.value)
        else:
            with pytest.raises(
                    _fallback.ORTModuleONNXModelException) as ex_info:
                # Initialize with fallback policy because Exception will happen during __init__
                _ = ort_model(x.mm(w1)).mm(w2)
            assert "There was an error while exporting the PyTorch model to ONNX" in str(
                ex_info.value)
def test_ortmodule_fallback_onnx_model__missing_op(is_training,
                                                   fallback_enabled,
                                                   matching_policy,
                                                   persist_fallback):
    # is_training: True for torch.nn.Module training model, eval mode otherwise
    # fallback_enabled: True PyTorch executes the forward graph instead of ORT backend
    # matching_policy: True matches FALLBACK_UNSUPPORTED_ONNX_MODEL policy to ORTModuleDeviceException exception.
    #   Otherwise, an incorrect policy (FALLBACK_UNSUPPORTED_DEVICE) is used to verify that the fallback does not happen

    if fallback_enabled:
        if matching_policy:
            policy = "FALLBACK_UNSUPPORTED_ONNX_MODEL"
        else:
            policy = "FALLBACK_UNSUPPORTED_DEVICE"
    else:
        policy = "FALLBACK_DISABLE"
    os.environ["ORTMODULE_FALLBACK_POLICY"] = policy
    os.environ["ORTMODULE_FALLBACK_RETRY"] = str(not persist_fallback)

    class CrossModule(torch.nn.Module):
        def forward(self, x, y):
            return torch.cross(x, y)

    x = torch.randn(2, 3)
    y = torch.randn(2, 3)

    pt_model = CrossModule()
    ort_model = ORTModule(copy.deepcopy(pt_model))
    ort_model.train(is_training)
    pt_model.train(is_training)

    for i in range(3):
        if fallback_enabled:
            if matching_policy:
                if i > 0 and persist_fallback:
                    assert (ort_model._torch_module._execution_manager(
                        is_training=is_training)._fallback_manager._exception
                            is not None)
                pt_out = pt_model(x, y)
                ort_out = ort_model(x, y)
                _test_helpers.assert_values_are_close(ort_out,
                                                      pt_out,
                                                      rtol=0,
                                                      atol=0)
            else:
                with pytest.raises(
                        _fallback.ORTModuleONNXModelException) as ex_info:
                    _ = ort_model(x, y)
                assert "There was an error while exporting the PyTorch model to ONNX" in str(
                    ex_info.value)
        else:
            with pytest.raises(
                    _fallback.ORTModuleONNXModelException) as ex_info:
                # Initialize with fallback policy because Exception will happen during __init__
                _ = ort_model(x, y)
            assert "There was an error while exporting the PyTorch model to ONNX" in str(
                ex_info.value)
 def __init__(self, lr, use_ortmodule=True):
     super().__init__()
     self.lr = lr
     self.encoder = nn.Sequential(nn.Linear(28 * 28, 64), nn.ReLU(),
                                  nn.Linear(64, 3))
     self.decoder = nn.Sequential(nn.Linear(3, 64), nn.ReLU(),
                                  nn.Linear(64, 28 * 28))
     if use_ortmodule:
         self.encoder = ORTModule(self.encoder)
         self.decoder = ORTModule(self.decoder)
def test_ortmodule_fallback_input(is_training, fallback_enabled,
                                  matching_policy, persist_fallback):
    # is_training: True for torch.nn.Module training model, eval mode otherwise
    # fallback_enabled: True PyTorch executes the forward graph instead of ORT backend
    # matching_policy: True matches FALLBACK_UNSUPPORTED_DATA policy to ORTModuleDeviceException exception.
    #   Otherwise, an incorrect policy (FALLBACK_UNSUPPORTED_DEVICE) is used to verify that the fallback does not happen

    if fallback_enabled:
        if matching_policy:
            policy = "FALLBACK_UNSUPPORTED_DATA"
        else:
            policy = "FALLBACK_UNSUPPORTED_DEVICE"
    else:
        policy = "FALLBACK_DISABLE"
    os.environ["ORTMODULE_FALLBACK_POLICY"] = policy
    os.environ["ORTMODULE_FALLBACK_RETRY"] = str(not persist_fallback)

    pt_model = MyCustomClassInputNet()
    ort_model = ORTModule(copy.deepcopy(pt_model))
    inputs = torch.randn(1, 2)

    class CustomClass:
        def __init__(self, x):
            self.x = x

    ort_model.train(is_training)
    pt_model.train(is_training)

    for i in range(3):
        if fallback_enabled:
            if matching_policy:
                if i > 0 and persist_fallback:
                    assert (ort_model._torch_module._execution_manager(
                        is_training=is_training)._fallback_manager._exception
                            is not None)
                ort_out = ort_model(inputs, CustomClass(1))
                pt_out = pt_model(inputs, CustomClass(1))
                _test_helpers.assert_values_are_close(ort_out,
                                                      pt_out,
                                                      rtol=0,
                                                      atol=0)
            else:
                with pytest.raises(_fallback.ORTModuleIOError) as ex_info:
                    _ = ort_model(torch.randn(1, 2), CustomClass(1))
                assert ("ORTModule does not support the following model data"
                        " type <class 'orttraining_test_ortmodule_fallback."
                        "test_ortmodule_fallback_input.<locals>.CustomClass'>"
                        in str(ex_info.value))
        else:
            with pytest.raises(_fallback.ORTModuleIOError) as ex_info:
                _ = ort_model(torch.randn(1, 2), CustomClass(1))
            assert ("ORTModule does not support the following model data"
                    " type <class 'orttraining_test_ortmodule_fallback."
                    "test_ortmodule_fallback_input.<locals>.CustomClass'>"
                    in str(ex_info.value))
def test_ortmodule_fallback_device__multiple(is_training, fallback_enabled,
                                             matching_policy,
                                             persist_fallback):
    # is_training: True for torch.nn.Module training model, eval mode otherwise
    # fallback_enabled: True PyTorch executes the forward graph instead of ORT backend
    # matching_policy: True matches FALLBACK_UNSUPPORTED_DEVICE policy to ORTModuleDeviceException exception.
    #   Otherwise, an incorrect policy (FALLBACK_UNSUPPORTED_DATA) is used to verify that the fallback does not happen

    if fallback_enabled:
        if matching_policy:
            policy = "FALLBACK_UNSUPPORTED_DEVICE"
        else:
            policy = "FALLBACK_UNSUPPORTED_DATA"
    else:
        policy = "FALLBACK_DISABLE"
    os.environ["ORTMODULE_FALLBACK_POLICY"] = policy
    os.environ["ORTMODULE_FALLBACK_RETRY"] = str(not persist_fallback)

    class ManyDevicesNet(torch.nn.Module):
        def __init__(self):
            super().__init__()
            self.net1 = torch.nn.Linear(10, 10).to("cuda:0")
            self.relu = torch.nn.ReLU()
            self.net2 = torch.nn.Linear(10, 5).to("cpu")

        def forward(self, x):
            x = self.relu(self.net1(x.to("cuda:0")))
            return self.net2(x.to("cpu"))

    pt_model = ManyDevicesNet()
    inputs = torch.randn(20, 10)

    for _ in range(3):
        if fallback_enabled:
            if matching_policy:
                ort_model = ORTModule(copy.deepcopy(pt_model))
                pt_model.train(is_training)
                ort_model.train(is_training)
                ort_out = ort_model(inputs)
                pt_out = pt_model(inputs)
                _test_helpers.assert_values_are_close(ort_out,
                                                      pt_out,
                                                      rtol=0,
                                                      atol=0)
            else:
                with pytest.raises(
                        _fallback.ORTModuleFallbackException) as type_error:
                    # Initialize with fallback policy because Exception will happen during __init__
                    ort_model = ORTModule(copy.deepcopy(pt_model))
                assert "ORTModule supports a single device per model" in str(
                    type_error.value)
        else:
            with pytest.raises(
                    _fallback.ORTModuleFallbackException) as type_error:
                # Initialize with fallback policy because Exception will happen during __init__
                ort_model = ORTModule(copy.deepcopy(pt_model))
            assert "ORTModule supports a single device per model" in str(
                type_error.value)
def test_ortmodule_fallback_init__torch_version(is_training, fallback_enabled, matching_policy, persist_fallback):
    # is_training: True for torch.nn.Module training model, eval mode otherwise
    # fallback_enabled: True PyTorch executes the forward graph instead of ORT backend
    # matching_policy: True matches FALLBACK_UNSUPPORTED_TORCH_MODEL policy to ORTModuleDeviceException exception.
    #   Otherwise, an incorrect policy (FALLBACK_UNSUPPORTED_DEVICE) is used to verify that the fallback does not happen

    from packaging import version

    from onnxruntime.training.ortmodule import MINIMUM_RUNTIME_PYTORCH_VERSION_STR

    runtime_pytorch_version = version.parse(torch.__version__.split("+")[0])
    minimum_runtime_pytorch_version = version.parse(MINIMUM_RUNTIME_PYTORCH_VERSION_STR)
    if runtime_pytorch_version < minimum_runtime_pytorch_version:

        if fallback_enabled:
            if matching_policy:
                policy = "FALLBACK_BAD_INITIALIZATION"
            else:
                policy = "FALLBACK_UNSUPPORTED_DEVICE"
        else:
            policy = "FALLBACK_DISABLE"
        os.environ["ORTMODULE_FALLBACK_POLICY"] = policy
        os.environ["ORTMODULE_FALLBACK_RETRY"] = str(not persist_fallback)

        device = "cuda"
        N, D_in, H, D_out = 64, 784, 500, 10
        x = torch.randn(N, D_in, device=device)

        pt_model = NeuralNetSinglePositionalArgument(D_in, H, D_out).to(device)

        for i in range(3):
            if fallback_enabled:
                if matching_policy:
                    ort_model = ORTModule(pt_model)
                    ort_model.train(is_training)
                    pt_model.train(is_training)

                    ort_out = ort_model(x)
                    pt_out = pt_model(x)
                    _test_helpers.assert_values_are_close(ort_out, pt_out, rtol=0, atol=0)
                else:
                    with pytest.raises(_fallback.ORTModuleInitException) as ex_info:
                        ort_model = ORTModule(pt_model)
                    assert "ONNX Runtime ORTModule frontend requires PyTorch version greater or equal to" in str(
                        ex_info.value
                    )
            else:
                with pytest.raises(_fallback.ORTModuleInitException) as ex_info:
                    # Initialize with fallback policy because Exception will happen during __init__
                    ort_model = ORTModule(pt_model)
                assert "ONNX Runtime ORTModule frontend requires PyTorch version greater or equal to" in str(
                    ex_info.value
                )
    else:
        warnings.warn(
            "Skipping test_ortmodule_fallback_torch_version."
            f" It requires PyTorch prior to {MINIMUM_RUNTIME_PYTORCH_VERSION_STR}"
        )
    def test_bool_input_and_output(self):
        class NeuralNetBoolInputOutput(torch.nn.Module):
            def __init__(self, input_size, num_classes):
                super(NeuralNetBoolInputOutput, self).__init__()
                self.fc = torch.nn.Linear(input_size, num_classes)
                self.relu = torch.nn.ReLU()

            def forward(self, condition, x1, x2):
                out1 = self.relu(self.fc(torch.where(condition, x1, x2)))
                out2 = torch.tensor(out1).to(torch.bool)
                return out1, out2

        device = 'cpu'
        N, D_in, D_out = 8, 16, 2
        model = NeuralNetBoolInputOutput(D_in, D_out).to(device)
        model = ORTModule(model)
        condition = torch.randint(2, (N, D_in),
                                  dtype=torch.bool,
                                  device=device)
        x1 = torch.randn(N, D_in, device=device)
        x2 = torch.randn(N, D_in, device=device)
        y1, y2 = model(condition, x1, x2)

        assert y1 is not None
        assert y2 is not None and y2.dtype == torch.bool
    def test_ortmodule_dlpack(self):
        class NeuralNetTanh(torch.nn.Module):
            def __init__(self, input_size, hidden_size, num_classes):
                super(NeuralNetTanh, self).__init__()

                self.fc1 = torch.nn.Linear(input_size, hidden_size)
                self.tanh = torch.nn.Tanh()

            def forward(self, input1):
                out = self.fc1(input1)
                out = self.tanh(out)
                return out

        def run_step(model, x):
            prediction = model(x)
            loss = prediction.sum()
            loss.backward()
            return prediction, loss

        N, D_in, H, D_out = 120, 1536, 500, 1536
        pt_model = NeuralNetTanh(D_in, H, D_out)
        ort_model = ORTModule(copy.deepcopy(pt_model))

        for step in range(10):
            pt_x = torch.randn(N, D_in, device='cpu', requires_grad=True)
            ort_x = copy.deepcopy(pt_x)
            ort_prediction, ort_loss = run_step(ort_model, ort_x)
            pt_prediction, pt_loss = run_step(pt_model, pt_x)
            _test_helpers.assert_values_are_close(ort_prediction,
                                                  pt_prediction,
                                                  atol=1e-4)
            _test_helpers.assert_values_are_close(ort_x.grad, pt_x.grad)
            _test_helpers.assert_values_are_close(ort_loss, pt_loss, atol=1e-4)
 def forward(ctx, x, dim, device, use_ort):
     ctx.save_for_backward(x)
     ctx.device = device
     ctx.inner = InnerModel(dim, device).to(device)
     if use_ort:
         ctx.inner = ORTModule(ctx.inner)
     z = ctx.inner(x)
     return z
def test_ortmodule_fallback_init__missing_cpp_extensions(
        is_training, fallback_enabled, matching_policy, persist_fallback):
    # is_training: True for torch.nn.Module training model, eval mode otherwise
    # fallback_enabled: True PyTorch executes the forward graph instead of ORT backend
    # matching_policy: True matches FALLBACK_UNSUPPORTED_TORCH_MODEL policy to ORTModuleDeviceException exception.
    #   Otherwise, an incorrect policy (FALLBACK_UNSUPPORTED_DEVICE) is used to verify that the fallback does not happen

    if is_torch_cpp_extensions_installed(ORTMODULE_TORCH_CPP_DIR):
        warnings.warn(
            "Skipping test_ortmodule_fallback_init__missing_cpp_extensions."
            f" It requires PyTorch CPP extensions to be missing")
    else:

        if fallback_enabled:
            if matching_policy:
                policy = "FALLBACK_BAD_INITIALIZATION"
            else:
                policy = "FALLBACK_UNSUPPORTED_DEVICE"
        else:
            policy = "FALLBACK_DISABLE"
        os.environ["ORTMODULE_FALLBACK_POLICY"] = policy
        os.environ["ORTMODULE_FALLBACK_RETRY"] = str(not persist_fallback)

        device = "cuda"
        N, D_in, H, D_out = 64, 784, 500, 10
        x = torch.randn(N, D_in, device=device)

        pt_model = NeuralNetSinglePositionalArgument(D_in, H, D_out).to(device)

        for _ in range(3):
            if fallback_enabled:
                if matching_policy:
                    ort_model = ORTModule(pt_model)
                    ort_model.train(is_training)
                    pt_model.train(is_training)

                    ort_out = ort_model(x)
                    pt_out = pt_model(x)
                    _test_helpers.assert_values_are_close(ort_out,
                                                          pt_out,
                                                          rtol=0,
                                                          atol=0)
                else:
                    with pytest.raises(
                            _fallback.ORTModuleInitException) as ex_info:
                        ort_model = ORTModule(pt_model)
                    assert "ORTModule's extensions were not detected" in str(
                        ex_info.value)
            else:
                with pytest.raises(
                        _fallback.ORTModuleInitException) as ex_info:
                    # Initialize with fallback policy because Exception will happen during __init__
                    ort_model = ORTModule(pt_model)
                assert "ORTModule's extensions were not detected" in str(
                    ex_info.value)
Exemplo n.º 15
0
def run_with_ort_on_device(device,
                           model,
                           input_list,
                           label_input,
                           is_eval_mode=False):
    model = copy.deepcopy(model)
    model.to(device)
    model = ORTModule(model)
    enable_custom_autograd_function(model)
    if is_eval_mode:
        model.eval()
    else:
        model.train()

    inputs_on_device = [input_.to(device) for input_ in input_list]
    output = model(*inputs_on_device)
    forward_outputs = [output]
    grad_outputs = []

    if not is_eval_mode:
        criterion = torch.nn.MSELoss()
        target = label_input.to(device)
        loss = criterion(output, target)
        loss.backward()
        for name, param in model.named_parameters():
            if param.requires_grad:
                grad_outputs.append(param.grad)
    return forward_outputs, grad_outputs
 def forward(ctx, x, dim, device, use_ort):
     ctx.save_for_backward(x)
     ctx.device = device
     ctx.inner = InnerModel(dim, device).to(device)
     if use_ort:
         ctx.inner = ORTModule(ctx.inner)
         enable_custom_autograd_function(ctx.inner)
     z = ctx.inner(x)
     return z
Exemplo n.º 17
0
    def run_with_ort_on_gpu(model, args, rank, device):
        model.to(device)
        model = ORTModule(model)

        _test_helpers.set_onnx_fallthrough_export_type(model)
        model = DDP(model, device_ids=[rank])
        cuda_args = [arg.to(device) for arg in args]
        output = model(*cuda_args)
        output.sum().backward()
        return output, [arg.grad for arg in cuda_args]
def test_ortmodule_fallback_torch_model(is_training, fallback_enabled,
                                        matching_policy, persist_fallback):
    # is_training: True for torch.nn.Module training model, eval mode otherwise
    # fallback_enabled: True PyTorch executes the forward graph instead of ORT backend
    # matching_policy: True matches FALLBACK_UNSUPPORTED_TORCH_MODEL policy to ORTModuleDeviceException exception.
    #   Otherwise, an incorrect policy (FALLBACK_UNSUPPORTED_DEVICE) is used to verify that the fallback does not happen

    if fallback_enabled:
        if matching_policy:
            policy = "FALLBACK_UNSUPPORTED_TORCH_MODEL"
        else:
            policy = "FALLBACK_UNSUPPORTED_DEVICE"
    else:
        policy = "FALLBACK_DISABLE"
    os.environ["ORTMODULE_FALLBACK_POLICY"] = policy
    os.environ["ORTMODULE_FALLBACK_RETRY"] = str(not persist_fallback)

    device = "cuda"
    N, D_in, H, D_out = 64, 784, 500, 10
    x = torch.randn(N, D_in, device=device)

    pt_model = NeuralNetSinglePositionalArgument(D_in, H, D_out).to(device)
    pt_model = torch.nn.DataParallel(pt_model)

    for _ in range(3):
        if fallback_enabled:
            if matching_policy:
                ort_model = ORTModule(pt_model)
                ort_model.train(is_training)
                pt_model.train(is_training)

                ort_out = ort_model(x)
                pt_out = pt_model(x)
                _test_helpers.assert_values_are_close(ort_out,
                                                      pt_out,
                                                      rtol=1e-3,
                                                      atol=1e-6)
            else:
                with pytest.raises(
                        _fallback.ORTModuleTorchModelException) as ex_info:
                    ort_model = ORTModule(pt_model)
                assert "ORTModule is not compatible with torch.nn.DataParallel" in str(
                    ex_info.value)
        else:
            with pytest.raises(
                    _fallback.ORTModuleTorchModelException) as ex_info:
                # Initialize with fallback policy because Exception will happen during __init__
                ort_model = ORTModule(pt_model)
            assert "ORTModule is not compatible with torch.nn.DataParallel" in str(
                ex_info.value)
def test_load_config_from_json_2():
    device = 'cuda'
    model = ORTModule(Net().to(device))

    # load from json once.
    path_to_json = os.path.join(os.getcwd(), 'orttraining_test_ortmodule_experimental_json_config_1.json')
    load_from_json(model, path_to_json)

    # load from json another time
    path_to_json = os.path.join(os.getcwd(), 'orttraining_test_ortmodule_experimental_json_config_2.json')
    load_from_json(model, path_to_json)

    for training_mode in [True, False]:
        ort_model_attributes = model._torch_module._execution_manager(training_mode)

        # test propagate cast ops
        assert ort_model_attributes._propagate_cast_ops_strategy == C.PropagateCastOpsStrategy.INSERT_AND_REDUCE
        assert ort_model_attributes._propagate_cast_ops_level == 5
        assert ort_model_attributes._propagate_cast_ops_allow == ["XYZ", "PQR"]

        # test use external gpu allocator
        assert ort_model_attributes._use_external_gpu_allocator == True

        # test enable custom autograd function
        assert ort_model_attributes._enable_custom_autograd_function == False

        # test allow layer norm mod precision
        assert ort_model_attributes._allow_layer_norm_mod_precision == False

        # test use static shape
        assert ort_model_attributes._use_static_shape == False

        # test run symbolic shape inference
        assert ort_model_attributes._run_symbolic_shape_infer == True

        # test enable grad acc optimization
        assert ort_model_attributes._enable_grad_acc_optimization == False

        # test skip check
        assert ort_model_attributes._skip_check.value == 10

        # test debug options
        assert ort_model_attributes._debug_options.save_onnx_models.save == True
        assert ort_model_attributes._debug_options.save_onnx_models.name_prefix == 'my_other_model'
        assert ort_model_attributes._debug_options.logging.log_level.name == "INFO"

        # test use memory aware gradient builder.
        assert ort_model_attributes._use_memory_efficient_gradient == True

        # test fallback policy
        assert ort_model_attributes._fallback_manager.policy.value == 250
def test_load_config_from_json_1():
    device = 'cuda'
    model = ORTModule(Net().to(device))

    # load from json once.
    path_to_json = os.path.join(
        os.getcwd(),
        'orttraining_test_ortmodule_experimental_json_config_2.json')
    load_from_json(model, path_to_json)

    # load from json another time
    path_to_json = os.path.join(
        os.getcwd(),
        'orttraining_test_ortmodule_experimental_json_config_1.json')
    load_from_json(model, path_to_json)

    for training_mode in [True, False]:
        ort_model_attributes = model._torch_module._execution_manager(
            training_mode)

        # test propagate cast ops
        assert ort_model_attributes._propagate_cast_ops_strategy == C.PropagateCastOpsStrategy.FLOOD_FILL
        assert ort_model_attributes._propagate_cast_ops_level == 3
        assert ort_model_attributes._propagate_cast_ops_allow == ["ABC", "DEF"]

        # test use external gpu allocator
        assert ort_model_attributes._use_external_gpu_allocator == False

        # test enable custom autograd function
        assert ort_model_attributes._enable_custom_autograd_function == True

        # test allow layer norm mod precision
        assert ort_model_attributes._allow_layer_norm_mod_precision == True

        # test use static shape
        assert ort_model_attributes._use_static_shape == True

        # test run symbolic shape inference
        assert ort_model_attributes._run_symbolic_shape_infer == False

        # test enable grad acc optimization
        assert ort_model_attributes._enable_grad_acc_optimization == True

        # test skip check
        assert ort_model_attributes._skip_check.value == 14

        # test debug options
        assert ort_model_attributes._debug_options.save_onnx_models.save == True
        assert ort_model_attributes._debug_options.save_onnx_models.name_prefix == 'my_model'
        assert ort_model_attributes._debug_options.logging.log_level.name == "VERBOSE"
Exemplo n.º 21
0
def run_with_ort_on_device(device,
                           model,
                           input_list,
                           label_input,
                           is_eval_mode=False,
                           run_forward_twice=False):
    with torch.no_grad():
        model = copy.deepcopy(model)
        model.to(device)
    enable_custom_autograd_function(model)
    model = ORTModule(model)

    return _run_model_on_device(device, model, input_list, label_input,
                                is_eval_mode, run_forward_twice)
Exemplo n.º 22
0
    def gradient_correctness(self, name, device, debug=False):
        pt_model_cls, op_grad_type, kwargs = self.get_torch_model_name(
            name, device)
        if kwargs is None:
            kwargs = {}
        N = 32
        pt_model = pt_model_cls().to(device)
        D_in = pt_model.fc1.in_features
        ort_model = ORTModule(copy.deepcopy(pt_model))

        def run_step(model, x):
            prediction = model(x)
            loss = prediction.sum()
            loss.backward()
            return prediction

        for _ in range(10):
            x = torch.randn(N, D_in, device=device)
            pt_prediction = run_step(pt_model, x)
            ort_prediction = run_step(ort_model, x)

            self.assert_values_are_close(ort_prediction, pt_prediction,
                                         **kwargs)
            self.assert_gradients_match_and_reset_gradient(
                ort_model, pt_model, **kwargs)

        onnx_graph_inf = ort_model._torch_module._execution_manager._training_manager._onnx_models.exported_model
        onnx_graph_train = ort_model._torch_module._execution_manager._training_manager._onnx_models.optimized_model
        if debug:
            with open("debug_%s_ortmodule_infer.onnx" % name, "wb") as f:
                f.write(onnx_graph_inf.SerializeToString())
            with open("debug_%s_ortmodule_train.onnx" % name, "wb") as f:
                f.write(onnx_graph_train.SerializeToString())
        self.assertIn('op_type: "%s"' % name, str(onnx_graph_inf))
        for onnx_model in [onnx_graph_inf, onnx_graph_train]:
            for oimp in onnx_model.opset_import:
                if oimp.domain == "":
                    self.assertEqual(oimp.version, 14)
        if op_grad_type is not None:
            if isinstance(op_grad_type, tuple):
                text = str(onnx_graph_train)
                if all(
                        map(lambda op: ('op_type: "%s"' % op) not in text,
                            op_grad_type)):
                    raise AssertionError("Operator %s not found in %s." %
                                         (" or ".join(op_grad_type), text))
            else:
                self.assertIn('op_type: "%s"' % op_grad_type,
                              str(onnx_graph_train))
Exemplo n.º 23
0
def demo_basic(rank, world_size, use_ort_module):
    torch.manual_seed(0)
    print(f"Running basic DDP example on rank {rank}.")
    setup(rank, world_size)

    # create model and move it to GPU with id rank
    model = ToyModel().to(rank)
    if use_ort_module:
        model = ORTModule(model)
        print(f"  Rank {rank} uses ORTModule.")
    else:
        print(f"  Rank {rank} uses Pytorch's nn.Module.")

    ddp_model = DDP(model, device_ids=[rank])

    loss_fn = nn.MSELoss()
    optimizer = optim.Adagrad(ddp_model.parameters(), lr=0.01)

    x = torch.randn(20, 10).to(rank)
    y = torch.randn(20, 5).to(rank)

    loss_history = []

    for i in range(5):
        optimizer.zero_grad()
        p = ddp_model(x)
        loss = loss_fn(p, y)
        with torch.no_grad():
            print(f"  Rank {rank} at iteration {i} has loss {loss}.")
        loss.backward()
        optimizer.step()
        with torch.no_grad():
            loss_history.append(torch.unsqueeze(loss, 0))

    loss_history = torch.cat(loss_history).cpu()
    expected_loss_history = torch.FloatTensor([
        1.4909229278564453, 1.432194471359253, 1.39592707157135,
        1.367714762687683, 1.3445055484771729
    ])

    assert torch.allclose(expected_loss_history, loss_history)

    cleanup()
Exemplo n.º 24
0
    def run():
        m = Foo().to('cuda')
        x = torch.rand((2, 2), dtype=torch.float).to('cuda')

        # Baseline.
        y_ref = m(x)
        print('Ref:')
        print(y_ref)

        m = ORTModule(m)

        # Inferene mode.
        y_infer = m(x)
        print(y_infer)
        assert torch.allclose(y_ref, y_infer)

        # Training mode.
        m.train()
        y_train = m(x)
        print('Train:')
        assert torch.allclose(y_ref, y_train)
Exemplo n.º 25
0
    def run():
        m = Foo().to("cuda")
        x = torch.rand((2, 2), dtype=torch.float).to("cuda")

        # Baseline.
        y_ref = m(x)
        print("Ref:")
        print(y_ref)

        from onnxruntime.training.ortmodule._custom_autograd_function import enable_custom_autograd_support

        enable_custom_autograd_support()
        m = ORTModule(m)

        # Inferene mode.
        y_infer = m(x)
        print(y_infer)
        assert torch.allclose(y_ref, y_infer)

        # Training mode.
        m.train()
        y_train = m(x)
        print("Train:")
        assert torch.allclose(y_ref, y_train)
def test_GeLU_When_Autograd_Func_Fallback_Not_Enabled():
    @torch.jit.script
    def bias_gelu(bias, y):
        x = bias + y
        return x * 0.5 * (1.0 + torch.tanh(0.79788456 * x *
                                           (1 + 0.044715 * x * x)))

    @torch.jit.script
    def bias_gelu_backward(g, bias, y):
        x = bias + y
        tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))
        ff = 0.5 * x * (
            (1 - tanh_out * tanh_out) *
            (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out)
        return ff * g

    class GeLUFunction(torch.autograd.Function):
        @staticmethod
        def forward(ctx, input, bias):
            ctx.save_for_backward(input, bias)
            return bias_gelu(bias, input)

        @staticmethod
        def backward(ctx, grad_output):
            input, bias = ctx.saved_tensors
            tmp = bias_gelu_backward(grad_output, bias, input)
            return tmp, tmp

    class GeLUModel(torch.nn.Module):
        def __init__(self, output_size):
            super(GeLUModel, self).__init__()
            self.relu = GeLUFunction.apply
            self.bias = Parameter(
                torch.empty(output_size,
                            device=torch.cuda.current_device(),
                            dtype=torch.float))

            with torch.no_grad():
                self.bias.uniform_()

        def forward(self, model_input):
            out = self.relu(model_input, self.bias)
            return out

    output_size = 1024

    def model_builder():
        return GeLUModel(output_size)

    def input_generator():
        return torch.randn(output_size, dtype=torch.float)

    # generate a label that have same shape as forward output.
    label_input = torch.ones([output_size])

    m_ort = model_builder()
    x_ort = input_generator()

    try:
        device = torch.device("cpu")
        m_ort.to(device)
        model = ORTModule(m_ort)
        model.train()

        inputs_on_device = [x_ort.to(device)]
        output = model(*inputs_on_device)
    except RuntimeError as e:
        assert "Detected autograd functions usage in current model, the run will fail" in str(
            e)
Exemplo n.º 27
0
def main():
    # Training settings
    parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
    parser.add_argument(
        '--train-steps',
        type=int,
        default=-1,
        metavar='N',
        help=
        'number of steps to train. Set -1 to run through whole dataset (default: -1)'
    )
    parser.add_argument('--lr',
                        type=float,
                        default=0.01,
                        metavar='LR',
                        help='learning rate (default: 0.01)')
    parser.add_argument('--batch-size',
                        type=int,
                        default=32,
                        metavar='N',
                        help='input batch size for training (default: 32)')
    parser.add_argument('--test-batch-size',
                        type=int,
                        default=64,
                        metavar='N',
                        help='input batch size for testing (default: 64)')
    parser.add_argument('--no-cuda',
                        action='store_true',
                        default=False,
                        help='disables CUDA training')
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        metavar='S',
                        help='random seed (default: 42)')
    parser.add_argument('--pytorch-only',
                        action='store_true',
                        default=False,
                        help='disables ONNX Runtime training')
    parser.add_argument(
        '--log-interval',
        type=int,
        default=300,
        metavar='N',
        help=
        'how many batches to wait before logging training status (default: 300)'
    )
    parser.add_argument('--view-graphs',
                        action='store_true',
                        default=False,
                        help='views forward and backward graphs')
    parser.add_argument('--export-onnx-graphs',
                        action='store_true',
                        default=False,
                        help='export ONNX graphs to current directory')
    parser.add_argument('--epochs',
                        type=int,
                        default=5,
                        metavar='N',
                        help='number of epochs to train (default: 10)')
    parser.add_argument(
        '--log-level',
        choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'],
        default='WARNING',
        help='Log level (default: WARNING)')
    parser.add_argument('--data-dir',
                        type=str,
                        default='./mnist',
                        help='Path to the mnist data directory')

    args = parser.parse_args()

    # Common setup
    torch.manual_seed(args.seed)
    onnxruntime.set_seed(args.seed)

    if not args.no_cuda and torch.cuda.is_available():
        device = "cuda"
    else:
        device = "cpu"

    ## Data loader
    train_loader = torch.utils.data.DataLoader(datasets.MNIST(
        args.data_dir,
        train=True,
        download=True,
        transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307, ), (0.3081, ))
        ])),
                                               batch_size=args.batch_size,
                                               shuffle=True)
    test_loader = None
    if args.test_batch_size > 0:
        test_loader = torch.utils.data.DataLoader(
            datasets.MNIST(args.data_dir,
                           train=False,
                           transform=transforms.Compose([
                               transforms.ToTensor(),
                               transforms.Normalize((0.1307, ), (0.3081, ))
                           ])),
            batch_size=args.test_batch_size,
            shuffle=True)

    # Model architecture
    model = NeuralNet(input_size=784, hidden_size=500,
                      num_classes=10).to(device)
    if not args.pytorch_only:
        print('Training MNIST on ORTModule....')

        # Just for future debugging
        debug_options = DebugOptions(save_onnx=args.export_onnx_graphs,
                                     onnx_prefix='MNIST')

        model = ORTModule(model, debug_options)

        # Set log level
        numeric_level = getattr(logging, args.log_level.upper(), None)
        if not isinstance(numeric_level, int):
            raise ValueError('Invalid log level: %s' % args.log_level)
        logging.basicConfig(level=numeric_level)
    else:
        print('Training MNIST on vanilla PyTorch....')
    optimizer = torch.optim.SGD(model.parameters(), lr=args.lr)

    # Train loop
    total_training_time, total_test_time, epoch_0_training, validation_accuracy = 0, 0, 0, 0
    for epoch in range(0, args.epochs):
        total_training_time += train(args, model, device, optimizer, my_loss,
                                     train_loader, epoch)
        if not args.pytorch_only and epoch == 0:
            epoch_0_training = total_training_time
        if args.test_batch_size > 0:
            test_time, validation_accuracy = test(args, model, device, my_loss,
                                                  test_loader)
            total_test_time += test_time

    assert validation_accuracy > 0.92

    print('\n======== Global stats ========')
    if not args.pytorch_only:
        estimated_export = 0
        if args.epochs > 1:
            estimated_export = epoch_0_training - (
                total_training_time - epoch_0_training) / (args.epochs - 1)
            print("  Estimated ONNX export took:               {:.4f}s".format(
                estimated_export))
        else:
            print(
                "  Estimated ONNX export took:               Estimate available when epochs > 1 only"
            )
        print("  Accumulated training without export took: {:.4f}s".format(
            total_training_time - estimated_export))
    print("  Accumulated training took:                {:.4f}s".format(
        total_training_time))
    print("  Accumulated validation took:              {:.4f}s".format(
        total_test_time))
Exemplo n.º 28
0
def demo_checkpoint(rank, world_size, use_ort_module):
    torch.manual_seed(rank)
    print(f"Running DDP checkpoint example on rank {rank}.")
    setup(rank, world_size)

    if use_ort_module:
        print(f"  Rank {rank} uses ORTModule.")
        model = ToyModel().to(rank)
        model = ORTModule(model)
    else:
        print(f"  Rank {rank} uses Pytorch's nn.Module.")
        model = ToyModel().to(rank)

    ddp_model = DDP(model, device_ids=[rank])

    loss_fn = nn.MSELoss()
    optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)

    CHECKPOINT_PATH = os.path.join(tempfile.gettempdir(), "model.checkpoint")
    if rank == 0:
        # All processes should see same parameters as they all start from same
        # random parameters and gradients are synchronized in backward passes.
        # Therefore, saving it in one process is sufficient.
        torch.save(ddp_model.state_dict(), CHECKPOINT_PATH)

    # Use a barrier() to make sure that process 1 loads the model after process
    # 0 saves it.
    dist.barrier()
    # configure map_location properly
    map_location = {'cuda:%d' % 0: 'cuda:%d' % rank}
    ddp_model.load_state_dict(
        torch.load(CHECKPOINT_PATH, map_location=map_location))

    optimizer.zero_grad()
    outputs = ddp_model(torch.randn(20, 10))
    labels = torch.randn(20, 5).to(rank)
    loss_fn = nn.MSELoss()
    loss = loss_fn(outputs, labels)
    loss.backward()
    optimizer.step()

    print(f"Rank {rank} sees loss {loss}")

    if rank == 0:
        assert torch.allclose(loss.cpu(),
                              torch.FloatTensor([1.4909229278564453]))
    elif rank == 1:
        assert torch.allclose(loss.cpu(),
                              torch.FloatTensor([1.0177688598632812]))
    elif rank == 2:
        assert torch.allclose(loss.cpu(),
                              torch.FloatTensor([1.290669322013855]))
    elif rank == 3:
        assert torch.allclose(loss.cpu(),
                              torch.FloatTensor([0.825118362903595]))
    else:
        assert False

    # Not necessary to use a dist.barrier() to guard the file deletion below
    # as the AllReduce ops in the backward pass of DDP already served as
    # a synchronization.

    if rank == 0:
        os.remove(CHECKPOINT_PATH)

    cleanup()
Exemplo n.º 29
0
def main():
    # 1. Basic setup
    parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
    parser.add_argument('--pytorch-only',
                        action='store_true',
                        default=False,
                        help='disables ONNX Runtime training')
    parser.add_argument('--batch-size',
                        type=int,
                        default=32,
                        metavar='N',
                        help='input batch size for training (default: 32)')
    parser.add_argument('--test-batch-size',
                        type=int,
                        default=64,
                        metavar='N',
                        help='input batch size for testing (default: 64)')
    parser.add_argument('--view-graphs',
                        action='store_true',
                        default=False,
                        help='views forward and backward graphs')
    parser.add_argument('--no-cuda',
                        action='store_true',
                        default=False,
                        help='disables CUDA training')
    parser.add_argument('--epochs',
                        type=int,
                        default=4,
                        metavar='N',
                        help='number of epochs to train (default: 4)')
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        metavar='S',
                        help='random seed (default: 42)')
    parser.add_argument(
        '--log-interval',
        type=int,
        default=40,
        metavar='N',
        help=
        'how many batches to wait before logging training status (default: 40)'
    )
    parser.add_argument(
        '--train-steps',
        type=int,
        default=-1,
        metavar='N',
        help=
        'number of steps to train. Set -1 to run through whole dataset (default: -1)'
    )
    parser.add_argument(
        '--log-level',
        choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'],
        default='WARNING',
        help='Log level (default: WARNING)')
    parser.add_argument(
        '--num-hidden-layers',
        type=int,
        default=1,
        metavar='H',
        help=
        'Number of hidden layers for the BERT model. A vanila BERT has 12 hidden layers (default: 1)'
    )
    parser.add_argument('--data-dir',
                        type=str,
                        default='./cola_public/raw',
                        help='Path to the bert data directory')

    args = parser.parse_args()

    # Device (CPU vs CUDA)
    if torch.cuda.is_available() and not args.no_cuda:
        device = torch.device("cuda")
        print('There are %d GPU(s) available.' % torch.cuda.device_count())
        print('We will use the GPU:', torch.cuda.get_device_name(0))
    else:
        print('No GPU available, using the CPU instead.')
        device = torch.device("cpu")

    # Set log level
    numeric_level = getattr(logging, args.log_level.upper(), None)
    if not isinstance(numeric_level, int):
        raise ValueError('Invalid log level: %s' % args.log_level)
    logging.basicConfig(level=numeric_level)

    # 2. Dataloader
    train_dataloader, validation_dataloader = load_dataset(args)

    # 3. Modeling
    # Load BertForSequenceClassification, the pretrained BERT model with a single
    # linear classification layer on top.
    config = AutoConfig.from_pretrained(
        "bert-base-uncased",
        num_labels=2,
        num_hidden_layers=args.num_hidden_layers,
        output_attentions=False,  # Whether the model returns attentions weights.
        output_hidden_states=
        False,  # Whether the model returns all hidden-states.
    )
    model = BertForSequenceClassification.from_pretrained(
        "bert-base-uncased",  # Use the 12-layer BERT model, with an uncased vocab.
        config=config,
    )

    if not args.pytorch_only:
        model = ORTModule(model)

    # Just for future debugging
    model._execution_manager(model._is_training())._save_onnx = False
    model._execution_manager(model._is_training(
    ))._save_onnx_prefix = 'BertForSequenceClassification'

    # Tell pytorch to run this model on the GPU.
    if torch.cuda.is_available() and not args.no_cuda:
        model.cuda()

    # Note: AdamW is a class from the huggingface library (as opposed to pytorch)
    optimizer = AdamW(
        model.parameters(),
        lr=2e-5,  # args.learning_rate - default is 5e-5, our notebook had 2e-5
        eps=1e-8  # args.adam_epsilon  - default is 1e-8.
    )

    # Authors recommend between 2 and 4 epochs
    # Total number of training steps is number of batches * number of epochs.
    total_steps = len(train_dataloader) * args.epochs

    # Create the learning rate scheduler.
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=0,  # Default value in run_glue.py
        num_training_steps=total_steps)
    # Seed
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    onnxruntime.set_seed(args.seed)
    if torch.cuda.is_available() and not args.no_cuda:
        torch.cuda.manual_seed_all(args.seed)

    # 4. Train loop (fine-tune)
    total_training_time, total_test_time, epoch_0_training, validation_accuracy = 0, 0, 0, 0
    for epoch_i in range(0, args.epochs):
        total_training_time += train(model, optimizer, scheduler,
                                     train_dataloader, epoch_i, device, args)
        if not args.pytorch_only and epoch_i == 0:
            epoch_0_training = total_training_time
        test_time, validation_accuracy = test(model, validation_dataloader,
                                              device, args)
        total_test_time += test_time

    assert validation_accuracy > 0.5

    print('\n======== Global stats ========')
    if not args.pytorch_only:
        estimated_export = 0
        if args.epochs > 1:
            estimated_export = epoch_0_training - (
                total_training_time - epoch_0_training) / (args.epochs - 1)
            print("  Estimated ONNX export took:               {:.4f}s".format(
                estimated_export))
        else:
            print(
                "  Estimated ONNX export took:               Estimate available when epochs > 1 only"
            )
        print("  Accumulated training without export took: {:.4f}s".format(
            total_training_time - estimated_export))
    print("  Accumulated training took:                {:.4f}s".format(
        total_training_time))
    print("  Accumulated validation took:              {:.4f}s".format(
        total_test_time))
def test_ortmodule_fallback_non_contiguous_tensors(is_training,
                                                   persist_fallback):
    # is_training: True for torch.nn.Module training model, eval mode otherwise
    # Validate fix for issue: https://github.com/pytorch/ort/issues/92

    policy = "FALLBACK_UNSUPPORTED_DEVICE"
    os.environ["ORTMODULE_FALLBACK_POLICY"] = policy
    os.environ["ORTMODULE_FALLBACK_RETRY"] = str(not persist_fallback)
    os.environ["ORTMODULE_SKIPCHECK_POLICY"] = "SKIP_CHECK_DISABLED"

    class PositionalEncoding(torch.nn.Module):
        def __init__(self, d_model, dropout=0.1, max_len=5000):
            super().__init__()
            self.dropout = torch.nn.Dropout(p=dropout)
            position = torch.arange(max_len).unsqueeze(1)
            div_term = torch.exp(
                torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
            pe = torch.zeros(max_len, 1, d_model)
            pe[:, 0, 0::2] = torch.sin(position * div_term)
            pe[:, 0, 1::2] = torch.cos(position * div_term)
            self.register_buffer("pe", pe)

        def forward(self, x):
            x = x + self.pe[:x.size(0)]
            return self.dropout(x)

    class TransformerModel(torch.nn.Module):
        def __init__(self,
                     ntoken,
                     d_model,
                     nhead,
                     d_hid,
                     nlayers,
                     dropout=0.5):
            super().__init__()
            self.model_type = "Transformer"
            encoder_layers = torch.nn.TransformerEncoderLayer(
                d_model, nhead, d_hid, dropout)
            self.transformer_encoder = torch.nn.TransformerEncoder(
                encoder_layers, nlayers)
            self.pos_encoder = PositionalEncoding(d_model, dropout)
            self.encoder = torch.nn.Embedding(ntoken, d_model)
            self.d_model = d_model
            self.decoder = torch.nn.Linear(d_model, ntoken)
            self.init_weights()

        def init_weights(self):
            initrange = 0.1
            self.encoder.weight.data.uniform_(-initrange, initrange)
            self.decoder.bias.data.zero_()
            self.decoder.weight.data.uniform_(-initrange, initrange)

        def forward(self, src, src_mask):
            src = self.encoder(src) * math.sqrt(self.d_model)
            src = self.pos_encoder(src)
            output = self.transformer_encoder(src, src_mask)
            output = self.decoder(output)
            return output

    def generate_square_subsequent_mask(sz):
        return torch.triu(torch.ones(sz, sz) * float("-inf"), diagonal=1)

    def get_batch(source, i):
        seq_len = min(bptt, len(source) - 1 - i)
        data = source[i:i + seq_len]
        target = source[i + 1:i + 1 + seq_len].reshape(-1)
        return data, target

    criterion = torch.nn.CrossEntropyLoss()
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    train_data = np.random.randint(1, 12455, 1000)
    ends = np.random.randint(2, 20, 100).cumsum()
    ends = ends[ends < train_data.shape[0] - 2]
    train_data[ends] = 0
    train_data[-1] = 0

    train_data = torch.tensor(np.array(train_data, dtype=np.int64))
    train_data = train_data.to(torch.int64).to(device)
    bptt = 35
    src_mask = generate_square_subsequent_mask(bptt).to(device)
    ntokens, emsize, nhead, d_hid, nlayers, dropout = 12455, 200, 2, 200, 2, 0.2
    pt_model = TransformerModel(ntokens, emsize, nhead, d_hid, nlayers,
                                dropout)
    model = ORTModule(pt_model).to(device)
    pt_model.train(is_training)
    model.train(is_training)
    optimizer = torch.optim.SGD(model.parameters(), lr=5.0)

    n_iter = 0
    for epoch in range(1, 2):
        model.train()  # turn on train mode

        num_batches = len(train_data) // bptt
        for batch, i in enumerate(range(0, train_data.size(0) - 1, bptt)):
            data, targets = get_batch(train_data, i)
            batch_size = data.size(0)
            if batch_size != bptt:  # only on last batch
                src_mask = src_mask[:batch_size, :batch_size]
            output = model(data, src_mask)
            nrows = min(ntokens, targets.shape[0])
            loss = criterion(output.view(nrows, -1), targets)

            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
            optimizer.step()
            n_iter += 1
            break

    assert n_iter > 0

    del os.environ["ORTMODULE_SKIPCHECK_POLICY"]