def test_torchcript_invalid_method(tmpdir):
    """Test that an error is thrown with invalid torchscript method"""
    model = BoringModel()
    model.train(True)

    with pytest.raises(ValueError, match="only supports 'script' or 'trace'"):
        model.to_torchscript(method='temp')
def test_torchscript_retain_training_state():
    """ Test that torchscript export does not alter the training mode of original model. """
    model = BoringModel()
    model.train(True)
    script = model.to_torchscript()
    assert model.training
    assert not script.training
    model.train(False)
    _ = model.to_torchscript()
    assert not model.training
    assert not script.training