def test_dunder_package_works_from_package(self): """ The attribute '__torch_package__' should be accessible from within the module itself, so that packaged code can detect whether it's being used in a packaged context or not. """ import package_a.use_dunder_package as mod buffer = BytesIO() with PackageExporter(buffer, verbose=False) as pe: pe.intern("**") pe.save_module(mod.__name__) buffer.seek(0) pi = PackageImporter(buffer) imported_mod = pi.import_module(mod.__name__) self.assertTrue(imported_mod.is_from_package()) self.assertFalse(mod.is_from_package())
def test_mock(self): buffer = BytesIO() with PackageExporter(buffer) as he: he.mock(["package_a.subpackage", "module_a"]) # Import something that dependso n package_a.subpackage he.save_source_string("foo", "import package_a.subpackage") buffer.seek(0) hi = PackageImporter(buffer) import package_a.subpackage _ = package_a.subpackage import module_a _ = module_a m = hi.import_module("package_a.subpackage") r = m.result with self.assertRaisesRegex(NotImplementedError, "was mocked out"): r()
def test_loading_pickle(self): """ Test basic saving and loading of modules and pickles from a DirectoryReader. """ resnet = resnet18() filename = self.temp() with PackageExporter(filename) as e: e.intern("**") e.save_pickle("model", "model.pkl", resnet) zip_file = zipfile.ZipFile(filename, "r") with TemporaryDirectory() as temp_dir: zip_file.extractall(path=temp_dir) importer = PackageImporter(Path(temp_dir) / Path(filename).name) dir_mod = importer.load_pickle("model", "model.pkl") input = torch.rand(1, 3, 224, 224) self.assertEqual(dir_mod(input), resnet(input))
def test_custom_requires(self): filename = self.temp() class Custom(PackageExporter): def require_module(self, name, dependencies): if name == 'module_a': self.save_mock_module('module_a') elif name == 'package_a': self.save_source_string('package_a', 'import module_a\nresult = 5\n') else: raise NotImplementedError('wat') with Custom(filename, verbose=False) as he: he.save_source_string('main', 'import package_a\n') hi = PackageImporter(filename) hi.import_module('module_a').should_be_mocked bar = hi.import_module('package_a') self.assertEqual(bar.result, 5)
def test_mock_glob(self): filename = self.temp() with PackageExporter(filename, verbose=False) as he: he.mock(['package_a.*', 'module*']) he.save_module('package_a') he.save_source_string('test_module', """\ import package_a.subpackage import module_a """) hi = PackageImporter(filename) import package_a.subpackage _ = package_a.subpackage import module_a _ = module_a m = hi.import_module('package_a.subpackage') r = m.result with self.assertRaisesRegex(NotImplementedError, 'was mocked out'): r()
def test_load_shared_tensors(self): """ Test tensors shared across eager and ScriptModules on load are the same. """ from package_a.test_module import ( ModWithTensor, ModWithTwoSubmodsAndTensor, ) shared_tensor = torch.ones(3, 3) scripted_mod_0 = torch.jit.script(ModWithTensor(shared_tensor)) scripted_mod_1 = torch.jit.script(ModWithTensor(shared_tensor)) mod1 = ModWithTwoSubmodsAndTensor(shared_tensor, scripted_mod_0, scripted_mod_1) buffer = BytesIO() with PackageExporter(buffer) as e: e.intern("**") e.save_pickle("res", "mod1.pkl", mod1) buffer.seek(0) importer = PackageImporter(buffer) loaded_mod_1 = importer.load_pickle("res", "mod1.pkl") self.assertTrue( loaded_mod_1.tensor.storage()._cdata, loaded_mod_1.sub_mod_0.tensor.storage()._cdata, ) self.assertTrue( loaded_mod_1.tensor.storage()._cdata, loaded_mod_1.sub_mod_0.tensor.storage()._cdata, ) loaded_mod_1.tensor.add_(torch.ones(3, 3)) self.assertTrue( torch.allclose(loaded_mod_1.tensor, loaded_mod_1.sub_mod_0.tensor)) self.assertTrue( torch.allclose(loaded_mod_1.tensor, loaded_mod_1.sub_mod_1.tensor))
def test_save_imported_module_fails(self): """ Directly saving/requiring an PackageImported module should raise a specific error message. """ import package_a.subpackage obj = package_a.subpackage.PackageASubpackageObject() obj2 = package_a.PackageAObject(obj) f1 = self.temp() with PackageExporter(f1, verbose=False) as pe: pe.intern("**") pe.save_pickle("obj", "obj.pkl", obj) importer1 = PackageImporter(f1) loaded1 = importer1.load_pickle("obj", "obj.pkl") f2 = self.temp() pe = PackageExporter(f2, verbose=False, importer=(importer1, sys_importer)) with self.assertRaisesRegex(ModuleNotFoundError, "torch.package"): pe.save_module(loaded1.__module__)
def test_mock(self): buffer = BytesIO() with PackageExporter(buffer, verbose=False) as he: he.mock(["package_a.subpackage", "module_a"]) he.save_module("package_a") he.require_module("package_a.subpackage") he.require_module("module_a") buffer.seek(0) hi = PackageImporter(buffer) import package_a.subpackage _ = package_a.subpackage import module_a _ = module_a m = hi.import_module("package_a.subpackage") r = m.result with self.assertRaisesRegex(NotImplementedError, "was mocked out"): r()
def test_ordered_importer_basic(self): import package_a buffer = BytesIO() with PackageExporter(buffer) as pe: pe.save_module(package_a.__name__) buffer.seek(0) importer = PackageImporter(buffer) ordered_importer_sys_first = OrderedImporter(sys_importer, importer) self.assertIs(ordered_importer_sys_first.import_module("package_a"), package_a) ordered_importer_package_first = OrderedImporter( importer, sys_importer) self.assertIs( ordered_importer_package_first.import_module("package_a"), importer.import_module("package_a"), )
def test_inspect_class(self): """Should be able to retrieve source for a packaged class.""" import package_a.subpackage buffer = BytesIO() obj = package_a.subpackage.PackageASubpackageObject() with PackageExporter(buffer) as pe: pe.intern("**") pe.save_pickle("obj", "obj.pkl", obj) buffer.seek(0) pi = PackageImporter(buffer) packaged_class = pi.import_module( "package_a.subpackage").PackageASubpackageObject regular_class = package_a.subpackage.PackageASubpackageObject packaged_src = inspect.getsourcelines(packaged_class) regular_src = inspect.getsourcelines(regular_class) self.assertEqual(packaged_src, regular_src)
def test_is_from_package(self): """is_from_package should work for objects and modules""" import package_a.subpackage buffer = BytesIO() obj = package_a.subpackage.PackageASubpackageObject() with PackageExporter(buffer, verbose=False) as pe: pe.save_pickle("obj", "obj.pkl", obj) buffer.seek(0) pi = PackageImporter(buffer) mod = pi.import_module("package_a.subpackage") loaded_obj = pi.load_pickle("obj", "obj.pkl") self.assertFalse(is_from_package(package_a.subpackage)) self.assertTrue(is_from_package(mod)) self.assertFalse(is_from_package(obj)) self.assertTrue(is_from_package(loaded_obj))
def test_custom_requires(self): buffer = BytesIO() class Custom(PackageExporter): def require_module(self, name, dependencies): if name == "module_a": self.save_mock_module("module_a") elif name == "package_a": self.save_source_string("package_a", "import module_a\nresult = 5\n") else: raise NotImplementedError("wat") with Custom(buffer, verbose=False) as he: he.save_source_string("main", "import package_a\n") buffer.seek(0) hi = PackageImporter(buffer) hi.import_module("module_a").should_be_mocked bar = hi.import_module("package_a") self.assertEqual(bar.result, 5)
def test_package_resource_access(self): """Packaged modules should be able to use the importlib.resources API to access resources saved in the package. """ mod_src = dedent("""\ import importlib.resources import my_cool_resources def secret_message(): return importlib.resources.read_text(my_cool_resources, 'sekrit.txt') """) buffer = BytesIO() with PackageExporter(buffer, verbose=False) as pe: pe.save_source_string("foo.bar", mod_src) pe.save_text("my_cool_resources", "sekrit.txt", "my sekrit plays") buffer.seek(0) importer = PackageImporter(buffer) self.assertEqual( importer.import_module("foo.bar").secret_message(), "my sekrit plays")
def test_resource_access_by_path(self): """ Tests that packaged code can used importlib.resources.path. """ buffer = BytesIO() with PackageExporter(buffer) as he: he.save_binary("string_module", "my_string", "my string".encode("utf-8")) src = dedent("""\ import importlib.resources import string_module with importlib.resources.path(string_module, 'my_string') as path: with open(path, mode='r', encoding='utf-8') as f: s = f.read() """) he.save_source_string("main", src, is_package=True) buffer.seek(0) hi = PackageImporter(buffer) m = hi.import_module("main") self.assertEqual(m.s, "my string")
def test_mixing_packaged_and_inline_modules_shared_code(self): """ Test saving inline and imported modules in same package that share code. """ class TorchVisionTestInline(torch.nn.Module): def __init__(self): super().__init__() self.tvmod = resnet18() def forward(self, x): x = a_non_torch_leaf(x, x) return torch.relu(x + 3.0) def a_non_torch_leaf(a, b): return a + b inline_mod = TorchVisionTestInline() scripted_inline = torch.jit.script(inline_mod) from package_c.test_module import TorchVisionTest imported_mod = TorchVisionTest() scripted_imported = torch.jit.script(imported_mod) buffer = BytesIO() with PackageExporter(buffer, verbose=False) as e: e.save_pickle("model", "inline.pkl", scripted_inline) e.save_pickle("model", "imported.pkl", scripted_imported) buffer.seek(0) importer = PackageImporter(buffer) loaded_inline = importer.load_pickle("model", "inline.pkl") loaded_imported = importer.load_pickle("model", "imported.pkl") input = torch.rand(2, 3) self.assertTrue( torch.allclose(loaded_imported(input), imported_mod(input))) self.assertTrue(torch.allclose(loaded_inline(input), inline_mod(input)))
def test_package_importer_whichmodule_no_dunder_module(self): """Exercise corner case where we try to pickle an object whose __module__ doesn't exist because it's from a C extension. """ # torch.float16 is an example of such an object: it is a C extension # type for which there is no __module__ defined. The default pickler # finds it using special logic to traverse sys.modules and look up # `float16` on each module (see pickle.py:whichmodule). # # We must ensure that we emulate the same behavior from PackageImporter. my_dtype = torch.float16 # Set up a PackageImporter which has a torch.float16 object pickled: buffer = BytesIO() with PackageExporter(buffer) as exporter: exporter.save_pickle("foo", "foo.pkl", my_dtype) buffer.seek(0) importer = PackageImporter(buffer) my_loaded_dtype = importer.load_pickle("foo", "foo.pkl") # Re-save a package with only our PackageImporter as the importer buffer2 = BytesIO() with PackageExporter(buffer2, importer=importer) as exporter: exporter.save_pickle("foo", "foo.pkl", my_loaded_dtype) buffer2.seek(0) importer2 = PackageImporter(buffer2) my_loaded_dtype2 = importer2.load_pickle("foo", "foo.pkl") self.assertIs(my_dtype, my_loaded_dtype) self.assertIs(my_dtype, my_loaded_dtype2)
def test_saving_and_scripting_packaged_mod(self): """ Test scripting a module loaded from a package and saving it in a new package as a script object. """ from package_a.test_module import SimpleTest orig_mod = SimpleTest() buffer_0 = BytesIO() with PackageExporter(buffer_0) as e: e.intern("**") e.save_pickle("model", "model.pkl", orig_mod) buffer_0.seek(0) importer_0 = PackageImporter(buffer_0) loaded_mod = importer_0.load_pickle("model", "model.pkl") input = torch.rand(2, 3) self.assertEqual(loaded_mod(input), orig_mod(input)) scripted_mod = torch.jit.script(loaded_mod) buffer_1 = BytesIO() with PackageExporter(buffer_1, importer=importer_0) as e: e.intern("**") e.save_pickle("res", "scripted_mod.pkl", scripted_mod) buffer_1.seek(0) importer_1 = PackageImporter(buffer_1) loaded_mod_scripted = importer_1.load_pickle("res", "scripted_mod.pkl") self.assertEqual(loaded_mod_scripted(input), orig_mod(input))
def test_repackage_mocked_module(self): """Re-packaging a package that contains a mocked module should work correctly.""" buffer = BytesIO() with PackageExporter(buffer) as exporter: exporter.mock("package_a") exporter.save_source_string("foo", "import package_a") buffer.seek(0) importer = PackageImporter(buffer) foo = importer.import_module("foo") # "package_a" should be mocked out. with self.assertRaises(NotImplementedError): foo.package_a.get_something() # Re-package the model, but intern the previously-mocked module and mock # everything else. buffer2 = BytesIO() with PackageExporter(buffer2, importer=importer) as exporter: exporter.intern("package_a") exporter.mock("**") exporter.save_source_string("foo", "import package_a") buffer2.seek(0) importer2 = PackageImporter(buffer2) foo2 = importer2.import_module("foo") # "package_a" should still be mocked out. with self.assertRaises(NotImplementedError): foo2.package_a.get_something()
def test_extern_glob(self): filename = self.temp() with PackageExporter(filename, verbose=False) as he: he.extern(["package_a.*", "module_*"]) he.save_module("package_a") he.save_source_string( "test_module", dedent( """\ import package_a.subpackage import module_a """ ), ) hi = PackageImporter(filename) import module_a import package_a.subpackage module_a_im = hi.import_module("module_a") hi.import_module("package_a.subpackage") package_a_im = hi.import_module("package_a") self.assertIs(module_a, module_a_im) self.assertIsNot(package_a, package_a_im) self.assertIs(package_a.subpackage, package_a_im.subpackage)
def test_package_fx_package(self): from package_a.test_module import SimpleTest model = SimpleTest() f = BytesIO() with PackageExporter(f) as pe: pe.intern("**") pe.save_pickle("model", "model.pkl", model) f.seek(0) pi = PackageImporter(f) loaded = pi.load_pickle("model", "model.pkl") traced = symbolic_trace(loaded) # re-save the package exporter f2 = BytesIO() # This should fail, because we are referencing some globals that are # only in the package. with self.assertRaises(ObjMismatchError): with PackageExporter(f2) as pe: pe.intern("**") pe.save_pickle("model", "model.pkl", traced) f2.seek(0) with PackageExporter(f2, importer=(pi, sys_importer)) as pe: # Make the package available to the exporter's environment. pe.intern("**") pe.save_pickle("model", "model.pkl", traced) f2.seek(0) pi2 = PackageImporter(f2) loaded2 = pi2.load_pickle("model", "model.pkl") input = torch.rand(2, 3) self.assertTrue(torch.allclose(loaded(input), loaded2(input)))
def test_extern_glob(self): buffer = BytesIO() with PackageExporter(buffer, verbose=False) as he: he.extern(["package_a.*", "module_*"]) he.save_module("package_a") he.save_source_string( "test_module", dedent( """\ import package_a.subpackage import module_a """ ), ) buffer.seek(0) hi = PackageImporter(buffer) import module_a import package_a.subpackage module_a_im = hi.import_module("module_a") hi.import_module("package_a.subpackage") package_a_im = hi.import_module("package_a") self.assertIs(module_a, module_a_im) self.assertIsNot(package_a, package_a_im) self.assertIs(package_a.subpackage, package_a_im.subpackage)
def test_externing_c_extension(self): """Externing c extensions modules should allow us to still access them especially those found in torch._C.""" buffer = BytesIO() # The C extension module in question is F.gelu which comes from torch._C._nn model = torch.nn.TransformerEncoderLayer( d_model=64, nhead=2, dim_feedforward=64, dropout=1.0, batch_first=True, activation="gelu", norm_first=True, ) with PackageExporter(buffer) as e: e.extern("torch.**") e.intern("**") e.save_pickle("model", "model.pkl", model) buffer.seek(0) imp = PackageImporter(buffer) imp.load_pickle("model", "model.pkl")
def test_save_independent_scriptmodules(self): """ Test to verify saving multiple ScriptModules with completely separate code works. """ from package_a.test_module import ModWithTensor, SimpleTest scripted_mod_0 = torch.jit.script(SimpleTest()) scripted_mod_1 = torch.jit.script(ModWithTensor(torch.rand(1, 2, 3))) buffer = BytesIO() with PackageExporter(buffer) as e: e.save_pickle("res", "mod1.pkl", scripted_mod_0) e.save_pickle("res", "mod2.pkl", scripted_mod_1) buffer.seek(0) importer = PackageImporter(buffer) loaded_mod_0 = importer.load_pickle("res", "mod1.pkl") loaded_mod_1 = importer.load_pickle("res", "mod2.pkl") input = torch.rand(1, 2, 3) self.assertEqual(loaded_mod_0(input), scripted_mod_0(input)) self.assertEqual(loaded_mod_1(input), scripted_mod_1(input))
def test_package_script_class_referencing_self(self): import package_a.fake_script_class as fake obj = fake.UsesIdListFeature() # intentionally script here to fill the compilation cache, to make sure # there is no false sharing between scripted types coming from the # package vs. outside environment. torch.jit.script(obj) buffer = BytesIO() with PackageExporter(buffer) as exporter: exporter.intern("**") exporter.save_pickle("obj", "obj.pkl", obj) buffer.seek(0) importer = PackageImporter(buffer) obj_loaded = importer.load_pickle("obj", "obj.pkl") scripted_obj_loaded = torch.jit.script(obj_loaded) # Make sure the scripted object can be serialized without error. buffer2 = scripted_obj_loaded.save_to_buffer() torch.jit.load(BytesIO(buffer2))
def test_save_scriptmodules_in_container(self): """ Test saving of ScriptModules inside of container. Checks that relations between shared modules are upheld. """ from package_a.test_module import ModWithSubmodAndTensor, ModWithTensor scripted_mod_a = torch.jit.script(ModWithTensor(torch.rand(1, 2, 3))) scripted_mod_b = torch.jit.script( ModWithSubmodAndTensor(torch.rand(1, 2, 3), scripted_mod_a)) script_mods_list = [scripted_mod_a, scripted_mod_b] buffer = BytesIO() with PackageExporter(buffer) as e: e.save_pickle("res", "list.pkl", script_mods_list) buffer.seek(0) importer = PackageImporter(buffer) loaded_mod_list = importer.load_pickle("res", "list.pkl") input = torch.rand(1, 2, 3) self.assertEqual(loaded_mod_list[0](input), scripted_mod_a(input)) self.assertEqual(loaded_mod_list[1](input), scripted_mod_b(input))
def test_single_ordered_importer(self): import package_a import module_a # noqa: F401 buffer = BytesIO() with PackageExporter(buffer, verbose=False) as pe: pe.save_module(package_a.__name__) buffer.seek(0) importer = PackageImporter(buffer) # Construct an importer-only environment. ordered_importer = OrderedImporter(importer) # The module returned by this environment should be the same one that's # in the importer. self.assertIs(ordered_importer.import_module('package_a'), importer.import_module('package_a')) # It should not be the one available in the outer Python environment. self.assertIsNot(ordered_importer.import_module('package_a'), package_a) # We didn't package this module, so it should not be available. with self.assertRaises(ModuleNotFoundError): ordered_importer.import_module('module_a')
def test_load_shared_tensors_repackaged(self): """ Test tensors shared across eager and ScriptModules on load are the same across multiple package saves and loads. This is an important test because not all of the tensor information is restored in python between packages. The python identity is not maintained, but the backing cpp TensorImpl is. We load/save storages based off of this cpp TensorImpl and not the python identity. """ from package_a.test_module import ( ModWithTensor, ModWithTwoSubmodsAndTensor, ) shared_tensor = torch.ones(3, 3) scripted_mod_0 = torch.jit.script(ModWithTensor(shared_tensor)) scripted_mod_1 = torch.jit.script(ModWithTensor(shared_tensor)) mod1 = ModWithTwoSubmodsAndTensor(shared_tensor, scripted_mod_0, scripted_mod_1) buffer_0 = BytesIO() with PackageExporter(buffer_0) as e: e.intern("**") e.save_pickle("res", "mod1.pkl", mod1) buffer_0.seek(0) importer_0 = PackageImporter(buffer_0) loaded_mod_0 = importer_0.load_pickle("res", "mod1.pkl") buffer_1 = BytesIO() with PackageExporter(buffer_1, importer=importer_0) as e: e.intern("**") e.save_pickle("res", "mod1.pkl", loaded_mod_0) buffer_1.seek(0) importer = PackageImporter(buffer_1) loaded_mod_1 = importer.load_pickle("res", "mod1.pkl") self.assertEqual( loaded_mod_1.tensor.storage()._cdata, loaded_mod_1.sub_mod_0.tensor.storage()._cdata, ) self.assertEqual( loaded_mod_1.tensor.storage()._cdata, loaded_mod_1.sub_mod_1.tensor.storage()._cdata, ) loaded_mod_1.tensor.add_( torch.ones(3, 3) ) # all tensors should reflect this change self.assertTrue( torch.allclose(loaded_mod_1.tensor, loaded_mod_1.sub_mod_0.tensor) ) self.assertTrue( torch.allclose(loaded_mod_1.tensor, loaded_mod_1.sub_mod_1.tensor) )
def test_package_fx_custom_tracer(self): from package_a.test_all_leaf_modules_tracer import TestAllLeafModulesTracer from package_a.test_module import SimpleTest, ModWithTwoSubmodsAndTensor class SpecialGraphModule(torch.fx.GraphModule): def __init__(self, root, graph, info): super().__init__(root, graph) self.info = info sub_module = SimpleTest() module = ModWithTwoSubmodsAndTensor( torch.ones(3), sub_module, sub_module, ) tracer = TestAllLeafModulesTracer() graph = tracer.trace(module) self.assertEqual(graph._tracer_cls, TestAllLeafModulesTracer) gm = SpecialGraphModule(module, graph, "secret") self.assertEqual(gm._tracer_cls, TestAllLeafModulesTracer) f = BytesIO() with PackageExporter(f) as pe: pe.intern("**") pe.save_pickle("model", "model.pkl", gm) f.seek(0) pi = PackageImporter(f) loaded_gm = pi.load_pickle("model", "model.pkl") self.assertEqual( type(loaded_gm).__class__.__name__, SpecialGraphModule.__class__.__name__ ) self.assertEqual(loaded_gm.info, "secret") input_x = torch.randn(3) self.assertEqual(loaded_gm(input_x), gm(input_x))
def test_resnet(self): resnet = resnet18() f1 = self.temp() # create a package that will save it along with its code with PackageExporter(f1) as e: # put the pickled resnet in the package, by default # this will also save all the code files references by # the objects in the pickle e.intern("**") e.save_pickle("model", "model.pkl", resnet) # we can now load the saved model i = PackageImporter(f1) r2 = i.load_pickle("model", "model.pkl") # test that it works input = torch.rand(1, 3, 224, 224) ref = resnet(input) self.assertEqual(r2(input), ref) # functions exist also to get at the private modules in each package torchvision = i.import_module("torchvision") f2 = BytesIO() # if we are doing transfer learning we might want to re-save # things that were loaded from a package. # We need to tell the exporter about any modules that # came from imported packages so that it can resolve # class names like torchvision.models.resnet.ResNet # to their source code. with PackageExporter(f2, importer=(i, sys_importer)) as e: # e.importers is a list of module importing functions # that by default contains importlib.import_module. # it is searched in order until the first success and # that module is taken to be what torchvision.models.resnet # should be in this code package. In the case of name collisions, # such as trying to save a ResNet from two different packages, # we take the first thing found in the path, so only ResNet objects from # one importer will work. This avoids a bunch of name mangling in # the source code. If you need to actually mix ResNet objects, # we suggest reconstructing the model objects using code from a single package # using functions like save_state_dict and load_state_dict to transfer state # to the correct code objects. e.intern("**") e.save_pickle("model", "model.pkl", r2) f2.seek(0) i2 = PackageImporter(f2) r3 = i2.load_pickle("model", "model.pkl") self.assertEqual(r3(input), ref)
def test_save_repeat_scriptmodules(self): """ Test to verify saving multiple different modules and repeats of same scriptmodule in package works. Also tests that PyTorchStreamReader isn't having code hidden from PyTorchStreamWriter writing ScriptModule code files multiple times. """ from package_a.test_module import ( ModWithSubmodAndTensor, ModWithTensor, SimpleTest, ) scripted_mod_0 = torch.jit.script(SimpleTest()) scripted_mod_1 = torch.jit.script(ModWithTensor(torch.rand(1, 2, 3))) scripted_mod_2 = torch.jit.script( ModWithSubmodAndTensor( torch.rand(1, 2, 3), ModWithTensor(torch.rand(1, 2, 3)) ) ) buffer = BytesIO() with PackageExporter(buffer) as e: e.save_pickle("res", "mod0.pkl", scripted_mod_0) e.save_pickle("res", "mod1.pkl", scripted_mod_1) e.save_pickle("res", "mod2.pkl", scripted_mod_0) e.save_pickle("res", "mod3.pkl", scripted_mod_1) e.save_pickle("res", "mod4.pkl", scripted_mod_2) buffer.seek(0) importer = PackageImporter(buffer) loaded_mod_0 = importer.load_pickle("res", "mod0.pkl") loaded_mod_1 = importer.load_pickle("res", "mod3.pkl") loaded_mod_2 = importer.load_pickle("res", "mod4.pkl") input = torch.rand(1, 2, 3) self.assertEqual(loaded_mod_0(input), scripted_mod_0(input)) self.assertEqual(loaded_mod_1(input), scripted_mod_1(input)) self.assertEqual(loaded_mod_2(input), scripted_mod_2(input))