def test_torchcript_invalid_method(tmpdir):
    """Test that an error is thrown with invalid torchscript method"""
    model = BoringModel()
    model.train(True)

    with pytest.raises(ValueError, match="only supports 'script' or 'trace'"):
        model.to_torchscript(method='temp')
def test_torchscript_with_no_input(tmpdir):
    """Test that an error is thrown when there is no input tensor"""
    model = BoringModel()
    model.example_input_array = None

    with pytest.raises(
            ValueError,
            match=
            'requires either `example_inputs` or `model.example_input_array`'):
        model.to_torchscript(method='trace')
def test_torchscript_retain_training_state():
    """ Test that torchscript export does not alter the training mode of original model. """
    model = BoringModel()
    model.train(True)
    script = model.to_torchscript()
    assert model.training
    assert not script.training
    model.train(False)
    _ = model.to_torchscript()
    assert not model.training
    assert not script.training
def test_torchscript_device(device):
    """ Test that scripted module is on the correct device. """
    model = BoringModel().to(device)
    model.example_input_array = torch.randn(5, 32)

    script = model.to_torchscript()
    assert next(script.parameters()).device == device
    script_output = script(model.example_input_array.to(device))
    assert script_output.device == device
def test_torchscript_input_output_trace():
    """ Test that traced LightningModule forward works with example_inputs """
    model = BoringModel()
    example_inputs = torch.randn(1, 32)
    script = model.to_torchscript(example_inputs=example_inputs,
                                  method='trace')
    assert isinstance(script, torch.jit.ScriptModule)

    model.eval()
    with torch.no_grad():
        model_output = model(example_inputs)

    script_output = script(example_inputs)
    assert torch.allclose(script_output, model_output)