def test_torchscript_retain_training_state(): """ Test that torchscript export does not alter the training mode of original model. """ model = EvalModelTemplate() model.train(True) script = model.to_torchscript() assert model.training assert not script.training model.train(False) _ = model.to_torchscript() assert not model.training assert not script.training
def test_torchscript_device(device): """ Test that scripted module is on the correct device. """ model = EvalModelTemplate().to(device) script = model.to_torchscript() assert next(script.parameters()).device == device script_output = script(model.example_input_array.to(device)) assert script_output.device == device