def test_torchcript_invalid_method(tmpdir): """Test that an error is thrown with invalid torchscript method""" model = BoringModel() model.train(True) with pytest.raises(ValueError, match="only supports 'script' or 'trace'"): model.to_torchscript(method='temp')
def test_torchscript_with_no_input(tmpdir): """Test that an error is thrown when there is no input tensor""" model = BoringModel() model.example_input_array = None with pytest.raises( ValueError, match= 'requires either `example_inputs` or `model.example_input_array`'): model.to_torchscript(method='trace')
def test_torchscript_retain_training_state(): """ Test that torchscript export does not alter the training mode of original model. """ model = BoringModel() model.train(True) script = model.to_torchscript() assert model.training assert not script.training model.train(False) _ = model.to_torchscript() assert not model.training assert not script.training
def test_torchscript_device(device): """ Test that scripted module is on the correct device. """ model = BoringModel().to(device) model.example_input_array = torch.randn(5, 32) script = model.to_torchscript() assert next(script.parameters()).device == device script_output = script(model.example_input_array.to(device)) assert script_output.device == device
def test_torchscript_input_output_trace(): """ Test that traced LightningModule forward works with example_inputs """ model = BoringModel() example_inputs = torch.randn(1, 32) script = model.to_torchscript(example_inputs=example_inputs, method='trace') assert isinstance(script, torch.jit.ScriptModule) model.eval() with torch.no_grad(): model_output = model(example_inputs) script_output = script(example_inputs) assert torch.allclose(script_output, model_output)