def test_save_shared_tensors(self): """ Test tensors shared across eager and ScriptModules are serialized once. """ from package_a.test_module import ModWithSubmodAndTensor, ModWithTensor shared_tensor = torch.rand(2, 3, 4) scripted_mod = torch.jit.script(ModWithTensor(shared_tensor)) mod1 = ModWithSubmodAndTensor(shared_tensor, scripted_mod) mod2 = ModWithSubmodAndTensor(shared_tensor, scripted_mod) buffer = BytesIO() with PackageExporter(buffer) as e: e.intern("**") e.save_pickle("res", "tensor", shared_tensor) e.save_pickle("res", "mod1.pkl", mod1) e.save_pickle("res", "mod2.pkl", mod2) buffer.seek(0) importer = PackageImporter(buffer) loaded_mod_1 = importer.load_pickle("res", "mod1.pkl") # assert that there is only one storage stored in package file_structure = importer.file_structure(include=".data/*.storage") self.assertTrue(len(file_structure.children[".data"].children) == 1) input = torch.rand(2, 3, 4) self.assertEqual(loaded_mod_1(input), mod1(input))
def test_file_structure(self): filename = self.temp() export_plain = """\ ├── main │ └── main ├── obj │ └── obj.pkl ├── package_a │ ├── __init__.py │ └── subpackage.py └── module_a.py """ export_include = """\ ├── obj │ └── obj.pkl └── package_a └── subpackage.py """ import_exclude = """\ ├── .data │ ├── extern_modules │ └── version ├── main │ └── main ├── obj │ └── obj.pkl ├── package_a │ ├── __init__.py │ └── subpackage.py └── module_a.py """ with PackageExporter(filename, verbose=False) as he: import module_a import package_a import package_a.subpackage obj = package_a.subpackage.PackageASubpackageObject() he.save_module(module_a.__name__) he.save_module(package_a.__name__) he.save_pickle('obj', 'obj.pkl', obj) he.save_text('main', 'main', "my string") export_file_structure = he.file_structure() # remove first line from testing because WINDOW/iOS/Unix treat the filename differently self.assertEqual( '\n'.join(str(export_file_structure).split('\n')[1:]), export_plain) export_file_structure = he.file_structure( include=["**/subpackage.py", "**/*.pkl"]) self.assertEqual( '\n'.join(str(export_file_structure).split('\n')[1:]), export_include) hi = PackageImporter(filename) import_file_structure = hi.file_structure(exclude="**/*.storage") self.assertEqual('\n'.join(str(import_file_structure).split('\n')[1:]), import_exclude)
def test_file_structure_has_file(self): """ Test Directory's has_file() method. """ buffer = BytesIO() with PackageExporter(buffer, verbose=False) as he: import package_a.subpackage obj = package_a.subpackage.PackageASubpackageObject() he.save_pickle("obj", "obj.pkl", obj) buffer.seek(0) importer = PackageImporter(buffer) file_structure = importer.file_structure() self.assertTrue(file_structure.has_file("package_a/subpackage.py")) self.assertFalse(file_structure.has_file("package_a/subpackage"))
def test_save_eager_mods_sharing_scriptmodule(self): """ Test saving of single ScriptModule shared by multiple eager modules (ScriptModule should be saved just once even though is contained in multiple pickles). """ from package_a.test_module import ModWithSubmod, SimpleTest scripted_mod = torch.jit.script(SimpleTest()) mod1 = ModWithSubmod(scripted_mod) mod2 = ModWithSubmod(scripted_mod) buffer = BytesIO() with PackageExporter(buffer) as e: e.intern("**") e.save_pickle("res", "mod1.pkl", mod1) e.save_pickle("res", "mod2.pkl", mod2) buffer.seek(0) importer = PackageImporter(buffer) file_structure = importer.file_structure() self.assertTrue(file_structure.has_file(".data/ts_code/0")) self.assertFalse(file_structure.has_file(".data/ts_code/1"))
def test_save_scriptmodule_only_necessary_code(self): """ Test to verify when saving multiple packages with same CU that packages don't include unnecessary torchscript code files. The TorchVision code should only be saved in the package that relies on it. """ from package_a.test_module import ModWithTensor class ModWithTorchVision(torch.nn.Module): def __init__(self, name: str): super().__init__() self.tvmod = resnet18() def forward(self, input): return input * 4 scripted_mod_0 = torch.jit.script(ModWithTorchVision("foo")) scripted_mod_1 = torch.jit.script(ModWithTensor(torch.rand(1, 2, 3))) buffer_0 = BytesIO() with PackageExporter(buffer_0) as e: e.save_pickle("res", "mod1.pkl", scripted_mod_0) buffer_0.seek(0) importer_0 = importer = PackageImporter(buffer_0) buffer_1 = BytesIO() with PackageExporter(buffer_1) as e: e.save_pickle("res", "mod1.pkl", scripted_mod_1) buffer_1.seek(0) importer_1 = PackageImporter(buffer_1) self.assertTrue("torchvision" in str(importer_0.file_structure())) self.assertFalse("torchvision" in str(importer_1.file_structure()))
def test_file_structure(self): """ Tests package's Folder structure representation of a zip file. Ensures that the returned Folder prints what is expected and filters inputs/outputs correctly. """ buffer = BytesIO() export_plain = dedent("""\ ├── main │ └── main ├── obj │ └── obj.pkl ├── package_a │ ├── __init__.py │ └── subpackage.py └── module_a.py """) export_include = dedent("""\ ├── obj │ └── obj.pkl └── package_a └── subpackage.py """) import_exclude = dedent("""\ ├── .data │ ├── extern_modules │ └── version ├── main │ └── main ├── obj │ └── obj.pkl ├── package_a │ ├── __init__.py │ └── subpackage.py └── module_a.py """) with PackageExporter(buffer, verbose=False) as he: import module_a import package_a import package_a.subpackage obj = package_a.subpackage.PackageASubpackageObject() he.save_module(module_a.__name__) he.save_module(package_a.__name__) he.save_pickle("obj", "obj.pkl", obj) he.save_text("main", "main", "my string") export_file_structure = he.file_structure() # remove first line from testing because WINDOW/iOS/Unix treat the buffer differently self.assertEqual( dedent("\n".join(str(export_file_structure).split("\n")[1:])), export_plain, ) export_file_structure = he.file_structure( include=["**/subpackage.py", "**/*.pkl"]) self.assertEqual( dedent("\n".join(str(export_file_structure).split("\n")[1:])), export_include, ) buffer.seek(0) hi = PackageImporter(buffer) import_file_structure = hi.file_structure(exclude="**/*.storage") self.assertEqual( dedent("\n".join(str(import_file_structure).split("\n")[1:])), import_exclude, )