Exemple #1
0
    def test_dunder_package_works_from_package(self):
        """
        The attribute '__torch_package__' should be accessible from within
        the module itself, so that packaged code can detect whether it's
        being used in a packaged context or not.
        """
        import package_a.use_dunder_package as mod

        buffer = BytesIO()

        with PackageExporter(buffer, verbose=False) as pe:
            pe.intern("**")
            pe.save_module(mod.__name__)

        buffer.seek(0)
        pi = PackageImporter(buffer)
        imported_mod = pi.import_module(mod.__name__)
        self.assertTrue(imported_mod.is_from_package())
        self.assertFalse(mod.is_from_package())
    def test_mock(self):
        buffer = BytesIO()
        with PackageExporter(buffer) as he:
            he.mock(["package_a.subpackage", "module_a"])
            # Import something that dependso n package_a.subpackage
            he.save_source_string("foo", "import package_a.subpackage")
        buffer.seek(0)
        hi = PackageImporter(buffer)
        import package_a.subpackage

        _ = package_a.subpackage
        import module_a

        _ = module_a

        m = hi.import_module("package_a.subpackage")
        r = m.result
        with self.assertRaisesRegex(NotImplementedError, "was mocked out"):
            r()
Exemple #3
0
    def test_loading_pickle(self):
        """
        Test basic saving and loading of modules and pickles from a DirectoryReader.
        """
        resnet = resnet18()

        filename = self.temp()
        with PackageExporter(filename) as e:
            e.intern("**")
            e.save_pickle("model", "model.pkl", resnet)

        zip_file = zipfile.ZipFile(filename, "r")

        with TemporaryDirectory() as temp_dir:
            zip_file.extractall(path=temp_dir)
            importer = PackageImporter(Path(temp_dir) / Path(filename).name)
            dir_mod = importer.load_pickle("model", "model.pkl")
            input = torch.rand(1, 3, 224, 224)
            self.assertEqual(dir_mod(input), resnet(input))
Exemple #4
0
    def test_custom_requires(self):
        filename = self.temp()

        class Custom(PackageExporter):
            def require_module(self, name, dependencies):
                if name == 'module_a':
                    self.save_mock_module('module_a')
                elif name == 'package_a':
                    self.save_source_string('package_a', 'import module_a\nresult = 5\n')
                else:
                    raise NotImplementedError('wat')

        with Custom(filename, verbose=False) as he:
            he.save_source_string('main', 'import package_a\n')

        hi = PackageImporter(filename)
        hi.import_module('module_a').should_be_mocked
        bar = hi.import_module('package_a')
        self.assertEqual(bar.result, 5)
Exemple #5
0
    def test_mock_glob(self):
        filename = self.temp()
        with PackageExporter(filename, verbose=False) as he:
            he.mock(['package_a.*', 'module*'])
            he.save_module('package_a')
            he.save_source_string('test_module', """\
import package_a.subpackage
import module_a
""")
        hi = PackageImporter(filename)
        import package_a.subpackage
        _ = package_a.subpackage
        import module_a
        _ = module_a

        m = hi.import_module('package_a.subpackage')
        r = m.result
        with self.assertRaisesRegex(NotImplementedError, 'was mocked out'):
            r()
    def test_load_shared_tensors(self):
        """
        Test tensors shared across eager and ScriptModules on load
        are the same.
        """
        from package_a.test_module import (
            ModWithTensor,
            ModWithTwoSubmodsAndTensor,
        )

        shared_tensor = torch.ones(3, 3)

        scripted_mod_0 = torch.jit.script(ModWithTensor(shared_tensor))
        scripted_mod_1 = torch.jit.script(ModWithTensor(shared_tensor))

        mod1 = ModWithTwoSubmodsAndTensor(shared_tensor, scripted_mod_0,
                                          scripted_mod_1)

        buffer = BytesIO()
        with PackageExporter(buffer) as e:
            e.intern("**")
            e.save_pickle("res", "mod1.pkl", mod1)

        buffer.seek(0)
        importer = PackageImporter(buffer)
        loaded_mod_1 = importer.load_pickle("res", "mod1.pkl")

        self.assertTrue(
            loaded_mod_1.tensor.storage()._cdata,
            loaded_mod_1.sub_mod_0.tensor.storage()._cdata,
        )
        self.assertTrue(
            loaded_mod_1.tensor.storage()._cdata,
            loaded_mod_1.sub_mod_0.tensor.storage()._cdata,
        )

        loaded_mod_1.tensor.add_(torch.ones(3, 3))

        self.assertTrue(
            torch.allclose(loaded_mod_1.tensor, loaded_mod_1.sub_mod_0.tensor))
        self.assertTrue(
            torch.allclose(loaded_mod_1.tensor, loaded_mod_1.sub_mod_1.tensor))
Exemple #7
0
    def test_save_imported_module_fails(self):
        """
        Directly saving/requiring an PackageImported module should raise a specific error message.
        """
        import package_a.subpackage

        obj = package_a.subpackage.PackageASubpackageObject()
        obj2 = package_a.PackageAObject(obj)
        f1 = self.temp()
        with PackageExporter(f1, verbose=False) as pe:
            pe.intern("**")
            pe.save_pickle("obj", "obj.pkl", obj)

        importer1 = PackageImporter(f1)
        loaded1 = importer1.load_pickle("obj", "obj.pkl")

        f2 = self.temp()
        pe = PackageExporter(f2, verbose=False, importer=(importer1, sys_importer))
        with self.assertRaisesRegex(ModuleNotFoundError, "torch.package"):
            pe.save_module(loaded1.__module__)
Exemple #8
0
    def test_mock(self):
        buffer = BytesIO()
        with PackageExporter(buffer, verbose=False) as he:
            he.mock(["package_a.subpackage", "module_a"])
            he.save_module("package_a")
            he.require_module("package_a.subpackage")
            he.require_module("module_a")
        buffer.seek(0)
        hi = PackageImporter(buffer)
        import package_a.subpackage

        _ = package_a.subpackage
        import module_a

        _ = module_a

        m = hi.import_module("package_a.subpackage")
        r = m.result
        with self.assertRaisesRegex(NotImplementedError, "was mocked out"):
            r()
Exemple #9
0
    def test_ordered_importer_basic(self):
        import package_a

        buffer = BytesIO()
        with PackageExporter(buffer) as pe:
            pe.save_module(package_a.__name__)

        buffer.seek(0)
        importer = PackageImporter(buffer)

        ordered_importer_sys_first = OrderedImporter(sys_importer, importer)
        self.assertIs(ordered_importer_sys_first.import_module("package_a"),
                      package_a)

        ordered_importer_package_first = OrderedImporter(
            importer, sys_importer)
        self.assertIs(
            ordered_importer_package_first.import_module("package_a"),
            importer.import_module("package_a"),
        )
Exemple #10
0
    def test_inspect_class(self):
        """Should be able to retrieve source for a packaged class."""
        import package_a.subpackage

        buffer = BytesIO()
        obj = package_a.subpackage.PackageASubpackageObject()

        with PackageExporter(buffer) as pe:
            pe.intern("**")
            pe.save_pickle("obj", "obj.pkl", obj)

        buffer.seek(0)
        pi = PackageImporter(buffer)
        packaged_class = pi.import_module(
            "package_a.subpackage").PackageASubpackageObject
        regular_class = package_a.subpackage.PackageASubpackageObject

        packaged_src = inspect.getsourcelines(packaged_class)
        regular_src = inspect.getsourcelines(regular_class)
        self.assertEqual(packaged_src, regular_src)
    def test_is_from_package(self):
        """is_from_package should work for objects and modules"""
        import package_a.subpackage

        buffer = BytesIO()
        obj = package_a.subpackage.PackageASubpackageObject()

        with PackageExporter(buffer, verbose=False) as pe:
            pe.save_pickle("obj", "obj.pkl", obj)

        buffer.seek(0)
        pi = PackageImporter(buffer)
        mod = pi.import_module("package_a.subpackage")
        loaded_obj = pi.load_pickle("obj", "obj.pkl")

        self.assertFalse(is_from_package(package_a.subpackage))
        self.assertTrue(is_from_package(mod))

        self.assertFalse(is_from_package(obj))
        self.assertTrue(is_from_package(loaded_obj))
    def test_custom_requires(self):
        buffer = BytesIO()

        class Custom(PackageExporter):
            def require_module(self, name, dependencies):
                if name == "module_a":
                    self.save_mock_module("module_a")
                elif name == "package_a":
                    self.save_source_string("package_a",
                                            "import module_a\nresult = 5\n")
                else:
                    raise NotImplementedError("wat")

        with Custom(buffer, verbose=False) as he:
            he.save_source_string("main", "import package_a\n")

        buffer.seek(0)
        hi = PackageImporter(buffer)
        hi.import_module("module_a").should_be_mocked
        bar = hi.import_module("package_a")
        self.assertEqual(bar.result, 5)
Exemple #13
0
    def test_package_resource_access(self):
        """Packaged modules should be able to use the importlib.resources API to access
        resources saved in the package.
        """
        mod_src = dedent("""\
            import importlib.resources
            import my_cool_resources

            def secret_message():
                return importlib.resources.read_text(my_cool_resources, 'sekrit.txt')
            """)
        buffer = BytesIO()
        with PackageExporter(buffer, verbose=False) as pe:
            pe.save_source_string("foo.bar", mod_src)
            pe.save_text("my_cool_resources", "sekrit.txt", "my sekrit plays")

        buffer.seek(0)
        importer = PackageImporter(buffer)
        self.assertEqual(
            importer.import_module("foo.bar").secret_message(),
            "my sekrit plays")
Exemple #14
0
    def test_resource_access_by_path(self):
        """
        Tests that packaged code can used importlib.resources.path.
        """
        buffer = BytesIO()
        with PackageExporter(buffer) as he:
            he.save_binary("string_module", "my_string",
                           "my string".encode("utf-8"))
            src = dedent("""\
                import importlib.resources
                import string_module

                with importlib.resources.path(string_module, 'my_string') as path:
                    with open(path, mode='r', encoding='utf-8') as f:
                        s = f.read()
                """)
            he.save_source_string("main", src, is_package=True)
        buffer.seek(0)
        hi = PackageImporter(buffer)
        m = hi.import_module("main")
        self.assertEqual(m.s, "my string")
Exemple #15
0
    def test_mixing_packaged_and_inline_modules_shared_code(self):
        """
        Test saving inline and imported modules in same package that
        share code.
        """
        class TorchVisionTestInline(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.tvmod = resnet18()

            def forward(self, x):
                x = a_non_torch_leaf(x, x)
                return torch.relu(x + 3.0)

        def a_non_torch_leaf(a, b):
            return a + b

        inline_mod = TorchVisionTestInline()
        scripted_inline = torch.jit.script(inline_mod)

        from package_c.test_module import TorchVisionTest

        imported_mod = TorchVisionTest()
        scripted_imported = torch.jit.script(imported_mod)

        buffer = BytesIO()
        with PackageExporter(buffer, verbose=False) as e:
            e.save_pickle("model", "inline.pkl", scripted_inline)
            e.save_pickle("model", "imported.pkl", scripted_imported)

        buffer.seek(0)
        importer = PackageImporter(buffer)
        loaded_inline = importer.load_pickle("model", "inline.pkl")
        loaded_imported = importer.load_pickle("model", "imported.pkl")

        input = torch.rand(2, 3)
        self.assertTrue(
            torch.allclose(loaded_imported(input), imported_mod(input)))
        self.assertTrue(torch.allclose(loaded_inline(input),
                                       inline_mod(input)))
Exemple #16
0
    def test_package_importer_whichmodule_no_dunder_module(self):
        """Exercise corner case where we try to pickle an object whose
        __module__ doesn't exist because it's from a C extension.
        """
        # torch.float16 is an example of such an object: it is a C extension
        # type for which there is no __module__ defined. The default pickler
        # finds it using special logic to traverse sys.modules and look up
        # `float16` on each module (see pickle.py:whichmodule).
        #
        # We must ensure that we emulate the same behavior from PackageImporter.
        my_dtype = torch.float16

        # Set up a PackageImporter which has a torch.float16 object pickled:
        buffer = BytesIO()
        with PackageExporter(buffer) as exporter:
            exporter.save_pickle("foo", "foo.pkl", my_dtype)
        buffer.seek(0)

        importer = PackageImporter(buffer)
        my_loaded_dtype = importer.load_pickle("foo", "foo.pkl")

        # Re-save a package with only our PackageImporter as the importer
        buffer2 = BytesIO()
        with PackageExporter(buffer2, importer=importer) as exporter:
            exporter.save_pickle("foo", "foo.pkl", my_loaded_dtype)

        buffer2.seek(0)

        importer2 = PackageImporter(buffer2)
        my_loaded_dtype2 = importer2.load_pickle("foo", "foo.pkl")
        self.assertIs(my_dtype, my_loaded_dtype)
        self.assertIs(my_dtype, my_loaded_dtype2)
    def test_saving_and_scripting_packaged_mod(self):
        """
        Test scripting a module loaded from a package
        and saving it in a new package as a script object.
        """
        from package_a.test_module import SimpleTest

        orig_mod = SimpleTest()

        buffer_0 = BytesIO()
        with PackageExporter(buffer_0) as e:
            e.intern("**")
            e.save_pickle("model", "model.pkl", orig_mod)

        buffer_0.seek(0)
        importer_0 = PackageImporter(buffer_0)
        loaded_mod = importer_0.load_pickle("model", "model.pkl")

        input = torch.rand(2, 3)
        self.assertEqual(loaded_mod(input), orig_mod(input))

        scripted_mod = torch.jit.script(loaded_mod)

        buffer_1 = BytesIO()
        with PackageExporter(buffer_1, importer=importer_0) as e:
            e.intern("**")
            e.save_pickle("res", "scripted_mod.pkl", scripted_mod)

        buffer_1.seek(0)
        importer_1 = PackageImporter(buffer_1)
        loaded_mod_scripted = importer_1.load_pickle("res", "scripted_mod.pkl")

        self.assertEqual(loaded_mod_scripted(input), orig_mod(input))
Exemple #18
0
    def test_repackage_mocked_module(self):
        """Re-packaging a package that contains a mocked module should work correctly."""
        buffer = BytesIO()
        with PackageExporter(buffer) as exporter:
            exporter.mock("package_a")
            exporter.save_source_string("foo", "import package_a")

        buffer.seek(0)
        importer = PackageImporter(buffer)
        foo = importer.import_module("foo")

        # "package_a" should be mocked out.
        with self.assertRaises(NotImplementedError):
            foo.package_a.get_something()

        # Re-package the model, but intern the previously-mocked module and mock
        # everything else.
        buffer2 = BytesIO()
        with PackageExporter(buffer2, importer=importer) as exporter:
            exporter.intern("package_a")
            exporter.mock("**")
            exporter.save_source_string("foo", "import package_a")

        buffer2.seek(0)
        importer2 = PackageImporter(buffer2)
        foo2 = importer2.import_module("foo")

        # "package_a" should still be mocked out.
        with self.assertRaises(NotImplementedError):
            foo2.package_a.get_something()
Exemple #19
0
    def test_extern_glob(self):
        filename = self.temp()
        with PackageExporter(filename, verbose=False) as he:
            he.extern(["package_a.*", "module_*"])
            he.save_module("package_a")
            he.save_source_string(
                "test_module",
                dedent(
                    """\
                    import package_a.subpackage
                    import module_a
                    """
                ),
            )
        hi = PackageImporter(filename)
        import module_a
        import package_a.subpackage

        module_a_im = hi.import_module("module_a")
        hi.import_module("package_a.subpackage")
        package_a_im = hi.import_module("package_a")

        self.assertIs(module_a, module_a_im)
        self.assertIsNot(package_a, package_a_im)
        self.assertIs(package_a.subpackage, package_a_im.subpackage)
Exemple #20
0
    def test_package_fx_package(self):
        from package_a.test_module import SimpleTest

        model = SimpleTest()
        f = BytesIO()
        with PackageExporter(f) as pe:
            pe.intern("**")
            pe.save_pickle("model", "model.pkl", model)

        f.seek(0)
        pi = PackageImporter(f)
        loaded = pi.load_pickle("model", "model.pkl")
        traced = symbolic_trace(loaded)

        # re-save the package exporter
        f2 = BytesIO()
        # This should fail, because we are referencing some globals that are
        # only in the package.
        with self.assertRaises(ObjMismatchError):
            with PackageExporter(f2) as pe:
                pe.intern("**")
                pe.save_pickle("model", "model.pkl", traced)

        f2.seek(0)
        with PackageExporter(f2, importer=(pi, sys_importer)) as pe:
            # Make the package available to the exporter's environment.
            pe.intern("**")
            pe.save_pickle("model", "model.pkl", traced)
        f2.seek(0)
        pi2 = PackageImporter(f2)
        loaded2 = pi2.load_pickle("model", "model.pkl")

        input = torch.rand(2, 3)
        self.assertTrue(torch.allclose(loaded(input), loaded2(input)))
Exemple #21
0
    def test_extern_glob(self):
        buffer = BytesIO()
        with PackageExporter(buffer, verbose=False) as he:
            he.extern(["package_a.*", "module_*"])
            he.save_module("package_a")
            he.save_source_string(
                "test_module",
                dedent(
                    """\
                    import package_a.subpackage
                    import module_a
                    """
                ),
            )
        buffer.seek(0)
        hi = PackageImporter(buffer)
        import module_a
        import package_a.subpackage

        module_a_im = hi.import_module("module_a")
        hi.import_module("package_a.subpackage")
        package_a_im = hi.import_module("package_a")

        self.assertIs(module_a, module_a_im)
        self.assertIsNot(package_a, package_a_im)
        self.assertIs(package_a.subpackage, package_a_im.subpackage)
    def test_externing_c_extension(self):
        """Externing c extensions modules should allow us to still access them especially those found in torch._C."""

        buffer = BytesIO()
        # The C extension module in question is F.gelu which comes from torch._C._nn
        model = torch.nn.TransformerEncoderLayer(
            d_model=64,
            nhead=2,
            dim_feedforward=64,
            dropout=1.0,
            batch_first=True,
            activation="gelu",
            norm_first=True,
        )
        with PackageExporter(buffer) as e:
            e.extern("torch.**")
            e.intern("**")

            e.save_pickle("model", "model.pkl", model)
        buffer.seek(0)
        imp = PackageImporter(buffer)
        imp.load_pickle("model", "model.pkl")
    def test_save_independent_scriptmodules(self):
        """
        Test to verify saving multiple ScriptModules with completely
        separate code works.
        """
        from package_a.test_module import ModWithTensor, SimpleTest

        scripted_mod_0 = torch.jit.script(SimpleTest())
        scripted_mod_1 = torch.jit.script(ModWithTensor(torch.rand(1, 2, 3)))

        buffer = BytesIO()
        with PackageExporter(buffer) as e:
            e.save_pickle("res", "mod1.pkl", scripted_mod_0)
            e.save_pickle("res", "mod2.pkl", scripted_mod_1)

        buffer.seek(0)
        importer = PackageImporter(buffer)
        loaded_mod_0 = importer.load_pickle("res", "mod1.pkl")
        loaded_mod_1 = importer.load_pickle("res", "mod2.pkl")
        input = torch.rand(1, 2, 3)
        self.assertEqual(loaded_mod_0(input), scripted_mod_0(input))
        self.assertEqual(loaded_mod_1(input), scripted_mod_1(input))
    def test_package_script_class_referencing_self(self):
        import package_a.fake_script_class as fake

        obj = fake.UsesIdListFeature()
        # intentionally script here to fill the compilation cache, to make sure
        # there is no false sharing between scripted types coming from the
        # package vs. outside environment.
        torch.jit.script(obj)

        buffer = BytesIO()
        with PackageExporter(buffer) as exporter:
            exporter.intern("**")
            exporter.save_pickle("obj", "obj.pkl", obj)

        buffer.seek(0)
        importer = PackageImporter(buffer)
        obj_loaded = importer.load_pickle("obj", "obj.pkl")
        scripted_obj_loaded = torch.jit.script(obj_loaded)

        # Make sure the scripted object can be serialized without error.
        buffer2 = scripted_obj_loaded.save_to_buffer()
        torch.jit.load(BytesIO(buffer2))
Exemple #25
0
    def test_save_scriptmodules_in_container(self):
        """
        Test saving of ScriptModules inside of container. Checks that relations
        between shared modules are upheld.
        """
        from package_a.test_module import ModWithSubmodAndTensor, ModWithTensor

        scripted_mod_a = torch.jit.script(ModWithTensor(torch.rand(1, 2, 3)))
        scripted_mod_b = torch.jit.script(
            ModWithSubmodAndTensor(torch.rand(1, 2, 3), scripted_mod_a))
        script_mods_list = [scripted_mod_a, scripted_mod_b]

        buffer = BytesIO()
        with PackageExporter(buffer) as e:
            e.save_pickle("res", "list.pkl", script_mods_list)

        buffer.seek(0)
        importer = PackageImporter(buffer)
        loaded_mod_list = importer.load_pickle("res", "list.pkl")
        input = torch.rand(1, 2, 3)
        self.assertEqual(loaded_mod_list[0](input), scripted_mod_a(input))
        self.assertEqual(loaded_mod_list[1](input), scripted_mod_b(input))
Exemple #26
0
    def test_single_ordered_importer(self):
        import package_a
        import module_a  # noqa: F401
        buffer = BytesIO()
        with PackageExporter(buffer, verbose=False) as pe:
            pe.save_module(package_a.__name__)

        buffer.seek(0)
        importer = PackageImporter(buffer)

        # Construct an importer-only environment.
        ordered_importer = OrderedImporter(importer)

        # The module returned by this environment should be the same one that's
        # in the importer.
        self.assertIs(ordered_importer.import_module('package_a'), importer.import_module('package_a'))
        # It should not be the one available in the outer Python environment.
        self.assertIsNot(ordered_importer.import_module('package_a'), package_a)

        # We didn't package this module, so it should not be available.
        with self.assertRaises(ModuleNotFoundError):
            ordered_importer.import_module('module_a')
    def test_load_shared_tensors_repackaged(self):
        """
        Test tensors shared across eager and ScriptModules on load
        are the same across multiple package saves and loads. This is
        an important test because not all of the tensor information is restored
        in python between packages. The python identity is not maintained, but
        the backing cpp TensorImpl is. We load/save storages based off of this
        cpp TensorImpl and not the python identity.
        """
        from package_a.test_module import (
            ModWithTensor,
            ModWithTwoSubmodsAndTensor,
        )

        shared_tensor = torch.ones(3, 3)

        scripted_mod_0 = torch.jit.script(ModWithTensor(shared_tensor))
        scripted_mod_1 = torch.jit.script(ModWithTensor(shared_tensor))

        mod1 = ModWithTwoSubmodsAndTensor(shared_tensor, scripted_mod_0, scripted_mod_1)

        buffer_0 = BytesIO()
        with PackageExporter(buffer_0) as e:
            e.intern("**")
            e.save_pickle("res", "mod1.pkl", mod1)

        buffer_0.seek(0)
        importer_0 = PackageImporter(buffer_0)
        loaded_mod_0 = importer_0.load_pickle("res", "mod1.pkl")

        buffer_1 = BytesIO()
        with PackageExporter(buffer_1, importer=importer_0) as e:
            e.intern("**")
            e.save_pickle("res", "mod1.pkl", loaded_mod_0)

        buffer_1.seek(0)
        importer = PackageImporter(buffer_1)
        loaded_mod_1 = importer.load_pickle("res", "mod1.pkl")

        self.assertEqual(
            loaded_mod_1.tensor.storage()._cdata,
            loaded_mod_1.sub_mod_0.tensor.storage()._cdata,
        )
        self.assertEqual(
            loaded_mod_1.tensor.storage()._cdata,
            loaded_mod_1.sub_mod_1.tensor.storage()._cdata,
        )

        loaded_mod_1.tensor.add_(
            torch.ones(3, 3)
        )  # all tensors should reflect this change

        self.assertTrue(
            torch.allclose(loaded_mod_1.tensor, loaded_mod_1.sub_mod_0.tensor)
        )
        self.assertTrue(
            torch.allclose(loaded_mod_1.tensor, loaded_mod_1.sub_mod_1.tensor)
        )
Exemple #28
0
    def test_package_fx_custom_tracer(self):
        from package_a.test_all_leaf_modules_tracer import TestAllLeafModulesTracer
        from package_a.test_module import SimpleTest, ModWithTwoSubmodsAndTensor

        class SpecialGraphModule(torch.fx.GraphModule):
            def __init__(self, root, graph, info):
                super().__init__(root, graph)
                self.info = info

        sub_module = SimpleTest()
        module = ModWithTwoSubmodsAndTensor(
            torch.ones(3),
            sub_module,
            sub_module,
        )
        tracer = TestAllLeafModulesTracer()
        graph = tracer.trace(module)

        self.assertEqual(graph._tracer_cls, TestAllLeafModulesTracer)

        gm = SpecialGraphModule(module, graph, "secret")
        self.assertEqual(gm._tracer_cls, TestAllLeafModulesTracer)

        f = BytesIO()
        with PackageExporter(f) as pe:
            pe.intern("**")
            pe.save_pickle("model", "model.pkl", gm)
        f.seek(0)

        pi = PackageImporter(f)
        loaded_gm = pi.load_pickle("model", "model.pkl")
        self.assertEqual(
            type(loaded_gm).__class__.__name__, SpecialGraphModule.__class__.__name__
        )
        self.assertEqual(loaded_gm.info, "secret")

        input_x = torch.randn(3)
        self.assertEqual(loaded_gm(input_x), gm(input_x))
Exemple #29
0
    def test_resnet(self):
        resnet = resnet18()

        f1 = self.temp()

        # create a package that will save it along with its code
        with PackageExporter(f1) as e:
            # put the pickled resnet in the package, by default
            # this will also save all the code files references by
            # the objects in the pickle
            e.intern("**")
            e.save_pickle("model", "model.pkl", resnet)

        # we can now load the saved model
        i = PackageImporter(f1)
        r2 = i.load_pickle("model", "model.pkl")

        # test that it works
        input = torch.rand(1, 3, 224, 224)
        ref = resnet(input)
        self.assertEqual(r2(input), ref)

        # functions exist also to get at the private modules in each package
        torchvision = i.import_module("torchvision")

        f2 = BytesIO()
        # if we are doing transfer learning we might want to re-save
        # things that were loaded from a package.
        # We need to tell the exporter about any modules that
        # came from imported packages so that it can resolve
        # class names like torchvision.models.resnet.ResNet
        # to their source code.
        with PackageExporter(f2, importer=(i, sys_importer)) as e:
            # e.importers is a list of module importing functions
            # that by default contains importlib.import_module.
            # it is searched in order until the first success and
            # that module is taken to be what torchvision.models.resnet
            # should be in this code package. In the case of name collisions,
            # such as trying to save a ResNet from two different packages,
            # we take the first thing found in the path, so only ResNet objects from
            # one importer will work. This avoids a bunch of name mangling in
            # the source code. If you need to actually mix ResNet objects,
            # we suggest reconstructing the model objects using code from a single package
            # using functions like save_state_dict and load_state_dict to transfer state
            # to the correct code objects.
            e.intern("**")
            e.save_pickle("model", "model.pkl", r2)

        f2.seek(0)

        i2 = PackageImporter(f2)
        r3 = i2.load_pickle("model", "model.pkl")
        self.assertEqual(r3(input), ref)
    def test_save_repeat_scriptmodules(self):
        """
        Test to verify saving multiple different modules and
        repeats of same scriptmodule in package works. Also tests that
        PyTorchStreamReader isn't having code hidden from
        PyTorchStreamWriter writing ScriptModule code files multiple times.
        """
        from package_a.test_module import (
            ModWithSubmodAndTensor,
            ModWithTensor,
            SimpleTest,
        )

        scripted_mod_0 = torch.jit.script(SimpleTest())
        scripted_mod_1 = torch.jit.script(ModWithTensor(torch.rand(1, 2, 3)))
        scripted_mod_2 = torch.jit.script(
            ModWithSubmodAndTensor(
                torch.rand(1, 2, 3), ModWithTensor(torch.rand(1, 2, 3))
            )
        )

        buffer = BytesIO()
        with PackageExporter(buffer) as e:
            e.save_pickle("res", "mod0.pkl", scripted_mod_0)
            e.save_pickle("res", "mod1.pkl", scripted_mod_1)
            e.save_pickle("res", "mod2.pkl", scripted_mod_0)
            e.save_pickle("res", "mod3.pkl", scripted_mod_1)
            e.save_pickle("res", "mod4.pkl", scripted_mod_2)

        buffer.seek(0)
        importer = PackageImporter(buffer)
        loaded_mod_0 = importer.load_pickle("res", "mod0.pkl")
        loaded_mod_1 = importer.load_pickle("res", "mod3.pkl")
        loaded_mod_2 = importer.load_pickle("res", "mod4.pkl")
        input = torch.rand(1, 2, 3)
        self.assertEqual(loaded_mod_0(input), scripted_mod_0(input))
        self.assertEqual(loaded_mod_1(input), scripted_mod_1(input))
        self.assertEqual(loaded_mod_2(input), scripted_mod_2(input))