def test_ortmodule_fallback_output(is_training, fallback_enabled, matching_policy, persist_fallback): # is_training: True for torch.nn.Module training model, eval mode otherwise # fallback_enabled: True PyTorch executes the forward graph instead of ORT backend # matching_policy: True matches FALLBACK_UNSUPPORTED_DATA policy to ORTModuleDeviceException exception. # Otherwise, an incorrect policy (FALLBACK_UNSUPPORTED_DEVICE) is used to verify that the fallback does not happen if fallback_enabled: if matching_policy: policy = "FALLBACK_UNSUPPORTED_DATA" else: policy = "FALLBACK_UNSUPPORTED_DEVICE" else: policy = "FALLBACK_DISABLE" os.environ["ORTMODULE_FALLBACK_POLICY"] = policy os.environ["ORTMODULE_FALLBACK_RETRY"] = str(not persist_fallback) device = "cuda" N, D_in, H, D_out = 64, 784, 500, 10 pt_model = NeuralNetCustomClassOutput(D_in, H, D_out).to(device) ort_model = ORTModule(copy.deepcopy(pt_model)) x = torch.randn(N, D_in, device=device) y = torch.randn(N, D_in, device=device) z = torch.randn(N, D_in, device=device) ort_model.train(is_training) pt_model.train(is_training) for i in range(3): if fallback_enabled: if matching_policy: if i > 0 and persist_fallback: assert (ort_model._torch_module._execution_manager( is_training=is_training)._fallback_manager._exception is not None) ort_out = ort_model(x, y, z) pt_out = pt_model(x, y, z) _test_helpers.assert_values_are_close(ort_out.out1, pt_out.out1, rtol=0, atol=0) _test_helpers.assert_values_are_close(ort_out.out2, pt_out.out2, rtol=0, atol=0) _test_helpers.assert_values_are_close(ort_out.out3, pt_out.out3, rtol=0, atol=0) else: with pytest.raises( _fallback.ORTModuleIOError) as runtime_error: ort_model(x, y, z) assert "ORTModule does not support the following model output type" in str( runtime_error.value) else: with pytest.raises(_fallback.ORTModuleIOError) as runtime_error: ort_model(x, y, z) assert "ORTModule does not support the following model output type" in str( runtime_error.value)
def test_ortmodule_fallback_warn_message(is_training, persist_fallback): # is_training: True for torch.nn.Module training model, eval mode otherwise policy = "FALLBACK_UNSUPPORTED_DEVICE" os.environ["ORTMODULE_FALLBACK_POLICY"] = policy os.environ["ORTMODULE_FALLBACK_RETRY"] = str(not persist_fallback) os.environ["ORTMODULE_SKIPCHECK_POLICY"] = "SKIP_CHECK_DISABLED" data_device = "cuda" N, D_in, H, D_out = 64, 784, 500, 10 pt_model = NeuralNetSinglePositionalArgument(D_in, H, D_out) ort_model = ORTModule(copy.deepcopy(pt_model)) pt_model.train(is_training) ort_model.train(is_training) # For initial model export, use same device for data and model so that PyTorch model can be traced during export _ = ort_model(torch.randn(N, D_in)) # Use data in different device for testing inputs = torch.randn(N, D_in, device=data_device) for _ in range(3): with pytest.raises(RuntimeError): with pytest.warns(UserWarning) as warning_record: ort_model(inputs) assert "Fallback to PyTorch due to exception" in str( warning_record[0].message.args[0]) del os.environ["ORTMODULE_SKIPCHECK_POLICY"]
def test_ortmodule_fallback_device__mismatch(is_training, fallback_enabled, matching_policy, persist_fallback): # is_training: True for torch.nn.Module training model, eval mode otherwise # fallback_enabled: True PyTorch executes the forward graph instead of ORT backend # matching_policy: True matches FALLBACK_UNSUPPORTED_DEVICE policy to ORTModuleDeviceException exception. # Otherwise, an incorrect policy (FALLBACK_UNSUPPORTED_DATA) is used to verify that the fallback does not happen if fallback_enabled: if matching_policy: policy = "FALLBACK_UNSUPPORTED_DEVICE" else: policy = "FALLBACK_UNSUPPORTED_DATA" else: policy = "FALLBACK_DISABLE" os.environ["ORTMODULE_FALLBACK_POLICY"] = policy os.environ["ORTMODULE_FALLBACK_RETRY"] = str(not persist_fallback) os.environ["ORTMODULE_SKIPCHECK_POLICY"] = "SKIP_CHECK_DISABLED" data_device = "cuda" N, D_in, H, D_out = 64, 784, 500, 10 pt_model = NeuralNetSinglePositionalArgument(D_in, H, D_out) ort_model = ORTModule(copy.deepcopy(pt_model)) pt_model.train(is_training) ort_model.train(is_training) # For initial model export, use same device for data and model so that PyTorch model can be traced during export _ = ort_model(torch.randn(N, D_in)) # Use data in different device for testing inputs = torch.randn(N, D_in, device=data_device) ort_model_device = ort_model._torch_module._execution_manager(ort_model._is_training())._device input_device = torch.device(inputs.device) for _ in range(3): if fallback_enabled: if matching_policy: with pytest.raises(RuntimeError) as e: ort_model(inputs) assert ( "Tensor for argument #1 'self' is on CPU, but expected them to be on GPU (while checking arguments for addmm)" in str(e.value) ) or ( "Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0!" in str(e.value) ) else: with pytest.raises(_fallback.ORTModuleDeviceException) as e: ort_model(inputs) assert ( f"Input argument to forward found on device {input_device}, " f"but expected it to be on module device {ort_model_device}." in str(e.value) ) else: with pytest.raises(_fallback.ORTModuleDeviceException) as e: ort_model(inputs) assert ( f"Input argument to forward found on device {input_device}, " f"but expected it to be on module device {ort_model_device}." in str(e.value) ) del os.environ["ORTMODULE_SKIPCHECK_POLICY"]
def test_ortmodule_fallback_forward(is_training, fallback_enabled, matching_policy, persist_fallback): # is_training: True for torch.nn.Module training model, eval mode otherwise # fallback_enabled: True PyTorch executes the forward graph instead of ORT backend # matching_policy: True matches FALLBACK_FORCE_TORCH_FORWARD policy to ORTModuleDeviceException exception. # Otherwise, an incorrect policy (FALLBACK_UNSUPPORTED_DEVICE) is used to verify that the fallback does not happen if fallback_enabled: if matching_policy: policy = "FALLBACK_FORCE_TORCH_FORWARD" else: policy = "FALLBACK_UNSUPPORTED_DEVICE" else: policy = "FALLBACK_DISABLE" os.environ["ORTMODULE_FALLBACK_POLICY"] = policy os.environ["ORTMODULE_FALLBACK_RETRY"] = str(not persist_fallback) from dataclasses import dataclass @dataclass class Point: x: int y: int class UnsupportedInputModel(torch.nn.Module): def __init__(self): super(UnsupportedInputModel, self).__init__() def forward(self, point): return point.x * point.y pt_model = UnsupportedInputModel() ort_model = ORTModule(copy.deepcopy(pt_model)) inputs = Point(x=2, y=3) ort_model.train(is_training) pt_model.train(is_training) for i in range(3): if fallback_enabled: if matching_policy: if i > 0 and persist_fallback: assert (ort_model._torch_module._execution_manager( is_training=is_training)._fallback_manager._exception is not None) ort_out = ort_model(inputs) pt_out = pt_model(inputs) assert ort_out == pt_out else: with pytest.raises( _fallback.ORTModuleFallbackException) as type_error: ort_model(inputs) assert "ORTModule does not support the following model data type" in str( type_error.value) else: with pytest.raises( _fallback.ORTModuleFallbackException) as type_error: ort_model(inputs) assert "ORTModule does not support the following model data type" in str( type_error.value)
def test_ortmodule_fallback_onnx_model__custom_autograd( is_training, fallback_enabled, matching_policy, persist_fallback): # is_training: True for torch.nn.Module training model, eval mode otherwise # fallback_enabled: True PyTorch executes the forward graph instead of ORT backend # matching_policy: True matches FALLBACK_UNSUPPORTED_ONNX_MODEL policy to ORTModuleDeviceException exception. # Otherwise, an incorrect policy (FALLBACK_UNSUPPORTED_DEVICE) is used to verify that the fallback does not happen if fallback_enabled: if matching_policy: policy = "FALLBACK_UNSUPPORTED_ONNX_MODEL" else: policy = "FALLBACK_UNSUPPORTED_DEVICE" else: policy = "FALLBACK_DISABLE" os.environ["ORTMODULE_FALLBACK_POLICY"] = policy os.environ["ORTMODULE_FALLBACK_RETRY"] = str(not persist_fallback) dtype = torch.float device = torch.device("cuda") N, D_in, H, D_out = 64, 1000, 100, 10 x = torch.randn(N, D_in, device=device, dtype=dtype) y = torch.randn(N, D_out, device=device, dtype=dtype) w1 = torch.randn(D_in, H, device=device, dtype=dtype, requires_grad=True) w2 = torch.randn(H, D_out, device=device, dtype=dtype, requires_grad=True) pt_model = MyCustomFunctionReluModel() ort_model = ORTModule(copy.deepcopy(pt_model)) ort_model.train(is_training) pt_model.train(is_training) for i in range(3): if fallback_enabled: if matching_policy: if i > 0 and persist_fallback: assert (ort_model._torch_module._execution_manager( is_training=is_training)._fallback_manager._exception is not None) pt_out = pt_model(x.mm(w1)).mm(w2) ort_out = ort_model(x.mm(w1)).mm(w2) _test_helpers.assert_values_are_close(ort_out, pt_out, rtol=1e-03, atol=1e-04) else: with pytest.raises( _fallback.ORTModuleONNXModelException) as ex_info: _ = ort_model(x.mm(w1)).mm(w2) assert "There was an error while exporting the PyTorch model to ONNX" in str( ex_info.value) else: with pytest.raises( _fallback.ORTModuleONNXModelException) as ex_info: # Initialize with fallback policy because Exception will happen during __init__ _ = ort_model(x.mm(w1)).mm(w2) assert "There was an error while exporting the PyTorch model to ONNX" in str( ex_info.value)
def test_ortmodule_fallback_onnx_model__missing_op(is_training, fallback_enabled, matching_policy, persist_fallback): # is_training: True for torch.nn.Module training model, eval mode otherwise # fallback_enabled: True PyTorch executes the forward graph instead of ORT backend # matching_policy: True matches FALLBACK_UNSUPPORTED_ONNX_MODEL policy to ORTModuleDeviceException exception. # Otherwise, an incorrect policy (FALLBACK_UNSUPPORTED_DEVICE) is used to verify that the fallback does not happen if fallback_enabled: if matching_policy: policy = "FALLBACK_UNSUPPORTED_ONNX_MODEL" else: policy = "FALLBACK_UNSUPPORTED_DEVICE" else: policy = "FALLBACK_DISABLE" os.environ["ORTMODULE_FALLBACK_POLICY"] = policy os.environ["ORTMODULE_FALLBACK_RETRY"] = str(not persist_fallback) class CrossModule(torch.nn.Module): def forward(self, x, y): return torch.cross(x, y) x = torch.randn(2, 3) y = torch.randn(2, 3) pt_model = CrossModule() ort_model = ORTModule(copy.deepcopy(pt_model)) ort_model.train(is_training) pt_model.train(is_training) for i in range(3): if fallback_enabled: if matching_policy: if i > 0 and persist_fallback: assert (ort_model._torch_module._execution_manager( is_training=is_training)._fallback_manager._exception is not None) pt_out = pt_model(x, y) ort_out = ort_model(x, y) _test_helpers.assert_values_are_close(ort_out, pt_out, rtol=0, atol=0) else: with pytest.raises( _fallback.ORTModuleONNXModelException) as ex_info: _ = ort_model(x, y) assert "There was an error while exporting the PyTorch model to ONNX" in str( ex_info.value) else: with pytest.raises( _fallback.ORTModuleONNXModelException) as ex_info: # Initialize with fallback policy because Exception will happen during __init__ _ = ort_model(x, y) assert "There was an error while exporting the PyTorch model to ONNX" in str( ex_info.value)
def __init__(self, lr, use_ortmodule=True): super().__init__() self.lr = lr self.encoder = nn.Sequential(nn.Linear(28 * 28, 64), nn.ReLU(), nn.Linear(64, 3)) self.decoder = nn.Sequential(nn.Linear(3, 64), nn.ReLU(), nn.Linear(64, 28 * 28)) if use_ortmodule: self.encoder = ORTModule(self.encoder) self.decoder = ORTModule(self.decoder)
def test_ortmodule_fallback_input(is_training, fallback_enabled, matching_policy, persist_fallback): # is_training: True for torch.nn.Module training model, eval mode otherwise # fallback_enabled: True PyTorch executes the forward graph instead of ORT backend # matching_policy: True matches FALLBACK_UNSUPPORTED_DATA policy to ORTModuleDeviceException exception. # Otherwise, an incorrect policy (FALLBACK_UNSUPPORTED_DEVICE) is used to verify that the fallback does not happen if fallback_enabled: if matching_policy: policy = "FALLBACK_UNSUPPORTED_DATA" else: policy = "FALLBACK_UNSUPPORTED_DEVICE" else: policy = "FALLBACK_DISABLE" os.environ["ORTMODULE_FALLBACK_POLICY"] = policy os.environ["ORTMODULE_FALLBACK_RETRY"] = str(not persist_fallback) pt_model = MyCustomClassInputNet() ort_model = ORTModule(copy.deepcopy(pt_model)) inputs = torch.randn(1, 2) class CustomClass: def __init__(self, x): self.x = x ort_model.train(is_training) pt_model.train(is_training) for i in range(3): if fallback_enabled: if matching_policy: if i > 0 and persist_fallback: assert (ort_model._torch_module._execution_manager( is_training=is_training)._fallback_manager._exception is not None) ort_out = ort_model(inputs, CustomClass(1)) pt_out = pt_model(inputs, CustomClass(1)) _test_helpers.assert_values_are_close(ort_out, pt_out, rtol=0, atol=0) else: with pytest.raises(_fallback.ORTModuleIOError) as ex_info: _ = ort_model(torch.randn(1, 2), CustomClass(1)) assert ("ORTModule does not support the following model data" " type <class 'orttraining_test_ortmodule_fallback." "test_ortmodule_fallback_input.<locals>.CustomClass'>" in str(ex_info.value)) else: with pytest.raises(_fallback.ORTModuleIOError) as ex_info: _ = ort_model(torch.randn(1, 2), CustomClass(1)) assert ("ORTModule does not support the following model data" " type <class 'orttraining_test_ortmodule_fallback." "test_ortmodule_fallback_input.<locals>.CustomClass'>" in str(ex_info.value))
def test_ortmodule_fallback_device__multiple(is_training, fallback_enabled, matching_policy, persist_fallback): # is_training: True for torch.nn.Module training model, eval mode otherwise # fallback_enabled: True PyTorch executes the forward graph instead of ORT backend # matching_policy: True matches FALLBACK_UNSUPPORTED_DEVICE policy to ORTModuleDeviceException exception. # Otherwise, an incorrect policy (FALLBACK_UNSUPPORTED_DATA) is used to verify that the fallback does not happen if fallback_enabled: if matching_policy: policy = "FALLBACK_UNSUPPORTED_DEVICE" else: policy = "FALLBACK_UNSUPPORTED_DATA" else: policy = "FALLBACK_DISABLE" os.environ["ORTMODULE_FALLBACK_POLICY"] = policy os.environ["ORTMODULE_FALLBACK_RETRY"] = str(not persist_fallback) class ManyDevicesNet(torch.nn.Module): def __init__(self): super().__init__() self.net1 = torch.nn.Linear(10, 10).to("cuda:0") self.relu = torch.nn.ReLU() self.net2 = torch.nn.Linear(10, 5).to("cpu") def forward(self, x): x = self.relu(self.net1(x.to("cuda:0"))) return self.net2(x.to("cpu")) pt_model = ManyDevicesNet() inputs = torch.randn(20, 10) for _ in range(3): if fallback_enabled: if matching_policy: ort_model = ORTModule(copy.deepcopy(pt_model)) pt_model.train(is_training) ort_model.train(is_training) ort_out = ort_model(inputs) pt_out = pt_model(inputs) _test_helpers.assert_values_are_close(ort_out, pt_out, rtol=0, atol=0) else: with pytest.raises( _fallback.ORTModuleFallbackException) as type_error: # Initialize with fallback policy because Exception will happen during __init__ ort_model = ORTModule(copy.deepcopy(pt_model)) assert "ORTModule supports a single device per model" in str( type_error.value) else: with pytest.raises( _fallback.ORTModuleFallbackException) as type_error: # Initialize with fallback policy because Exception will happen during __init__ ort_model = ORTModule(copy.deepcopy(pt_model)) assert "ORTModule supports a single device per model" in str( type_error.value)
def test_ortmodule_fallback_init__torch_version(is_training, fallback_enabled, matching_policy, persist_fallback): # is_training: True for torch.nn.Module training model, eval mode otherwise # fallback_enabled: True PyTorch executes the forward graph instead of ORT backend # matching_policy: True matches FALLBACK_UNSUPPORTED_TORCH_MODEL policy to ORTModuleDeviceException exception. # Otherwise, an incorrect policy (FALLBACK_UNSUPPORTED_DEVICE) is used to verify that the fallback does not happen from packaging import version from onnxruntime.training.ortmodule import MINIMUM_RUNTIME_PYTORCH_VERSION_STR runtime_pytorch_version = version.parse(torch.__version__.split("+")[0]) minimum_runtime_pytorch_version = version.parse(MINIMUM_RUNTIME_PYTORCH_VERSION_STR) if runtime_pytorch_version < minimum_runtime_pytorch_version: if fallback_enabled: if matching_policy: policy = "FALLBACK_BAD_INITIALIZATION" else: policy = "FALLBACK_UNSUPPORTED_DEVICE" else: policy = "FALLBACK_DISABLE" os.environ["ORTMODULE_FALLBACK_POLICY"] = policy os.environ["ORTMODULE_FALLBACK_RETRY"] = str(not persist_fallback) device = "cuda" N, D_in, H, D_out = 64, 784, 500, 10 x = torch.randn(N, D_in, device=device) pt_model = NeuralNetSinglePositionalArgument(D_in, H, D_out).to(device) for i in range(3): if fallback_enabled: if matching_policy: ort_model = ORTModule(pt_model) ort_model.train(is_training) pt_model.train(is_training) ort_out = ort_model(x) pt_out = pt_model(x) _test_helpers.assert_values_are_close(ort_out, pt_out, rtol=0, atol=0) else: with pytest.raises(_fallback.ORTModuleInitException) as ex_info: ort_model = ORTModule(pt_model) assert "ONNX Runtime ORTModule frontend requires PyTorch version greater or equal to" in str( ex_info.value ) else: with pytest.raises(_fallback.ORTModuleInitException) as ex_info: # Initialize with fallback policy because Exception will happen during __init__ ort_model = ORTModule(pt_model) assert "ONNX Runtime ORTModule frontend requires PyTorch version greater or equal to" in str( ex_info.value ) else: warnings.warn( "Skipping test_ortmodule_fallback_torch_version." f" It requires PyTorch prior to {MINIMUM_RUNTIME_PYTORCH_VERSION_STR}" )
def test_bool_input_and_output(self): class NeuralNetBoolInputOutput(torch.nn.Module): def __init__(self, input_size, num_classes): super(NeuralNetBoolInputOutput, self).__init__() self.fc = torch.nn.Linear(input_size, num_classes) self.relu = torch.nn.ReLU() def forward(self, condition, x1, x2): out1 = self.relu(self.fc(torch.where(condition, x1, x2))) out2 = torch.tensor(out1).to(torch.bool) return out1, out2 device = 'cpu' N, D_in, D_out = 8, 16, 2 model = NeuralNetBoolInputOutput(D_in, D_out).to(device) model = ORTModule(model) condition = torch.randint(2, (N, D_in), dtype=torch.bool, device=device) x1 = torch.randn(N, D_in, device=device) x2 = torch.randn(N, D_in, device=device) y1, y2 = model(condition, x1, x2) assert y1 is not None assert y2 is not None and y2.dtype == torch.bool
def test_ortmodule_dlpack(self): class NeuralNetTanh(torch.nn.Module): def __init__(self, input_size, hidden_size, num_classes): super(NeuralNetTanh, self).__init__() self.fc1 = torch.nn.Linear(input_size, hidden_size) self.tanh = torch.nn.Tanh() def forward(self, input1): out = self.fc1(input1) out = self.tanh(out) return out def run_step(model, x): prediction = model(x) loss = prediction.sum() loss.backward() return prediction, loss N, D_in, H, D_out = 120, 1536, 500, 1536 pt_model = NeuralNetTanh(D_in, H, D_out) ort_model = ORTModule(copy.deepcopy(pt_model)) for step in range(10): pt_x = torch.randn(N, D_in, device='cpu', requires_grad=True) ort_x = copy.deepcopy(pt_x) ort_prediction, ort_loss = run_step(ort_model, ort_x) pt_prediction, pt_loss = run_step(pt_model, pt_x) _test_helpers.assert_values_are_close(ort_prediction, pt_prediction, atol=1e-4) _test_helpers.assert_values_are_close(ort_x.grad, pt_x.grad) _test_helpers.assert_values_are_close(ort_loss, pt_loss, atol=1e-4)
def forward(ctx, x, dim, device, use_ort): ctx.save_for_backward(x) ctx.device = device ctx.inner = InnerModel(dim, device).to(device) if use_ort: ctx.inner = ORTModule(ctx.inner) z = ctx.inner(x) return z
def test_ortmodule_fallback_init__missing_cpp_extensions( is_training, fallback_enabled, matching_policy, persist_fallback): # is_training: True for torch.nn.Module training model, eval mode otherwise # fallback_enabled: True PyTorch executes the forward graph instead of ORT backend # matching_policy: True matches FALLBACK_UNSUPPORTED_TORCH_MODEL policy to ORTModuleDeviceException exception. # Otherwise, an incorrect policy (FALLBACK_UNSUPPORTED_DEVICE) is used to verify that the fallback does not happen if is_torch_cpp_extensions_installed(ORTMODULE_TORCH_CPP_DIR): warnings.warn( "Skipping test_ortmodule_fallback_init__missing_cpp_extensions." f" It requires PyTorch CPP extensions to be missing") else: if fallback_enabled: if matching_policy: policy = "FALLBACK_BAD_INITIALIZATION" else: policy = "FALLBACK_UNSUPPORTED_DEVICE" else: policy = "FALLBACK_DISABLE" os.environ["ORTMODULE_FALLBACK_POLICY"] = policy os.environ["ORTMODULE_FALLBACK_RETRY"] = str(not persist_fallback) device = "cuda" N, D_in, H, D_out = 64, 784, 500, 10 x = torch.randn(N, D_in, device=device) pt_model = NeuralNetSinglePositionalArgument(D_in, H, D_out).to(device) for _ in range(3): if fallback_enabled: if matching_policy: ort_model = ORTModule(pt_model) ort_model.train(is_training) pt_model.train(is_training) ort_out = ort_model(x) pt_out = pt_model(x) _test_helpers.assert_values_are_close(ort_out, pt_out, rtol=0, atol=0) else: with pytest.raises( _fallback.ORTModuleInitException) as ex_info: ort_model = ORTModule(pt_model) assert "ORTModule's extensions were not detected" in str( ex_info.value) else: with pytest.raises( _fallback.ORTModuleInitException) as ex_info: # Initialize with fallback policy because Exception will happen during __init__ ort_model = ORTModule(pt_model) assert "ORTModule's extensions were not detected" in str( ex_info.value)
def run_with_ort_on_device(device, model, input_list, label_input, is_eval_mode=False): model = copy.deepcopy(model) model.to(device) model = ORTModule(model) enable_custom_autograd_function(model) if is_eval_mode: model.eval() else: model.train() inputs_on_device = [input_.to(device) for input_ in input_list] output = model(*inputs_on_device) forward_outputs = [output] grad_outputs = [] if not is_eval_mode: criterion = torch.nn.MSELoss() target = label_input.to(device) loss = criterion(output, target) loss.backward() for name, param in model.named_parameters(): if param.requires_grad: grad_outputs.append(param.grad) return forward_outputs, grad_outputs
def forward(ctx, x, dim, device, use_ort): ctx.save_for_backward(x) ctx.device = device ctx.inner = InnerModel(dim, device).to(device) if use_ort: ctx.inner = ORTModule(ctx.inner) enable_custom_autograd_function(ctx.inner) z = ctx.inner(x) return z
def run_with_ort_on_gpu(model, args, rank, device): model.to(device) model = ORTModule(model) _test_helpers.set_onnx_fallthrough_export_type(model) model = DDP(model, device_ids=[rank]) cuda_args = [arg.to(device) for arg in args] output = model(*cuda_args) output.sum().backward() return output, [arg.grad for arg in cuda_args]
def test_ortmodule_fallback_torch_model(is_training, fallback_enabled, matching_policy, persist_fallback): # is_training: True for torch.nn.Module training model, eval mode otherwise # fallback_enabled: True PyTorch executes the forward graph instead of ORT backend # matching_policy: True matches FALLBACK_UNSUPPORTED_TORCH_MODEL policy to ORTModuleDeviceException exception. # Otherwise, an incorrect policy (FALLBACK_UNSUPPORTED_DEVICE) is used to verify that the fallback does not happen if fallback_enabled: if matching_policy: policy = "FALLBACK_UNSUPPORTED_TORCH_MODEL" else: policy = "FALLBACK_UNSUPPORTED_DEVICE" else: policy = "FALLBACK_DISABLE" os.environ["ORTMODULE_FALLBACK_POLICY"] = policy os.environ["ORTMODULE_FALLBACK_RETRY"] = str(not persist_fallback) device = "cuda" N, D_in, H, D_out = 64, 784, 500, 10 x = torch.randn(N, D_in, device=device) pt_model = NeuralNetSinglePositionalArgument(D_in, H, D_out).to(device) pt_model = torch.nn.DataParallel(pt_model) for _ in range(3): if fallback_enabled: if matching_policy: ort_model = ORTModule(pt_model) ort_model.train(is_training) pt_model.train(is_training) ort_out = ort_model(x) pt_out = pt_model(x) _test_helpers.assert_values_are_close(ort_out, pt_out, rtol=1e-3, atol=1e-6) else: with pytest.raises( _fallback.ORTModuleTorchModelException) as ex_info: ort_model = ORTModule(pt_model) assert "ORTModule is not compatible with torch.nn.DataParallel" in str( ex_info.value) else: with pytest.raises( _fallback.ORTModuleTorchModelException) as ex_info: # Initialize with fallback policy because Exception will happen during __init__ ort_model = ORTModule(pt_model) assert "ORTModule is not compatible with torch.nn.DataParallel" in str( ex_info.value)
def test_load_config_from_json_2(): device = 'cuda' model = ORTModule(Net().to(device)) # load from json once. path_to_json = os.path.join(os.getcwd(), 'orttraining_test_ortmodule_experimental_json_config_1.json') load_from_json(model, path_to_json) # load from json another time path_to_json = os.path.join(os.getcwd(), 'orttraining_test_ortmodule_experimental_json_config_2.json') load_from_json(model, path_to_json) for training_mode in [True, False]: ort_model_attributes = model._torch_module._execution_manager(training_mode) # test propagate cast ops assert ort_model_attributes._propagate_cast_ops_strategy == C.PropagateCastOpsStrategy.INSERT_AND_REDUCE assert ort_model_attributes._propagate_cast_ops_level == 5 assert ort_model_attributes._propagate_cast_ops_allow == ["XYZ", "PQR"] # test use external gpu allocator assert ort_model_attributes._use_external_gpu_allocator == True # test enable custom autograd function assert ort_model_attributes._enable_custom_autograd_function == False # test allow layer norm mod precision assert ort_model_attributes._allow_layer_norm_mod_precision == False # test use static shape assert ort_model_attributes._use_static_shape == False # test run symbolic shape inference assert ort_model_attributes._run_symbolic_shape_infer == True # test enable grad acc optimization assert ort_model_attributes._enable_grad_acc_optimization == False # test skip check assert ort_model_attributes._skip_check.value == 10 # test debug options assert ort_model_attributes._debug_options.save_onnx_models.save == True assert ort_model_attributes._debug_options.save_onnx_models.name_prefix == 'my_other_model' assert ort_model_attributes._debug_options.logging.log_level.name == "INFO" # test use memory aware gradient builder. assert ort_model_attributes._use_memory_efficient_gradient == True # test fallback policy assert ort_model_attributes._fallback_manager.policy.value == 250
def test_load_config_from_json_1(): device = 'cuda' model = ORTModule(Net().to(device)) # load from json once. path_to_json = os.path.join( os.getcwd(), 'orttraining_test_ortmodule_experimental_json_config_2.json') load_from_json(model, path_to_json) # load from json another time path_to_json = os.path.join( os.getcwd(), 'orttraining_test_ortmodule_experimental_json_config_1.json') load_from_json(model, path_to_json) for training_mode in [True, False]: ort_model_attributes = model._torch_module._execution_manager( training_mode) # test propagate cast ops assert ort_model_attributes._propagate_cast_ops_strategy == C.PropagateCastOpsStrategy.FLOOD_FILL assert ort_model_attributes._propagate_cast_ops_level == 3 assert ort_model_attributes._propagate_cast_ops_allow == ["ABC", "DEF"] # test use external gpu allocator assert ort_model_attributes._use_external_gpu_allocator == False # test enable custom autograd function assert ort_model_attributes._enable_custom_autograd_function == True # test allow layer norm mod precision assert ort_model_attributes._allow_layer_norm_mod_precision == True # test use static shape assert ort_model_attributes._use_static_shape == True # test run symbolic shape inference assert ort_model_attributes._run_symbolic_shape_infer == False # test enable grad acc optimization assert ort_model_attributes._enable_grad_acc_optimization == True # test skip check assert ort_model_attributes._skip_check.value == 14 # test debug options assert ort_model_attributes._debug_options.save_onnx_models.save == True assert ort_model_attributes._debug_options.save_onnx_models.name_prefix == 'my_model' assert ort_model_attributes._debug_options.logging.log_level.name == "VERBOSE"
def run_with_ort_on_device(device, model, input_list, label_input, is_eval_mode=False, run_forward_twice=False): with torch.no_grad(): model = copy.deepcopy(model) model.to(device) enable_custom_autograd_function(model) model = ORTModule(model) return _run_model_on_device(device, model, input_list, label_input, is_eval_mode, run_forward_twice)
def gradient_correctness(self, name, device, debug=False): pt_model_cls, op_grad_type, kwargs = self.get_torch_model_name( name, device) if kwargs is None: kwargs = {} N = 32 pt_model = pt_model_cls().to(device) D_in = pt_model.fc1.in_features ort_model = ORTModule(copy.deepcopy(pt_model)) def run_step(model, x): prediction = model(x) loss = prediction.sum() loss.backward() return prediction for _ in range(10): x = torch.randn(N, D_in, device=device) pt_prediction = run_step(pt_model, x) ort_prediction = run_step(ort_model, x) self.assert_values_are_close(ort_prediction, pt_prediction, **kwargs) self.assert_gradients_match_and_reset_gradient( ort_model, pt_model, **kwargs) onnx_graph_inf = ort_model._torch_module._execution_manager._training_manager._onnx_models.exported_model onnx_graph_train = ort_model._torch_module._execution_manager._training_manager._onnx_models.optimized_model if debug: with open("debug_%s_ortmodule_infer.onnx" % name, "wb") as f: f.write(onnx_graph_inf.SerializeToString()) with open("debug_%s_ortmodule_train.onnx" % name, "wb") as f: f.write(onnx_graph_train.SerializeToString()) self.assertIn('op_type: "%s"' % name, str(onnx_graph_inf)) for onnx_model in [onnx_graph_inf, onnx_graph_train]: for oimp in onnx_model.opset_import: if oimp.domain == "": self.assertEqual(oimp.version, 14) if op_grad_type is not None: if isinstance(op_grad_type, tuple): text = str(onnx_graph_train) if all( map(lambda op: ('op_type: "%s"' % op) not in text, op_grad_type)): raise AssertionError("Operator %s not found in %s." % (" or ".join(op_grad_type), text)) else: self.assertIn('op_type: "%s"' % op_grad_type, str(onnx_graph_train))
def demo_basic(rank, world_size, use_ort_module): torch.manual_seed(0) print(f"Running basic DDP example on rank {rank}.") setup(rank, world_size) # create model and move it to GPU with id rank model = ToyModel().to(rank) if use_ort_module: model = ORTModule(model) print(f" Rank {rank} uses ORTModule.") else: print(f" Rank {rank} uses Pytorch's nn.Module.") ddp_model = DDP(model, device_ids=[rank]) loss_fn = nn.MSELoss() optimizer = optim.Adagrad(ddp_model.parameters(), lr=0.01) x = torch.randn(20, 10).to(rank) y = torch.randn(20, 5).to(rank) loss_history = [] for i in range(5): optimizer.zero_grad() p = ddp_model(x) loss = loss_fn(p, y) with torch.no_grad(): print(f" Rank {rank} at iteration {i} has loss {loss}.") loss.backward() optimizer.step() with torch.no_grad(): loss_history.append(torch.unsqueeze(loss, 0)) loss_history = torch.cat(loss_history).cpu() expected_loss_history = torch.FloatTensor([ 1.4909229278564453, 1.432194471359253, 1.39592707157135, 1.367714762687683, 1.3445055484771729 ]) assert torch.allclose(expected_loss_history, loss_history) cleanup()
def run(): m = Foo().to('cuda') x = torch.rand((2, 2), dtype=torch.float).to('cuda') # Baseline. y_ref = m(x) print('Ref:') print(y_ref) m = ORTModule(m) # Inferene mode. y_infer = m(x) print(y_infer) assert torch.allclose(y_ref, y_infer) # Training mode. m.train() y_train = m(x) print('Train:') assert torch.allclose(y_ref, y_train)
def run(): m = Foo().to("cuda") x = torch.rand((2, 2), dtype=torch.float).to("cuda") # Baseline. y_ref = m(x) print("Ref:") print(y_ref) from onnxruntime.training.ortmodule._custom_autograd_function import enable_custom_autograd_support enable_custom_autograd_support() m = ORTModule(m) # Inferene mode. y_infer = m(x) print(y_infer) assert torch.allclose(y_ref, y_infer) # Training mode. m.train() y_train = m(x) print("Train:") assert torch.allclose(y_ref, y_train)
def test_GeLU_When_Autograd_Func_Fallback_Not_Enabled(): @torch.jit.script def bias_gelu(bias, y): x = bias + y return x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))) @torch.jit.script def bias_gelu_backward(g, bias, y): x = bias + y tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)) ff = 0.5 * x * ( (1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out) return ff * g class GeLUFunction(torch.autograd.Function): @staticmethod def forward(ctx, input, bias): ctx.save_for_backward(input, bias) return bias_gelu(bias, input) @staticmethod def backward(ctx, grad_output): input, bias = ctx.saved_tensors tmp = bias_gelu_backward(grad_output, bias, input) return tmp, tmp class GeLUModel(torch.nn.Module): def __init__(self, output_size): super(GeLUModel, self).__init__() self.relu = GeLUFunction.apply self.bias = Parameter( torch.empty(output_size, device=torch.cuda.current_device(), dtype=torch.float)) with torch.no_grad(): self.bias.uniform_() def forward(self, model_input): out = self.relu(model_input, self.bias) return out output_size = 1024 def model_builder(): return GeLUModel(output_size) def input_generator(): return torch.randn(output_size, dtype=torch.float) # generate a label that have same shape as forward output. label_input = torch.ones([output_size]) m_ort = model_builder() x_ort = input_generator() try: device = torch.device("cpu") m_ort.to(device) model = ORTModule(m_ort) model.train() inputs_on_device = [x_ort.to(device)] output = model(*inputs_on_device) except RuntimeError as e: assert "Detected autograd functions usage in current model, the run will fail" in str( e)
def main(): # Training settings parser = argparse.ArgumentParser(description='PyTorch MNIST Example') parser.add_argument( '--train-steps', type=int, default=-1, metavar='N', help= 'number of steps to train. Set -1 to run through whole dataset (default: -1)' ) parser.add_argument('--lr', type=float, default=0.01, metavar='LR', help='learning rate (default: 0.01)') parser.add_argument('--batch-size', type=int, default=32, metavar='N', help='input batch size for training (default: 32)') parser.add_argument('--test-batch-size', type=int, default=64, metavar='N', help='input batch size for testing (default: 64)') parser.add_argument('--no-cuda', action='store_true', default=False, help='disables CUDA training') parser.add_argument('--seed', type=int, default=42, metavar='S', help='random seed (default: 42)') parser.add_argument('--pytorch-only', action='store_true', default=False, help='disables ONNX Runtime training') parser.add_argument( '--log-interval', type=int, default=300, metavar='N', help= 'how many batches to wait before logging training status (default: 300)' ) parser.add_argument('--view-graphs', action='store_true', default=False, help='views forward and backward graphs') parser.add_argument('--export-onnx-graphs', action='store_true', default=False, help='export ONNX graphs to current directory') parser.add_argument('--epochs', type=int, default=5, metavar='N', help='number of epochs to train (default: 10)') parser.add_argument( '--log-level', choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'], default='WARNING', help='Log level (default: WARNING)') parser.add_argument('--data-dir', type=str, default='./mnist', help='Path to the mnist data directory') args = parser.parse_args() # Common setup torch.manual_seed(args.seed) onnxruntime.set_seed(args.seed) if not args.no_cuda and torch.cuda.is_available(): device = "cuda" else: device = "cpu" ## Data loader train_loader = torch.utils.data.DataLoader(datasets.MNIST( args.data_dir, train=True, download=True, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307, ), (0.3081, )) ])), batch_size=args.batch_size, shuffle=True) test_loader = None if args.test_batch_size > 0: test_loader = torch.utils.data.DataLoader( datasets.MNIST(args.data_dir, train=False, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307, ), (0.3081, )) ])), batch_size=args.test_batch_size, shuffle=True) # Model architecture model = NeuralNet(input_size=784, hidden_size=500, num_classes=10).to(device) if not args.pytorch_only: print('Training MNIST on ORTModule....') # Just for future debugging debug_options = DebugOptions(save_onnx=args.export_onnx_graphs, onnx_prefix='MNIST') model = ORTModule(model, debug_options) # Set log level numeric_level = getattr(logging, args.log_level.upper(), None) if not isinstance(numeric_level, int): raise ValueError('Invalid log level: %s' % args.log_level) logging.basicConfig(level=numeric_level) else: print('Training MNIST on vanilla PyTorch....') optimizer = torch.optim.SGD(model.parameters(), lr=args.lr) # Train loop total_training_time, total_test_time, epoch_0_training, validation_accuracy = 0, 0, 0, 0 for epoch in range(0, args.epochs): total_training_time += train(args, model, device, optimizer, my_loss, train_loader, epoch) if not args.pytorch_only and epoch == 0: epoch_0_training = total_training_time if args.test_batch_size > 0: test_time, validation_accuracy = test(args, model, device, my_loss, test_loader) total_test_time += test_time assert validation_accuracy > 0.92 print('\n======== Global stats ========') if not args.pytorch_only: estimated_export = 0 if args.epochs > 1: estimated_export = epoch_0_training - ( total_training_time - epoch_0_training) / (args.epochs - 1) print(" Estimated ONNX export took: {:.4f}s".format( estimated_export)) else: print( " Estimated ONNX export took: Estimate available when epochs > 1 only" ) print(" Accumulated training without export took: {:.4f}s".format( total_training_time - estimated_export)) print(" Accumulated training took: {:.4f}s".format( total_training_time)) print(" Accumulated validation took: {:.4f}s".format( total_test_time))
def demo_checkpoint(rank, world_size, use_ort_module): torch.manual_seed(rank) print(f"Running DDP checkpoint example on rank {rank}.") setup(rank, world_size) if use_ort_module: print(f" Rank {rank} uses ORTModule.") model = ToyModel().to(rank) model = ORTModule(model) else: print(f" Rank {rank} uses Pytorch's nn.Module.") model = ToyModel().to(rank) ddp_model = DDP(model, device_ids=[rank]) loss_fn = nn.MSELoss() optimizer = optim.SGD(ddp_model.parameters(), lr=0.001) CHECKPOINT_PATH = os.path.join(tempfile.gettempdir(), "model.checkpoint") if rank == 0: # All processes should see same parameters as they all start from same # random parameters and gradients are synchronized in backward passes. # Therefore, saving it in one process is sufficient. torch.save(ddp_model.state_dict(), CHECKPOINT_PATH) # Use a barrier() to make sure that process 1 loads the model after process # 0 saves it. dist.barrier() # configure map_location properly map_location = {'cuda:%d' % 0: 'cuda:%d' % rank} ddp_model.load_state_dict( torch.load(CHECKPOINT_PATH, map_location=map_location)) optimizer.zero_grad() outputs = ddp_model(torch.randn(20, 10)) labels = torch.randn(20, 5).to(rank) loss_fn = nn.MSELoss() loss = loss_fn(outputs, labels) loss.backward() optimizer.step() print(f"Rank {rank} sees loss {loss}") if rank == 0: assert torch.allclose(loss.cpu(), torch.FloatTensor([1.4909229278564453])) elif rank == 1: assert torch.allclose(loss.cpu(), torch.FloatTensor([1.0177688598632812])) elif rank == 2: assert torch.allclose(loss.cpu(), torch.FloatTensor([1.290669322013855])) elif rank == 3: assert torch.allclose(loss.cpu(), torch.FloatTensor([0.825118362903595])) else: assert False # Not necessary to use a dist.barrier() to guard the file deletion below # as the AllReduce ops in the backward pass of DDP already served as # a synchronization. if rank == 0: os.remove(CHECKPOINT_PATH) cleanup()
def main(): # 1. Basic setup parser = argparse.ArgumentParser(description='PyTorch MNIST Example') parser.add_argument('--pytorch-only', action='store_true', default=False, help='disables ONNX Runtime training') parser.add_argument('--batch-size', type=int, default=32, metavar='N', help='input batch size for training (default: 32)') parser.add_argument('--test-batch-size', type=int, default=64, metavar='N', help='input batch size for testing (default: 64)') parser.add_argument('--view-graphs', action='store_true', default=False, help='views forward and backward graphs') parser.add_argument('--no-cuda', action='store_true', default=False, help='disables CUDA training') parser.add_argument('--epochs', type=int, default=4, metavar='N', help='number of epochs to train (default: 4)') parser.add_argument('--seed', type=int, default=42, metavar='S', help='random seed (default: 42)') parser.add_argument( '--log-interval', type=int, default=40, metavar='N', help= 'how many batches to wait before logging training status (default: 40)' ) parser.add_argument( '--train-steps', type=int, default=-1, metavar='N', help= 'number of steps to train. Set -1 to run through whole dataset (default: -1)' ) parser.add_argument( '--log-level', choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'], default='WARNING', help='Log level (default: WARNING)') parser.add_argument( '--num-hidden-layers', type=int, default=1, metavar='H', help= 'Number of hidden layers for the BERT model. A vanila BERT has 12 hidden layers (default: 1)' ) parser.add_argument('--data-dir', type=str, default='./cola_public/raw', help='Path to the bert data directory') args = parser.parse_args() # Device (CPU vs CUDA) if torch.cuda.is_available() and not args.no_cuda: device = torch.device("cuda") print('There are %d GPU(s) available.' % torch.cuda.device_count()) print('We will use the GPU:', torch.cuda.get_device_name(0)) else: print('No GPU available, using the CPU instead.') device = torch.device("cpu") # Set log level numeric_level = getattr(logging, args.log_level.upper(), None) if not isinstance(numeric_level, int): raise ValueError('Invalid log level: %s' % args.log_level) logging.basicConfig(level=numeric_level) # 2. Dataloader train_dataloader, validation_dataloader = load_dataset(args) # 3. Modeling # Load BertForSequenceClassification, the pretrained BERT model with a single # linear classification layer on top. config = AutoConfig.from_pretrained( "bert-base-uncased", num_labels=2, num_hidden_layers=args.num_hidden_layers, output_attentions=False, # Whether the model returns attentions weights. output_hidden_states= False, # Whether the model returns all hidden-states. ) model = BertForSequenceClassification.from_pretrained( "bert-base-uncased", # Use the 12-layer BERT model, with an uncased vocab. config=config, ) if not args.pytorch_only: model = ORTModule(model) # Just for future debugging model._execution_manager(model._is_training())._save_onnx = False model._execution_manager(model._is_training( ))._save_onnx_prefix = 'BertForSequenceClassification' # Tell pytorch to run this model on the GPU. if torch.cuda.is_available() and not args.no_cuda: model.cuda() # Note: AdamW is a class from the huggingface library (as opposed to pytorch) optimizer = AdamW( model.parameters(), lr=2e-5, # args.learning_rate - default is 5e-5, our notebook had 2e-5 eps=1e-8 # args.adam_epsilon - default is 1e-8. ) # Authors recommend between 2 and 4 epochs # Total number of training steps is number of batches * number of epochs. total_steps = len(train_dataloader) * args.epochs # Create the learning rate scheduler. scheduler = get_linear_schedule_with_warmup( optimizer, num_warmup_steps=0, # Default value in run_glue.py num_training_steps=total_steps) # Seed random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) onnxruntime.set_seed(args.seed) if torch.cuda.is_available() and not args.no_cuda: torch.cuda.manual_seed_all(args.seed) # 4. Train loop (fine-tune) total_training_time, total_test_time, epoch_0_training, validation_accuracy = 0, 0, 0, 0 for epoch_i in range(0, args.epochs): total_training_time += train(model, optimizer, scheduler, train_dataloader, epoch_i, device, args) if not args.pytorch_only and epoch_i == 0: epoch_0_training = total_training_time test_time, validation_accuracy = test(model, validation_dataloader, device, args) total_test_time += test_time assert validation_accuracy > 0.5 print('\n======== Global stats ========') if not args.pytorch_only: estimated_export = 0 if args.epochs > 1: estimated_export = epoch_0_training - ( total_training_time - epoch_0_training) / (args.epochs - 1) print(" Estimated ONNX export took: {:.4f}s".format( estimated_export)) else: print( " Estimated ONNX export took: Estimate available when epochs > 1 only" ) print(" Accumulated training without export took: {:.4f}s".format( total_training_time - estimated_export)) print(" Accumulated training took: {:.4f}s".format( total_training_time)) print(" Accumulated validation took: {:.4f}s".format( total_test_time))
def test_ortmodule_fallback_non_contiguous_tensors(is_training, persist_fallback): # is_training: True for torch.nn.Module training model, eval mode otherwise # Validate fix for issue: https://github.com/pytorch/ort/issues/92 policy = "FALLBACK_UNSUPPORTED_DEVICE" os.environ["ORTMODULE_FALLBACK_POLICY"] = policy os.environ["ORTMODULE_FALLBACK_RETRY"] = str(not persist_fallback) os.environ["ORTMODULE_SKIPCHECK_POLICY"] = "SKIP_CHECK_DISABLED" class PositionalEncoding(torch.nn.Module): def __init__(self, d_model, dropout=0.1, max_len=5000): super().__init__() self.dropout = torch.nn.Dropout(p=dropout) position = torch.arange(max_len).unsqueeze(1) div_term = torch.exp( torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)) pe = torch.zeros(max_len, 1, d_model) pe[:, 0, 0::2] = torch.sin(position * div_term) pe[:, 0, 1::2] = torch.cos(position * div_term) self.register_buffer("pe", pe) def forward(self, x): x = x + self.pe[:x.size(0)] return self.dropout(x) class TransformerModel(torch.nn.Module): def __init__(self, ntoken, d_model, nhead, d_hid, nlayers, dropout=0.5): super().__init__() self.model_type = "Transformer" encoder_layers = torch.nn.TransformerEncoderLayer( d_model, nhead, d_hid, dropout) self.transformer_encoder = torch.nn.TransformerEncoder( encoder_layers, nlayers) self.pos_encoder = PositionalEncoding(d_model, dropout) self.encoder = torch.nn.Embedding(ntoken, d_model) self.d_model = d_model self.decoder = torch.nn.Linear(d_model, ntoken) self.init_weights() def init_weights(self): initrange = 0.1 self.encoder.weight.data.uniform_(-initrange, initrange) self.decoder.bias.data.zero_() self.decoder.weight.data.uniform_(-initrange, initrange) def forward(self, src, src_mask): src = self.encoder(src) * math.sqrt(self.d_model) src = self.pos_encoder(src) output = self.transformer_encoder(src, src_mask) output = self.decoder(output) return output def generate_square_subsequent_mask(sz): return torch.triu(torch.ones(sz, sz) * float("-inf"), diagonal=1) def get_batch(source, i): seq_len = min(bptt, len(source) - 1 - i) data = source[i:i + seq_len] target = source[i + 1:i + 1 + seq_len].reshape(-1) return data, target criterion = torch.nn.CrossEntropyLoss() device = torch.device("cuda" if torch.cuda.is_available() else "cpu") train_data = np.random.randint(1, 12455, 1000) ends = np.random.randint(2, 20, 100).cumsum() ends = ends[ends < train_data.shape[0] - 2] train_data[ends] = 0 train_data[-1] = 0 train_data = torch.tensor(np.array(train_data, dtype=np.int64)) train_data = train_data.to(torch.int64).to(device) bptt = 35 src_mask = generate_square_subsequent_mask(bptt).to(device) ntokens, emsize, nhead, d_hid, nlayers, dropout = 12455, 200, 2, 200, 2, 0.2 pt_model = TransformerModel(ntokens, emsize, nhead, d_hid, nlayers, dropout) model = ORTModule(pt_model).to(device) pt_model.train(is_training) model.train(is_training) optimizer = torch.optim.SGD(model.parameters(), lr=5.0) n_iter = 0 for epoch in range(1, 2): model.train() # turn on train mode num_batches = len(train_data) // bptt for batch, i in enumerate(range(0, train_data.size(0) - 1, bptt)): data, targets = get_batch(train_data, i) batch_size = data.size(0) if batch_size != bptt: # only on last batch src_mask = src_mask[:batch_size, :batch_size] output = model(data, src_mask) nrows = min(ntokens, targets.shape[0]) loss = criterion(output.view(nrows, -1), targets) optimizer.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5) optimizer.step() n_iter += 1 break assert n_iter > 0 del os.environ["ORTMODULE_SKIPCHECK_POLICY"]