def test_save_shared_tensors(self): """ Test tensors shared across eager and ScriptModules are serialized once. """ from package_a.test_module import ModWithSubmodAndTensor, ModWithTensor shared_tensor = torch.rand(2, 3, 4) scripted_mod = torch.jit.script(ModWithTensor(shared_tensor)) mod1 = ModWithSubmodAndTensor(shared_tensor, scripted_mod) mod2 = ModWithSubmodAndTensor(shared_tensor, scripted_mod) buffer = BytesIO() with PackageExporter(buffer) as e: e.intern("**") e.save_pickle("res", "tensor", shared_tensor) e.save_pickle("res", "mod1.pkl", mod1) e.save_pickle("res", "mod2.pkl", mod2) buffer.seek(0) importer = PackageImporter(buffer) loaded_mod_1 = importer.load_pickle("res", "mod1.pkl") # assert that there is only one storage stored in package file_structure = importer.file_structure(include=".data/*.storage") self.assertTrue(len(file_structure.children[".data"].children) == 1) input = torch.rand(2, 3, 4) self.assertEqual(loaded_mod_1(input), mod1(input))
def test_scriptmodules_repeat_save(self): """ Test to verify saving and loading same ScriptModule object works across multiple packages. """ from package_a.test_module import ModWithSubmodAndTensor, ModWithTensor scripted_mod_0 = torch.jit.script(ModWithTensor(torch.rand(1, 2, 3))) scripted_mod_1 = torch.jit.script( ModWithSubmodAndTensor(torch.rand(1, 2, 3), ModWithTensor(torch.rand(1, 2, 3)))) buffer_0 = BytesIO() with PackageExporter(buffer_0) as e: e.save_pickle("res", "mod1.pkl", scripted_mod_0) buffer_0.seek(0) importer_0 = PackageImporter(buffer_0) loaded_module_0 = importer_0.load_pickle("res", "mod1.pkl") buffer_1 = BytesIO() with PackageExporter(buffer_1) as e: e.save_pickle("res", "mod1.pkl", scripted_mod_1) e.save_pickle("res", "mod2.pkl", loaded_module_0) buffer_1.seek(0) importer_1 = PackageImporter(buffer_1) loaded_module_1 = importer_1.load_pickle("res", "mod1.pkl") reloaded_module_0 = importer_1.load_pickle("res", "mod2.pkl") input = torch.rand(1, 2, 3) self.assertEqual(loaded_module_0(input), scripted_mod_0(input)) self.assertEqual(loaded_module_0(input), reloaded_module_0(input)) self.assertEqual(loaded_module_1(input), scripted_mod_1(input))
def test_save_scriptmodules_in_container(self): """ Test saving of ScriptModules inside of container. Checks that relations between shared modules are upheld. """ from package_a.test_module import ModWithSubmodAndTensor, ModWithTensor scripted_mod_a = torch.jit.script(ModWithTensor(torch.rand(1, 2, 3))) scripted_mod_b = torch.jit.script( ModWithSubmodAndTensor(torch.rand(1, 2, 3), scripted_mod_a)) script_mods_list = [scripted_mod_a, scripted_mod_b] buffer = BytesIO() with PackageExporter(buffer) as e: e.save_pickle("res", "list.pkl", script_mods_list) buffer.seek(0) importer = PackageImporter(buffer) loaded_mod_list = importer.load_pickle("res", "list.pkl") input = torch.rand(1, 2, 3) self.assertEqual(loaded_mod_list[0](input), scripted_mod_a(input)) self.assertEqual(loaded_mod_list[1](input), scripted_mod_b(input))
def test_save_repeat_scriptmodules(self): """ Test to verify saving multiple different modules and repeats of same scriptmodule in package works. Also tests that PyTorchStreamReader isn't having code hidden from PyTorchStreamWriter writing ScriptModule code files multiple times. """ from package_a.test_module import ( ModWithSubmodAndTensor, ModWithTensor, SimpleTest, ) scripted_mod_0 = torch.jit.script(SimpleTest()) scripted_mod_1 = torch.jit.script(ModWithTensor(torch.rand(1, 2, 3))) scripted_mod_2 = torch.jit.script( ModWithSubmodAndTensor( torch.rand(1, 2, 3), ModWithTensor(torch.rand(1, 2, 3)) ) ) buffer = BytesIO() with PackageExporter(buffer) as e: e.save_pickle("res", "mod0.pkl", scripted_mod_0) e.save_pickle("res", "mod1.pkl", scripted_mod_1) e.save_pickle("res", "mod2.pkl", scripted_mod_0) e.save_pickle("res", "mod3.pkl", scripted_mod_1) e.save_pickle("res", "mod4.pkl", scripted_mod_2) buffer.seek(0) importer = PackageImporter(buffer) loaded_mod_0 = importer.load_pickle("res", "mod0.pkl") loaded_mod_1 = importer.load_pickle("res", "mod3.pkl") loaded_mod_2 = importer.load_pickle("res", "mod4.pkl") input = torch.rand(1, 2, 3) self.assertEqual(loaded_mod_0(input), scripted_mod_0(input)) self.assertEqual(loaded_mod_1(input), scripted_mod_1(input)) self.assertEqual(loaded_mod_2(input), scripted_mod_2(input))