def test_package_interface(self):
        """Packaging an interface class should work correctly."""

        import package_a.fake_interface as fake

        uses_interface = fake.UsesInterface()
        scripted = torch.jit.script(uses_interface)
        scripted.proxy_mod = torch.jit.script(fake.NewModule())

        buffer = BytesIO()
        with PackageExporter(buffer, verbose=False) as pe:
            pe.save_pickle("model", "model.pkl", uses_interface)
        buffer.seek(0)

        package_importer = PackageImporter(buffer)
        loaded = package_importer.load_pickle("model", "model.pkl")

        scripted_loaded = torch.jit.script(loaded)
        scripted_loaded.proxy_mod = torch.jit.script(fake.NewModule())

        input = torch.tensor(1)

        self.assertTrue(torch.allclose(scripted(input),
                                       scripted_loaded(input)))
Example #2
0
    def test_save_imported_module_fails(self):
        """
        Directly saving/requiring an PackageImported module should raise a specific error message.
        """
        import package_a.subpackage

        obj = package_a.subpackage.PackageASubpackageObject()
        obj2 = package_a.PackageAObject(obj)
        f1 = self.temp()
        with PackageExporter(f1, verbose=False) as pe:
            pe.save_pickle("obj", "obj.pkl", obj)

        importer1 = PackageImporter(f1)
        loaded1 = importer1.load_pickle("obj", "obj.pkl")

        f2 = self.temp()
        pe = PackageExporter(f2,
                             verbose=False,
                             importer=(importer1, sys_importer))
        with self.assertRaisesRegex(ModuleNotFoundError, "torch.package"):
            pe.require_module(loaded1.__module__)
        with self.assertRaisesRegex(ModuleNotFoundError, "torch.package"):
            pe.save_module(loaded1.__module__)
Example #3
0
    def test_resnet(self):
        resnet = resnet18()

        f1 = self.temp()

        # create a package that will save it along with its code
        with PackageExporter(f1, verbose=False) as e:
            # put the pickled resnet in the package, by default
            # this will also save all the code files references by
            # the objects in the pickle
            e.save_pickle("model", "model.pkl", resnet)

            # check th debug graph has something reasonable:
            buf = StringIO()
            debug_graph = e._write_dep_graph(failing_module="torch")
            self.assertIn("torchvision.models.resnet", debug_graph)

        # we can now load the saved model
        i = PackageImporter(f1)
        r2 = i.load_pickle("model", "model.pkl")

        # test that it works
        input = torch.rand(1, 3, 224, 224)
        ref = resnet(input)
        self.assertTrue(torch.allclose(r2(input), ref))

        # functions exist also to get at the private modules in each package
        torchvision = i.import_module("torchvision")

        f2 = BytesIO()
        # if we are doing transfer learning we might want to re-save
        # things that were loaded from a package.
        # We need to tell the exporter about any modules that
        # came from imported packages so that it can resolve
        # class names like torchvision.models.resnet.ResNet
        # to their source code.
        with PackageExporter(f2, verbose=False,
                             importer=(i, sys_importer)) as e:
            # e.importers is a list of module importing functions
            # that by default contains importlib.import_module.
            # it is searched in order until the first success and
            # that module is taken to be what torchvision.models.resnet
            # should be in this code package. In the case of name collisions,
            # such as trying to save a ResNet from two different packages,
            # we take the first thing found in the path, so only ResNet objects from
            # one importer will work. This avoids a bunch of name mangling in
            # the source code. If you need to actually mix ResNet objects,
            # we suggest reconstructing the model objects using code from a single package
            # using functions like save_state_dict and load_state_dict to transfer state
            # to the correct code objects.
            e.save_pickle("model", "model.pkl", r2)

        f2.seek(0)

        i2 = PackageImporter(f2)
        r3 = i2.load_pickle("model", "model.pkl")
        self.assertTrue(torch.allclose(r3(input), ref))

        # test we can load from a directory
        import zipfile

        zf = zipfile.ZipFile(f1, "r")

        with TemporaryDirectory() as td:
            zf.extractall(path=td)
            iz = PackageImporter(str(Path(td) / Path(f1).name))
            r4 = iz.load_pickle("model", "model.pkl")
            self.assertTrue(torch.allclose(r4(input), ref))
Example #4
0
    def test_model_save(self):

        # This example shows how you might package a model
        # so that the creator of the model has flexibility about
        # how they want to save it but the 'server' can always
        # use the same API to load the package.

        # The convension is for each model to provide a
        # 'model' package with a 'load' function that actual
        # reads the model out of the archive.

        # How the load function is implemented is up to the
        # the packager.

        # get our normal torchvision resnet
        resnet = resnet18()

        f1 = BytesIO()
        # Option 1: save by pickling the whole model
        # + single-line, similar to torch.jit.save
        # - more difficult to edit the code after the model is created
        with PackageExporter(f1, verbose=False) as e:
            e.save_pickle("model", "pickled", resnet)
            # note that this source is the same for all models in this approach
            # so it can be made part of an API that just takes the model and
            # packages it with this source.
            src = dedent("""\
                import importlib
                import torch_package_importer as resources

                # server knows to call model.load() to get the model,
                # maybe in the future it passes options as arguments by convension
                def load():
                    return resources.load_pickle('model', 'pickled')
                """)
            e.save_source_string("model", src, is_package=True)

        f2 = BytesIO()
        # Option 2: save with state dict
        # - more code to write to save/load the model
        # + but this code can be edited later to adjust adapt the model later
        with PackageExporter(f2, verbose=False) as e:
            e.save_pickle("model", "state_dict", resnet.state_dict())
            src = dedent("""\
                import importlib
                import torch_package_importer as resources

                from torchvision.models.resnet import resnet18
                def load():
                    # if you want, you can later edit how resnet is constructed here
                    # to edit the model in the package, while still loading the original
                    # state dict weights
                    r = resnet18()
                    state_dict = resources.load_pickle('model', 'state_dict')
                    r.load_state_dict(state_dict)
                    return r
                """)
            e.save_source_string("model", src, is_package=True)

        # regardless of how we chose to package, we can now use the model in a server in the same way
        input = torch.rand(1, 3, 224, 224)
        results = []
        for m in [f1, f2]:
            m.seek(0)
            importer = PackageImporter(m)
            the_model = importer.import_module("model").load()
            r = the_model(input)
            results.append(r)

        self.assertTrue(torch.allclose(*results))
Example #5
0
    def test_file_structure(self):
        """
        Tests package's Folder structure representation of a zip file. Ensures
        that the returned Folder prints what is expected and filters
        inputs/outputs correctly.
        """
        buffer = BytesIO()

        export_plain = dedent("""\
                ├── main
                │   └── main
                ├── obj
                │   └── obj.pkl
                ├── package_a
                │   ├── __init__.py
                │   └── subpackage.py
                └── module_a.py
            """)
        export_include = dedent("""\
                ├── obj
                │   └── obj.pkl
                └── package_a
                    └── subpackage.py
            """)
        import_exclude = dedent("""\
                ├── .data
                │   ├── extern_modules
                │   └── version
                ├── main
                │   └── main
                ├── obj
                │   └── obj.pkl
                ├── package_a
                │   ├── __init__.py
                │   └── subpackage.py
                └── module_a.py
            """)

        with PackageExporter(buffer, verbose=False) as he:
            import module_a
            import package_a
            import package_a.subpackage

            obj = package_a.subpackage.PackageASubpackageObject()
            he.save_module(module_a.__name__)
            he.save_module(package_a.__name__)
            he.save_pickle("obj", "obj.pkl", obj)
            he.save_text("main", "main", "my string")

            export_file_structure = he.file_structure()
            # remove first line from testing because WINDOW/iOS/Unix treat the buffer differently
            self.assertEqual(
                dedent("\n".join(str(export_file_structure).split("\n")[1:])),
                export_plain,
            )
            export_file_structure = he.file_structure(
                include=["**/subpackage.py", "**/*.pkl"])
            self.assertEqual(
                dedent("\n".join(str(export_file_structure).split("\n")[1:])),
                export_include,
            )

        buffer.seek(0)
        hi = PackageImporter(buffer)
        import_file_structure = hi.file_structure(exclude="**/*.storage")
        self.assertEqual(
            dedent("\n".join(str(import_file_structure).split("\n")[1:])),
            import_exclude,
        )
Example #6
0
parser = argparse.ArgumentParser(description="Generate Examples")
parser.add_argument("--install_dir",
                    help="Root directory for all output files")

if __name__ == "__main__":
    args = parser.parse_args()
    if args.install_dir is None:
        p = Path(__file__).parent / "generated"
        p.mkdir(exist_ok=True)
    else:
        p = Path(args.install_dir)

    resnet = resnet18()
    resnet.eval()
    resnet_eg = torch.rand(1, 3, 224, 224)
    resnet_traced = torch.jit.trace(resnet, resnet_eg)
    save("resnet", resnet, resnet_traced, (resnet_eg, ))

    simple = Simple(10, 20)
    save("simple", simple, torch.jit.script(simple), (torch.rand(10, 20), ))

    multi_return = MultiReturn()
    save("multi_return", multi_return, torch.jit.script(multi_return),
         (torch.rand(10, 20), ), multi_return_metadata)

    with PackageExporter(str(p / "load_library")) as e:
        e.mock("iopath.**")
        e.intern("**")
        e.save_pickle("fn", "fn.pkl", load_library)
Example #7
0
 def make_exporter():
     pe = PackageExporter(f2, verbose=False)
     # Ensure that the importer finds the 'PackageAObject' defined in 'importer1' first.
     pe.importers.insert(0, importer1.import_module)
     return pe
Example #8
0
 def make_exporter():
     pe = PackageExporter(f2, importer=[importer1, sys_importer])
     # Ensure that the importer finds the 'PackageAObject' defined in 'importer1' first.
     return pe
Example #9
0
    def test_resource_reader(self):
        """Tests DirectoryReader as the base for get_resource_reader."""
        filename = self.temp()
        with PackageExporter(filename, verbose=False) as pe:
            # Layout looks like:
            #    package
            #    ├── one/
            #    │   ├── a.txt
            #    │   ├── b.txt
            #    │   ├── c.txt
            #    │   └── three/
            #    │       ├── d.txt
            #    │       └── e.txt
            #    └── two/
            #       ├── f.txt
            #       └── g.txt
            pe.save_text("one", "a.txt", "hello, a!")
            pe.save_text("one", "b.txt", "hello, b!")
            pe.save_text("one", "c.txt", "hello, c!")

            pe.save_text("one.three", "d.txt", "hello, d!")
            pe.save_text("one.three", "e.txt", "hello, e!")

            pe.save_text("two", "f.txt", "hello, f!")
            pe.save_text("two", "g.txt", "hello, g!")

        zip_file = zipfile.ZipFile(filename, "r")

        with TemporaryDirectory() as temp_dir:
            zip_file.extractall(path=temp_dir)
            importer = PackageImporter(Path(temp_dir) / Path(filename).name)
            reader_one = importer.get_resource_reader("one")

            # Different behavior from still zipped archives
            resource_path = os.path.join(Path(temp_dir),
                                         Path(filename).name, "one", "a.txt")
            self.assertEqual(reader_one.resource_path("a.txt"), resource_path)

            self.assertTrue(reader_one.is_resource("a.txt"))
            self.assertEqual(
                reader_one.open_resource("a.txt").getbuffer(), b"hello, a!")
            self.assertFalse(reader_one.is_resource("three"))
            reader_one_contents = list(reader_one.contents())
            reader_one_contents.sort()
            self.assertSequenceEqual(reader_one_contents,
                                     ["a.txt", "b.txt", "c.txt", "three"])

            reader_two = importer.get_resource_reader("two")
            self.assertTrue(reader_two.is_resource("f.txt"))
            self.assertEqual(
                reader_two.open_resource("f.txt").getbuffer(), b"hello, f!")
            reader_two_contents = list(reader_two.contents())
            reader_two_contents.sort()
            self.assertSequenceEqual(reader_two_contents, ["f.txt", "g.txt"])

            reader_one_three = importer.get_resource_reader("one.three")
            self.assertTrue(reader_one_three.is_resource("d.txt"))
            self.assertEqual(
                reader_one_three.open_resource("d.txt").getbuffer(),
                b"hello, d!")
            reader_one_three_contents = list(reader_one_three.contents())
            reader_one_three_contents.sort()
            self.assertSequenceEqual(reader_one_three_contents,
                                     ["d.txt", "e.txt"])

            self.assertIsNone(
                importer.get_resource_reader("nonexistent_package"))
Example #10
0
def save(name, model, model_jit, eg):
    with PackageExporter(str(p / name)) as e:
        e.mock('iopath.**')
        e.save_pickle('model', 'model.pkl', model)
        e.save_pickle('model', 'example.pkl', eg)
    model_jit.save(str(p / (name + '_jit')))
Example #11
0
    def test_file_structure(self):
        filename = self.temp()

        export_plain = dedent("""\
                ├── main
                │   └── main
                ├── obj
                │   └── obj.pkl
                ├── package_a
                │   ├── __init__.py
                │   └── subpackage.py
                └── module_a.py
            """)
        export_include = dedent("""\
                ├── obj
                │   └── obj.pkl
                └── package_a
                    └── subpackage.py
            """)
        import_exclude = dedent("""\
                ├── .data
                │   ├── extern_modules
                │   └── version
                ├── main
                │   └── main
                ├── obj
                │   └── obj.pkl
                ├── package_a
                │   ├── __init__.py
                │   └── subpackage.py
                └── module_a.py
            """)

        with PackageExporter(filename, verbose=False) as he:
            import module_a
            import package_a
            import package_a.subpackage

            obj = package_a.subpackage.PackageASubpackageObject()
            he.save_module(module_a.__name__)
            he.save_module(package_a.__name__)
            he.save_pickle("obj", "obj.pkl", obj)
            he.save_text("main", "main", "my string")

            export_file_structure = he.file_structure()
            # remove first line from testing because WINDOW/iOS/Unix treat the filename differently
            self.assertEqual(
                dedent("\n".join(str(export_file_structure).split("\n")[1:])),
                export_plain,
            )
            export_file_structure = he.file_structure(
                include=["**/subpackage.py", "**/*.pkl"])
            self.assertEqual(
                dedent("\n".join(str(export_file_structure).split("\n")[1:])),
                export_include,
            )

        hi = PackageImporter(filename)
        import_file_structure = hi.file_structure(exclude="**/*.storage")
        self.assertEqual(
            dedent("\n".join(str(import_file_structure).split("\n")[1:])),
            import_exclude,
        )
Example #12
0
    resnet_eg = torch.rand(1, 3, 224, 224)
    resnet_traced = torch.jit.trace(resnet, resnet_eg)
    save("resnet", resnet, resnet_traced, (resnet_eg, ))

    simple = Simple(10, 20)
    save("simple", simple, torch.jit.script(simple), (torch.rand(10, 20), ))

    multi_return = MultiReturn()
    save("multi_return", multi_return, torch.jit.script(multi_return),
         (torch.rand(10, 20), ), multi_return_metadata)

    # used for torch deploy/package tests in predictor
    batched_model = BatchedModel()
    save("batched_model", batched_model)

    with PackageExporter(str(p / "load_library")) as e:
        e.mock("iopath.**")
        e.intern("**")
        e.save_pickle("fn", "fn.pkl", load_library)

    generate_fx_example()

    with PackageExporter(p / "uses_distributed") as e:
        e.save_source_string(
            "uses_distributed",
            "import torch.distributed; assert torch.distributed.is_available()"
        )

    with PackageExporter(str(p / "make_trt_module")) as e:
        e.extern("tensorrt")
        e.add_dependency("tensorrt")
Example #13
0
    resnet = resnet18()
    resnet.eval()
    resnet_eg = torch.rand(1, 3, 224, 224)
    resnet_traced = torch.jit.trace(resnet, resnet_eg)
    save("resnet", resnet, resnet_traced, (resnet_eg, ))

    simple = Simple(10, 20)
    save("simple", simple, torch.jit.script(simple), (torch.rand(10, 20), ))

    multi_return = MultiReturn()
    save("multi_return", multi_return, torch.jit.script(multi_return),
         (torch.rand(10, 20), ), multi_return_metadata)

    # used for torch deploy/package tests in predictor
    batched_model = BatchedModel()
    save("batched_model", batched_model)

    with PackageExporter(str(p / "load_library")) as e:
        e.mock("iopath.**")
        e.intern("**")
        e.save_pickle("fn", "fn.pkl", load_library)

    generate_fx_example()

    with PackageExporter(p / "uses_distributed") as e:
        e.save_source_string(
            "uses_distributed",
            "import torch.distributed; assert torch.distributed.is_available()"
        )