def test_package_interface(self): """Packaging an interface class should work correctly.""" import package_a.fake_interface as fake uses_interface = fake.UsesInterface() scripted = torch.jit.script(uses_interface) scripted.proxy_mod = torch.jit.script(fake.NewModule()) buffer = BytesIO() with PackageExporter(buffer, verbose=False) as pe: pe.save_pickle("model", "model.pkl", uses_interface) buffer.seek(0) package_importer = PackageImporter(buffer) loaded = package_importer.load_pickle("model", "model.pkl") scripted_loaded = torch.jit.script(loaded) scripted_loaded.proxy_mod = torch.jit.script(fake.NewModule()) input = torch.tensor(1) self.assertTrue(torch.allclose(scripted(input), scripted_loaded(input)))
def test_save_imported_module_fails(self): """ Directly saving/requiring an PackageImported module should raise a specific error message. """ import package_a.subpackage obj = package_a.subpackage.PackageASubpackageObject() obj2 = package_a.PackageAObject(obj) f1 = self.temp() with PackageExporter(f1, verbose=False) as pe: pe.save_pickle("obj", "obj.pkl", obj) importer1 = PackageImporter(f1) loaded1 = importer1.load_pickle("obj", "obj.pkl") f2 = self.temp() pe = PackageExporter(f2, verbose=False, importer=(importer1, sys_importer)) with self.assertRaisesRegex(ModuleNotFoundError, "torch.package"): pe.require_module(loaded1.__module__) with self.assertRaisesRegex(ModuleNotFoundError, "torch.package"): pe.save_module(loaded1.__module__)
def test_resnet(self): resnet = resnet18() f1 = self.temp() # create a package that will save it along with its code with PackageExporter(f1, verbose=False) as e: # put the pickled resnet in the package, by default # this will also save all the code files references by # the objects in the pickle e.save_pickle("model", "model.pkl", resnet) # check th debug graph has something reasonable: buf = StringIO() debug_graph = e._write_dep_graph(failing_module="torch") self.assertIn("torchvision.models.resnet", debug_graph) # we can now load the saved model i = PackageImporter(f1) r2 = i.load_pickle("model", "model.pkl") # test that it works input = torch.rand(1, 3, 224, 224) ref = resnet(input) self.assertTrue(torch.allclose(r2(input), ref)) # functions exist also to get at the private modules in each package torchvision = i.import_module("torchvision") f2 = BytesIO() # if we are doing transfer learning we might want to re-save # things that were loaded from a package. # We need to tell the exporter about any modules that # came from imported packages so that it can resolve # class names like torchvision.models.resnet.ResNet # to their source code. with PackageExporter(f2, verbose=False, importer=(i, sys_importer)) as e: # e.importers is a list of module importing functions # that by default contains importlib.import_module. # it is searched in order until the first success and # that module is taken to be what torchvision.models.resnet # should be in this code package. In the case of name collisions, # such as trying to save a ResNet from two different packages, # we take the first thing found in the path, so only ResNet objects from # one importer will work. This avoids a bunch of name mangling in # the source code. If you need to actually mix ResNet objects, # we suggest reconstructing the model objects using code from a single package # using functions like save_state_dict and load_state_dict to transfer state # to the correct code objects. e.save_pickle("model", "model.pkl", r2) f2.seek(0) i2 = PackageImporter(f2) r3 = i2.load_pickle("model", "model.pkl") self.assertTrue(torch.allclose(r3(input), ref)) # test we can load from a directory import zipfile zf = zipfile.ZipFile(f1, "r") with TemporaryDirectory() as td: zf.extractall(path=td) iz = PackageImporter(str(Path(td) / Path(f1).name)) r4 = iz.load_pickle("model", "model.pkl") self.assertTrue(torch.allclose(r4(input), ref))
def test_model_save(self): # This example shows how you might package a model # so that the creator of the model has flexibility about # how they want to save it but the 'server' can always # use the same API to load the package. # The convension is for each model to provide a # 'model' package with a 'load' function that actual # reads the model out of the archive. # How the load function is implemented is up to the # the packager. # get our normal torchvision resnet resnet = resnet18() f1 = BytesIO() # Option 1: save by pickling the whole model # + single-line, similar to torch.jit.save # - more difficult to edit the code after the model is created with PackageExporter(f1, verbose=False) as e: e.save_pickle("model", "pickled", resnet) # note that this source is the same for all models in this approach # so it can be made part of an API that just takes the model and # packages it with this source. src = dedent("""\ import importlib import torch_package_importer as resources # server knows to call model.load() to get the model, # maybe in the future it passes options as arguments by convension def load(): return resources.load_pickle('model', 'pickled') """) e.save_source_string("model", src, is_package=True) f2 = BytesIO() # Option 2: save with state dict # - more code to write to save/load the model # + but this code can be edited later to adjust adapt the model later with PackageExporter(f2, verbose=False) as e: e.save_pickle("model", "state_dict", resnet.state_dict()) src = dedent("""\ import importlib import torch_package_importer as resources from torchvision.models.resnet import resnet18 def load(): # if you want, you can later edit how resnet is constructed here # to edit the model in the package, while still loading the original # state dict weights r = resnet18() state_dict = resources.load_pickle('model', 'state_dict') r.load_state_dict(state_dict) return r """) e.save_source_string("model", src, is_package=True) # regardless of how we chose to package, we can now use the model in a server in the same way input = torch.rand(1, 3, 224, 224) results = [] for m in [f1, f2]: m.seek(0) importer = PackageImporter(m) the_model = importer.import_module("model").load() r = the_model(input) results.append(r) self.assertTrue(torch.allclose(*results))
def test_file_structure(self): """ Tests package's Folder structure representation of a zip file. Ensures that the returned Folder prints what is expected and filters inputs/outputs correctly. """ buffer = BytesIO() export_plain = dedent("""\ ├── main │ └── main ├── obj │ └── obj.pkl ├── package_a │ ├── __init__.py │ └── subpackage.py └── module_a.py """) export_include = dedent("""\ ├── obj │ └── obj.pkl └── package_a └── subpackage.py """) import_exclude = dedent("""\ ├── .data │ ├── extern_modules │ └── version ├── main │ └── main ├── obj │ └── obj.pkl ├── package_a │ ├── __init__.py │ └── subpackage.py └── module_a.py """) with PackageExporter(buffer, verbose=False) as he: import module_a import package_a import package_a.subpackage obj = package_a.subpackage.PackageASubpackageObject() he.save_module(module_a.__name__) he.save_module(package_a.__name__) he.save_pickle("obj", "obj.pkl", obj) he.save_text("main", "main", "my string") export_file_structure = he.file_structure() # remove first line from testing because WINDOW/iOS/Unix treat the buffer differently self.assertEqual( dedent("\n".join(str(export_file_structure).split("\n")[1:])), export_plain, ) export_file_structure = he.file_structure( include=["**/subpackage.py", "**/*.pkl"]) self.assertEqual( dedent("\n".join(str(export_file_structure).split("\n")[1:])), export_include, ) buffer.seek(0) hi = PackageImporter(buffer) import_file_structure = hi.file_structure(exclude="**/*.storage") self.assertEqual( dedent("\n".join(str(import_file_structure).split("\n")[1:])), import_exclude, )
parser = argparse.ArgumentParser(description="Generate Examples") parser.add_argument("--install_dir", help="Root directory for all output files") if __name__ == "__main__": args = parser.parse_args() if args.install_dir is None: p = Path(__file__).parent / "generated" p.mkdir(exist_ok=True) else: p = Path(args.install_dir) resnet = resnet18() resnet.eval() resnet_eg = torch.rand(1, 3, 224, 224) resnet_traced = torch.jit.trace(resnet, resnet_eg) save("resnet", resnet, resnet_traced, (resnet_eg, )) simple = Simple(10, 20) save("simple", simple, torch.jit.script(simple), (torch.rand(10, 20), )) multi_return = MultiReturn() save("multi_return", multi_return, torch.jit.script(multi_return), (torch.rand(10, 20), ), multi_return_metadata) with PackageExporter(str(p / "load_library")) as e: e.mock("iopath.**") e.intern("**") e.save_pickle("fn", "fn.pkl", load_library)
def make_exporter(): pe = PackageExporter(f2, verbose=False) # Ensure that the importer finds the 'PackageAObject' defined in 'importer1' first. pe.importers.insert(0, importer1.import_module) return pe
def make_exporter(): pe = PackageExporter(f2, importer=[importer1, sys_importer]) # Ensure that the importer finds the 'PackageAObject' defined in 'importer1' first. return pe
def test_resource_reader(self): """Tests DirectoryReader as the base for get_resource_reader.""" filename = self.temp() with PackageExporter(filename, verbose=False) as pe: # Layout looks like: # package # ├── one/ # │ ├── a.txt # │ ├── b.txt # │ ├── c.txt # │ └── three/ # │ ├── d.txt # │ └── e.txt # └── two/ # ├── f.txt # └── g.txt pe.save_text("one", "a.txt", "hello, a!") pe.save_text("one", "b.txt", "hello, b!") pe.save_text("one", "c.txt", "hello, c!") pe.save_text("one.three", "d.txt", "hello, d!") pe.save_text("one.three", "e.txt", "hello, e!") pe.save_text("two", "f.txt", "hello, f!") pe.save_text("two", "g.txt", "hello, g!") zip_file = zipfile.ZipFile(filename, "r") with TemporaryDirectory() as temp_dir: zip_file.extractall(path=temp_dir) importer = PackageImporter(Path(temp_dir) / Path(filename).name) reader_one = importer.get_resource_reader("one") # Different behavior from still zipped archives resource_path = os.path.join(Path(temp_dir), Path(filename).name, "one", "a.txt") self.assertEqual(reader_one.resource_path("a.txt"), resource_path) self.assertTrue(reader_one.is_resource("a.txt")) self.assertEqual( reader_one.open_resource("a.txt").getbuffer(), b"hello, a!") self.assertFalse(reader_one.is_resource("three")) reader_one_contents = list(reader_one.contents()) reader_one_contents.sort() self.assertSequenceEqual(reader_one_contents, ["a.txt", "b.txt", "c.txt", "three"]) reader_two = importer.get_resource_reader("two") self.assertTrue(reader_two.is_resource("f.txt")) self.assertEqual( reader_two.open_resource("f.txt").getbuffer(), b"hello, f!") reader_two_contents = list(reader_two.contents()) reader_two_contents.sort() self.assertSequenceEqual(reader_two_contents, ["f.txt", "g.txt"]) reader_one_three = importer.get_resource_reader("one.three") self.assertTrue(reader_one_three.is_resource("d.txt")) self.assertEqual( reader_one_three.open_resource("d.txt").getbuffer(), b"hello, d!") reader_one_three_contents = list(reader_one_three.contents()) reader_one_three_contents.sort() self.assertSequenceEqual(reader_one_three_contents, ["d.txt", "e.txt"]) self.assertIsNone( importer.get_resource_reader("nonexistent_package"))
def save(name, model, model_jit, eg): with PackageExporter(str(p / name)) as e: e.mock('iopath.**') e.save_pickle('model', 'model.pkl', model) e.save_pickle('model', 'example.pkl', eg) model_jit.save(str(p / (name + '_jit')))
def test_file_structure(self): filename = self.temp() export_plain = dedent("""\ ├── main │ └── main ├── obj │ └── obj.pkl ├── package_a │ ├── __init__.py │ └── subpackage.py └── module_a.py """) export_include = dedent("""\ ├── obj │ └── obj.pkl └── package_a └── subpackage.py """) import_exclude = dedent("""\ ├── .data │ ├── extern_modules │ └── version ├── main │ └── main ├── obj │ └── obj.pkl ├── package_a │ ├── __init__.py │ └── subpackage.py └── module_a.py """) with PackageExporter(filename, verbose=False) as he: import module_a import package_a import package_a.subpackage obj = package_a.subpackage.PackageASubpackageObject() he.save_module(module_a.__name__) he.save_module(package_a.__name__) he.save_pickle("obj", "obj.pkl", obj) he.save_text("main", "main", "my string") export_file_structure = he.file_structure() # remove first line from testing because WINDOW/iOS/Unix treat the filename differently self.assertEqual( dedent("\n".join(str(export_file_structure).split("\n")[1:])), export_plain, ) export_file_structure = he.file_structure( include=["**/subpackage.py", "**/*.pkl"]) self.assertEqual( dedent("\n".join(str(export_file_structure).split("\n")[1:])), export_include, ) hi = PackageImporter(filename) import_file_structure = hi.file_structure(exclude="**/*.storage") self.assertEqual( dedent("\n".join(str(import_file_structure).split("\n")[1:])), import_exclude, )
resnet_eg = torch.rand(1, 3, 224, 224) resnet_traced = torch.jit.trace(resnet, resnet_eg) save("resnet", resnet, resnet_traced, (resnet_eg, )) simple = Simple(10, 20) save("simple", simple, torch.jit.script(simple), (torch.rand(10, 20), )) multi_return = MultiReturn() save("multi_return", multi_return, torch.jit.script(multi_return), (torch.rand(10, 20), ), multi_return_metadata) # used for torch deploy/package tests in predictor batched_model = BatchedModel() save("batched_model", batched_model) with PackageExporter(str(p / "load_library")) as e: e.mock("iopath.**") e.intern("**") e.save_pickle("fn", "fn.pkl", load_library) generate_fx_example() with PackageExporter(p / "uses_distributed") as e: e.save_source_string( "uses_distributed", "import torch.distributed; assert torch.distributed.is_available()" ) with PackageExporter(str(p / "make_trt_module")) as e: e.extern("tensorrt") e.add_dependency("tensorrt")
resnet = resnet18() resnet.eval() resnet_eg = torch.rand(1, 3, 224, 224) resnet_traced = torch.jit.trace(resnet, resnet_eg) save("resnet", resnet, resnet_traced, (resnet_eg, )) simple = Simple(10, 20) save("simple", simple, torch.jit.script(simple), (torch.rand(10, 20), )) multi_return = MultiReturn() save("multi_return", multi_return, torch.jit.script(multi_return), (torch.rand(10, 20), ), multi_return_metadata) # used for torch deploy/package tests in predictor batched_model = BatchedModel() save("batched_model", batched_model) with PackageExporter(str(p / "load_library")) as e: e.mock("iopath.**") e.intern("**") e.save_pickle("fn", "fn.pkl", load_library) generate_fx_example() with PackageExporter(p / "uses_distributed") as e: e.save_source_string( "uses_distributed", "import torch.distributed; assert torch.distributed.is_available()" )