def test_torchcript_invalid_method(tmpdir): """Test that an error is thrown with invalid torchscript method""" model = BoringModel() model.train(True) with pytest.raises(ValueError, match="only supports 'script' or 'trace'"): model.to_torchscript(method='temp')
def test_torchscript_retain_training_state(): """ Test that torchscript export does not alter the training mode of original model. """ model = BoringModel() model.train(True) script = model.to_torchscript() assert model.training assert not script.training model.train(False) _ = model.to_torchscript() assert not model.training assert not script.training