def test_if_inference_output_is_valid(tmpdir): """Test that the output inferred from ONNX model is same as from PyTorch""" model = BoringModel() model.example_input_array = torch.randn(5, 32) trainer = Trainer(max_epochs=2) trainer.fit(model) model.eval() with torch.no_grad(): torch_out = model(model.example_input_array) file_path = os.path.join(tmpdir, "model.onnx") model.to_onnx(file_path, model.example_input_array, export_params=True) ort_session = onnxruntime.InferenceSession(file_path) def to_numpy(tensor): return tensor.detach().cpu().numpy( ) if tensor.requires_grad else tensor.cpu().numpy() # compute ONNX Runtime output prediction ort_inputs = { ort_session.get_inputs()[0].name: to_numpy(model.example_input_array) } ort_outs = ort_session.run(None, ort_inputs) # compare ONNX Runtime and PyTorch results assert np.allclose(to_numpy(torch_out), ort_outs[0], rtol=1e-03, atol=1e-05)
def test_model_saves_with_example_output(tmpdir): """Test that ONNX model saves when provided with example output""" model = BoringModel() trainer = Trainer(max_epochs=1) trainer.fit(model) file_path = os.path.join(tmpdir, "model.onnx") input_sample = torch.randn((1, 32)) model.eval() example_outputs = model.forward(input_sample) model.to_onnx(file_path, input_sample, example_outputs=example_outputs) assert os.path.exists(file_path) is True
def test_torchscript_input_output_trace(): """ Test that traced LightningModule forward works with example_inputs """ model = BoringModel() example_inputs = torch.randn(1, 32) script = model.to_torchscript(example_inputs=example_inputs, method='trace') assert isinstance(script, torch.jit.ScriptModule) model.eval() with torch.no_grad(): model_output = model(example_inputs) script_output = script(example_inputs) assert torch.allclose(script_output, model_output)