def test_torchscripting_using_trace(self): """ Test that the save_torchscript function works as expected with trace """ config = get_test_task_config() torchscript_folder = self.base_dir + "/torchscript_end_test/" # create a torchscript hook using trace torchscript_hook = TorchscriptHook(torchscript_folder) self.execute_hook(config, torchscript_folder, torchscript_hook)
def test_torchscripting_using_script(self): """ Test that the save_torchscript function works as expected with script """ config = get_test_task_config() # Setting wrapper_cls to None to make ResNet model torchscriptable ResNet.wrapper_cls = None torchscript_folder = self.base_dir + "/torchscript_end_test/" # create a torchscript hook using script torchscript_hook = TorchscriptHook(torchscript_folder, use_trace=False) self.execute_hook(config, torchscript_folder, torchscript_hook)
def test_torchscripting(self): """ Test that the save_torchscript function works as expected. """ config = get_test_task_config() task = build_task(config) task.prepare() torchscript_folder = self.base_dir + "/torchscript_end_test/" # create a torchscript hook torchscript_hook = TorchscriptHook(torchscript_folder) # create checkpoint dir, verify on_start hook runs os.mkdir(torchscript_folder) torchscript_hook.on_start(task) task.train = True # call the on end function torchscript_hook.on_end(task) # load torchscript file torchscript_file_name = ( f"{torchscript_hook.torchscript_folder}/{TORCHSCRIPT_FILE}") torchscript = torch.jit.load(torchscript_file_name) # compare model load from checkpoint vs torchscript with torch.no_grad(): batchsize = 1 model = task.model input_data = torch.randn((batchsize, ) + model.input_shape, dtype=torch.float) if torch.cuda.is_available(): input_data = input_data.cuda() checkpoint_out = model(input_data) torchscript_out = torchscript(input_data) self.assertTrue(torch.allclose(checkpoint_out, torchscript_out))