コード例 #1
0
ファイル: test_importer.py プロジェクト: zacker150/pytorch
    def test_package_importer_whichmodule_no_dunder_module(self):
        """Exercise corner case where we try to pickle an object whose
        __module__ doesn't exist because it's from a C extension.
        """
        # torch.float16 is an example of such an object: it is a C extension
        # type for which there is no __module__ defined. The default pickler
        # finds it using special logic to traverse sys.modules and look up
        # `float16` on each module (see pickle.py:whichmodule).
        #
        # We must ensure that we emulate the same behavior from PackageImporter.
        my_dtype = torch.float16

        # Set up a PackageImporter which has a torch.float16 object pickled:
        buffer = BytesIO()
        with PackageExporter(buffer) as exporter:
            exporter.save_pickle("foo", "foo.pkl", my_dtype)
        buffer.seek(0)

        importer = PackageImporter(buffer)
        my_loaded_dtype = importer.load_pickle("foo", "foo.pkl")

        # Re-save a package with only our PackageImporter as the importer
        buffer2 = BytesIO()
        with PackageExporter(buffer2, importer=importer) as exporter:
            exporter.save_pickle("foo", "foo.pkl", my_loaded_dtype)

        buffer2.seek(0)

        importer2 = PackageImporter(buffer2)
        my_loaded_dtype2 = importer2.load_pickle("foo", "foo.pkl")
        self.assertIs(my_dtype, my_loaded_dtype)
        self.assertIs(my_dtype, my_loaded_dtype2)
コード例 #2
0
ファイル: test_package_script.py プロジェクト: xsacha/pytorch
    def test_scriptmodules_repeat_save(self):
        """
        Test to verify saving and loading same ScriptModule object works
        across multiple packages.
        """
        from package_a.test_module import ModWithSubmodAndTensor, ModWithTensor

        scripted_mod_0 = torch.jit.script(ModWithTensor(torch.rand(1, 2, 3)))
        scripted_mod_1 = torch.jit.script(
            ModWithSubmodAndTensor(torch.rand(1, 2, 3),
                                   ModWithTensor(torch.rand(1, 2, 3))))

        buffer_0 = BytesIO()
        with PackageExporter(buffer_0) as e:
            e.save_pickle("res", "mod1.pkl", scripted_mod_0)

        buffer_0.seek(0)
        importer_0 = PackageImporter(buffer_0)
        loaded_module_0 = importer_0.load_pickle("res", "mod1.pkl")

        buffer_1 = BytesIO()
        with PackageExporter(buffer_1) as e:
            e.save_pickle("res", "mod1.pkl", scripted_mod_1)
            e.save_pickle("res", "mod2.pkl", loaded_module_0)

        buffer_1.seek(0)
        importer_1 = PackageImporter(buffer_1)
        loaded_module_1 = importer_1.load_pickle("res", "mod1.pkl")
        reloaded_module_0 = importer_1.load_pickle("res", "mod2.pkl")

        input = torch.rand(1, 2, 3)
        self.assertEqual(loaded_module_0(input), scripted_mod_0(input))
        self.assertEqual(loaded_module_0(input), reloaded_module_0(input))
        self.assertEqual(loaded_module_1(input), scripted_mod_1(input))
コード例 #3
0
    def test_tensor_sharing_pickle(self):
        """Test that saving a ScriptModule and a separately saving a tensor
        object causes no issues.
        """

        class M(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.foo = torch.ones(2, 3)

            def forward(self):
                return self.foo

        scripted_m = torch.jit.script(M())
        original_tensor = torch.ones(0)

        f = BytesIO()
        with torch.package.PackageExporter(f) as exporter:
            exporter.save_pickle("model", "model.pkl", scripted_m)
            exporter.save_pickle("model", "input.pkl", original_tensor)

        f.seek(0)
        # Should be able to load correctly
        importer = PackageImporter(f)
        loaded_m = importer.load_pickle("model", "model.pkl")
        loaded_tensor = importer.load_pickle("model", "input.pkl")

        self.assertEqual(scripted_m.foo, loaded_m.foo)
        self.assertEqual(original_tensor, loaded_tensor)
コード例 #4
0
    def test_package_fx_package(self):
        from package_a.test_module import SimpleTest

        model = SimpleTest()
        f = BytesIO()
        with PackageExporter(f) as pe:
            pe.intern("**")
            pe.save_pickle("model", "model.pkl", model)

        f.seek(0)
        pi = PackageImporter(f)
        loaded = pi.load_pickle("model", "model.pkl")
        traced = symbolic_trace(loaded)

        # re-save the package exporter
        f2 = BytesIO()
        # This should fail, because we are referencing some globals that are
        # only in the package.
        with self.assertRaises(ObjMismatchError):
            with PackageExporter(f2) as pe:
                pe.intern("**")
                pe.save_pickle("model", "model.pkl", traced)

        f2.seek(0)
        with PackageExporter(f2, importer=(pi, sys_importer)) as pe:
            # Make the package available to the exporter's environment.
            pe.intern("**")
            pe.save_pickle("model", "model.pkl", traced)
        f2.seek(0)
        pi2 = PackageImporter(f2)
        loaded2 = pi2.load_pickle("model", "model.pkl")

        input = torch.rand(2, 3)
        self.assertTrue(torch.allclose(loaded(input), loaded2(input)))
コード例 #5
0
    def test_saving_and_scripting_packaged_mod(self):
        """
        Test scripting a module loaded from a package
        and saving it in a new package as a script object.
        """
        from package_a.test_module import SimpleTest

        orig_mod = SimpleTest()

        buffer_0 = BytesIO()
        with PackageExporter(buffer_0) as e:
            e.intern("**")
            e.save_pickle("model", "model.pkl", orig_mod)

        buffer_0.seek(0)
        importer_0 = PackageImporter(buffer_0)
        loaded_mod = importer_0.load_pickle("model", "model.pkl")

        input = torch.rand(2, 3)
        self.assertEqual(loaded_mod(input), orig_mod(input))

        scripted_mod = torch.jit.script(loaded_mod)

        buffer_1 = BytesIO()
        with PackageExporter(buffer_1, importer=importer_0) as e:
            e.intern("**")
            e.save_pickle("res", "scripted_mod.pkl", scripted_mod)

        buffer_1.seek(0)
        importer_1 = PackageImporter(buffer_1)
        loaded_mod_scripted = importer_1.load_pickle("res", "scripted_mod.pkl")

        self.assertEqual(loaded_mod_scripted(input), orig_mod(input))
コード例 #6
0
ファイル: test_model.py プロジェクト: ydcjeff/pytorch
    def test_resnet(self):
        resnet = resnet18()

        f1 = self.temp()

        # create a package that will save it along with its code
        with PackageExporter(f1, verbose=False) as e:
            # put the pickled resnet in the package, by default
            # this will also save all the code files references by
            # the objects in the pickle
            e.intern("**")
            e.save_pickle("model", "model.pkl", resnet)

            # check th debug graph has something reasonable:
            buf = StringIO()
            debug_graph = e._write_dep_graph(failing_module="torch")
            self.assertIn("torchvision.models.resnet", debug_graph)

        # we can now load the saved model
        i = PackageImporter(f1)
        r2 = i.load_pickle("model", "model.pkl")

        # test that it works
        input = torch.rand(1, 3, 224, 224)
        ref = resnet(input)
        self.assertTrue(torch.allclose(r2(input), ref))

        # functions exist also to get at the private modules in each package
        torchvision = i.import_module("torchvision")

        f2 = BytesIO()
        # if we are doing transfer learning we might want to re-save
        # things that were loaded from a package.
        # We need to tell the exporter about any modules that
        # came from imported packages so that it can resolve
        # class names like torchvision.models.resnet.ResNet
        # to their source code.
        with PackageExporter(f2, verbose=False,
                             importer=(i, sys_importer)) as e:
            # e.importers is a list of module importing functions
            # that by default contains importlib.import_module.
            # it is searched in order until the first success and
            # that module is taken to be what torchvision.models.resnet
            # should be in this code package. In the case of name collisions,
            # such as trying to save a ResNet from two different packages,
            # we take the first thing found in the path, so only ResNet objects from
            # one importer will work. This avoids a bunch of name mangling in
            # the source code. If you need to actually mix ResNet objects,
            # we suggest reconstructing the model objects using code from a single package
            # using functions like save_state_dict and load_state_dict to transfer state
            # to the correct code objects.
            e.intern("**")
            e.save_pickle("model", "model.pkl", r2)

        f2.seek(0)

        i2 = PackageImporter(f2)
        r3 = i2.load_pickle("model", "model.pkl")
        self.assertTrue(torch.allclose(r3(input), ref))
コード例 #7
0
    def test_load_shared_tensors_repackaged(self):
        """
        Test tensors shared across eager and ScriptModules on load
        are the same across multiple package saves and loads. This is
        an important test because not all of the tensor information is restored
        in python between packages. The python identity is not maintained, but
        the backing cpp TensorImpl is. We load/save storages based off of this
        cpp TensorImpl and not the python identity.
        """
        from package_a.test_module import (
            ModWithTensor,
            ModWithTwoSubmodsAndTensor,
        )

        shared_tensor = torch.ones(3, 3)

        scripted_mod_0 = torch.jit.script(ModWithTensor(shared_tensor))
        scripted_mod_1 = torch.jit.script(ModWithTensor(shared_tensor))

        mod1 = ModWithTwoSubmodsAndTensor(shared_tensor, scripted_mod_0, scripted_mod_1)

        buffer_0 = BytesIO()
        with PackageExporter(buffer_0) as e:
            e.intern("**")
            e.save_pickle("res", "mod1.pkl", mod1)

        buffer_0.seek(0)
        importer_0 = PackageImporter(buffer_0)
        loaded_mod_0 = importer_0.load_pickle("res", "mod1.pkl")

        buffer_1 = BytesIO()
        with PackageExporter(buffer_1, importer=importer_0) as e:
            e.intern("**")
            e.save_pickle("res", "mod1.pkl", loaded_mod_0)

        buffer_1.seek(0)
        importer = PackageImporter(buffer_1)
        loaded_mod_1 = importer.load_pickle("res", "mod1.pkl")

        self.assertEqual(
            loaded_mod_1.tensor.storage()._cdata,
            loaded_mod_1.sub_mod_0.tensor.storage()._cdata,
        )
        self.assertEqual(
            loaded_mod_1.tensor.storage()._cdata,
            loaded_mod_1.sub_mod_1.tensor.storage()._cdata,
        )

        loaded_mod_1.tensor.add_(
            torch.ones(3, 3)
        )  # all tensors should reflect this change

        self.assertTrue(
            torch.allclose(loaded_mod_1.tensor, loaded_mod_1.sub_mod_0.tensor)
        )
        self.assertTrue(
            torch.allclose(loaded_mod_1.tensor, loaded_mod_1.sub_mod_1.tensor)
        )
コード例 #8
0
ファイル: test_package.py プロジェクト: yiqxiaobai/pytorch
    def test_pickle_mocked(self):
        import package_a.subpackage
        obj = package_a.subpackage.PackageASubpackageObject()
        obj2 = package_a.PackageAObject(obj)

        filename = self.temp()
        with PackageExporter(filename, verbose=False) as he:
            he.mock(include='package_a.subpackage')
            he.save_pickle('obj', 'obj.pkl', obj2)

        hi = PackageImporter(filename)
        with self.assertRaises(NotImplementedError):
            hi.load_pickle('obj', 'obj.pkl')
コード例 #9
0
    def test_scriptobject_failure_message(self):
        """
        Test basic saving and loading of a ScriptModule in a directory.
        Currently not supported.
        """
        from package_a.test_module import ModWithTensor

        scripted_mod = torch.jit.script(ModWithTensor(torch.rand(1, 2, 3)))

        filename = self.temp()
        with PackageExporter(filename, verbose=False) as e:
            e.save_pickle("res", "mod.pkl", scripted_mod)

        zip_file = zipfile.ZipFile(filename, "r")

        with self.assertRaisesRegex(
                RuntimeError,
                "Loading ScriptObjects from a PackageImporter created from a "
                "directory is not supported. Use a package archive file instead.",
        ):
            with TemporaryDirectory() as temp_dir:
                zip_file.extractall(path=temp_dir)
                dir_importer = PackageImporter(
                    Path(temp_dir) / Path(filename).name)
                dir_mod = dir_importer.load_pickle("res", "mod.pkl")
コード例 #10
0
    def test_repackage_import_indirectly_via_parent_module(self):
        from package_d.imports_directly import ImportsDirectlyFromSubSubPackage
        from package_d.imports_indirectly import ImportsIndirectlyFromSubPackage

        model_a = ImportsDirectlyFromSubSubPackage()
        buffer = BytesIO()
        with PackageExporter(buffer) as pe:
            pe.intern("**")
            pe.save_pickle("default", "model.py", model_a)

        buffer.seek(0)
        pi = PackageImporter(buffer)
        loaded_model = pi.load_pickle("default", "model.py")

        model_b = ImportsIndirectlyFromSubPackage()
        buffer = BytesIO()
        with PackageExporter(
                buffer,
                importer=(
                    pi,
                    sys_importer,
                ),
        ) as pe:
            pe.intern("**")
            pe.save_pickle("default", "model_b.py", model_b)
コード例 #11
0
    def test_package_interface(self):
        """Packaging an interface class should work correctly."""

        import package_a.fake_interface as fake

        uses_interface = fake.UsesInterface()
        scripted = torch.jit.script(uses_interface)
        scripted.proxy_mod = torch.jit.script(fake.NewModule())

        buffer = BytesIO()
        with PackageExporter(buffer) as pe:
            pe.intern("**")
            pe.save_pickle("model", "model.pkl", uses_interface)
        buffer.seek(0)

        package_importer = PackageImporter(buffer)
        loaded = package_importer.load_pickle("model", "model.pkl")

        scripted_loaded = torch.jit.script(loaded)
        scripted_loaded.proxy_mod = torch.jit.script(fake.NewModule())

        input = torch.tensor(1)

        self.assertTrue(torch.allclose(scripted(input),
                                       scripted_loaded(input)))
コード例 #12
0
    def test_pickle_mocked(self):
        import package_a.subpackage

        obj = package_a.subpackage.PackageASubpackageObject()
        obj2 = package_a.PackageAObject(obj)

        buffer = BytesIO()
        with PackageExporter(buffer, verbose=False) as he:
            he.mock(include="package_a.subpackage")
            he.save_pickle("obj", "obj.pkl", obj2)

        buffer.seek(0)

        hi = PackageImporter(buffer)
        with self.assertRaises(NotImplementedError):
            hi.load_pickle("obj", "obj.pkl")
コード例 #13
0
    def test_dunder_imports(self):
        buffer = BytesIO()
        with PackageExporter(buffer) as he:
            import package_b

            obj = package_b.PackageBObject
            he.intern("**")
            he.save_pickle("res", "obj.pkl", obj)

        buffer.seek(0)
        hi = PackageImporter(buffer)
        loaded_obj = hi.load_pickle("res", "obj.pkl")

        package_b = hi.import_module("package_b")
        self.assertEqual(package_b.result, "package_b")

        math = hi.import_module("math")
        self.assertEqual(math.__name__, "math")

        xml_sub_sub_package = hi.import_module("xml.sax.xmlreader")
        self.assertEqual(xml_sub_sub_package.__name__, "xml.sax.xmlreader")

        subpackage_1 = hi.import_module("package_b.subpackage_1")
        self.assertEqual(subpackage_1.result, "subpackage_1")

        subpackage_2 = hi.import_module("package_b.subpackage_2")
        self.assertEqual(subpackage_2.result, "subpackage_2")

        subsubpackage_0 = hi.import_module("package_b.subpackage_0.subsubpackage_0")
        self.assertEqual(subsubpackage_0.result, "subsubpackage_0")
コード例 #14
0
    def test_package_fx_with_imports(self):
        import package_a.subpackage

        # Manually construct a graph that invokes a leaf function
        graph = Graph()
        a = graph.placeholder("x")
        b = graph.placeholder("y")
        c = graph.call_function(package_a.subpackage.leaf_function, (a, b))
        d = graph.call_function(torch.sin, (c, ))
        graph.output(d)
        gm = GraphModule(torch.nn.Module(), graph)

        f = BytesIO()
        with PackageExporter(f) as pe:
            pe.intern("**")
            pe.save_pickle("model", "model.pkl", gm)
        f.seek(0)

        pi = PackageImporter(f)
        loaded_gm = pi.load_pickle("model", "model.pkl")
        input_x = torch.rand(2, 3)
        input_y = torch.rand(2, 3)

        self.assertTrue(
            torch.allclose(loaded_gm(input_x, input_y), gm(input_x, input_y)))

        # Check that the packaged version of the leaf_function dependency is
        # not the same as in the outer env.
        packaged_dependency = pi.import_module("package_a.subpackage")
        self.assertTrue(packaged_dependency is not package_a.subpackage)
コード例 #15
0
    def test_save_shared_tensors(self):
        """
        Test tensors shared across eager and ScriptModules are serialized once.
        """
        from package_a.test_module import ModWithSubmodAndTensor, ModWithTensor

        shared_tensor = torch.rand(2, 3, 4)
        scripted_mod = torch.jit.script(ModWithTensor(shared_tensor))

        mod1 = ModWithSubmodAndTensor(shared_tensor, scripted_mod)
        mod2 = ModWithSubmodAndTensor(shared_tensor, scripted_mod)

        buffer = BytesIO()
        with PackageExporter(buffer) as e:
            e.intern("**")
            e.save_pickle("res", "tensor", shared_tensor)
            e.save_pickle("res", "mod1.pkl", mod1)
            e.save_pickle("res", "mod2.pkl", mod2)

        buffer.seek(0)
        importer = PackageImporter(buffer)
        loaded_mod_1 = importer.load_pickle("res", "mod1.pkl")

        # assert that there is only one storage stored in package
        file_structure = importer.file_structure(include=".data/*.storage")
        self.assertTrue(len(file_structure.children[".data"].children) == 1)

        input = torch.rand(2, 3, 4)
        self.assertEqual(loaded_mod_1(input), mod1(input))
コード例 #16
0
    def test_script_resnet(self):
        resnet = resnet18()

        f1 = BytesIO()
        # Option 1: save by pickling the whole model
        # + single-line, similar to torch.jit.save
        # - more difficult to edit the code after the model is created
        with PackageExporter(f1, verbose=False) as e:
            e.save_pickle("model", "pickled", resnet)

        f1.seek(0)

        i = PackageImporter(f1)
        loaded = i.load_pickle("model", "pickled")

        # Model should script successfully.
        scripted = torch.jit.script(loaded)

        # Scripted model should save and load successfully.
        f2 = BytesIO()
        torch.jit.save(scripted, f2)
        f2.seek(0)
        loaded = torch.jit.load(f2)

        input = torch.rand(1, 3, 224, 224)
        self.assertTrue(torch.allclose((loaded(input)), resnet(input)))
コード例 #17
0
    def test_load_shared_scriptmodules(self):
        """
        Test loading of single ScriptModule shared by multiple eager
        modules in single pickle (ScriptModule objects should be the same).
        """
        from package_a.test_module import (
            ModWithMultipleSubmods,
            ModWithSubmod,
            SimpleTest,
        )

        scripted_mod = torch.jit.script(SimpleTest())

        mod1 = ModWithSubmod(scripted_mod)
        mod2 = ModWithSubmod(scripted_mod)

        mod_parent = ModWithMultipleSubmods(mod1, mod2)

        buffer = BytesIO()
        with PackageExporter(buffer) as e:
            e.intern("**")
            e.save_pickle("res", "mod.pkl", mod_parent)

        buffer.seek(0)
        importer = PackageImporter(buffer)

        loaded_mod = importer.load_pickle("res", "mod.pkl")
        self.assertTrue(
            id(loaded_mod.mod1.script_mod) == id(loaded_mod.mod2.script_mod)
        )
コード例 #18
0
ファイル: test_package.py プロジェクト: yiqxiaobai/pytorch
    def test_unique_module_names(self):
        import package_a.subpackage
        obj = package_a.subpackage.PackageASubpackageObject()
        obj2 = package_a.PackageAObject(obj)
        f1 = self.temp()
        with PackageExporter(f1, verbose=False) as pe:
            pe.save_pickle("obj", "obj.pkl", obj2)

        importer1 = PackageImporter(f1)
        loaded1 = importer1.load_pickle("obj", "obj.pkl")
        importer2 = PackageImporter(f1)
        loaded2 = importer2.load_pickle("obj", "obj.pkl")

        # Modules from loaded packages should not shadow the names of modules.
        # See mangling.md for more info.
        self.assertNotEqual(type(obj2).__module__, type(loaded1).__module__)
        self.assertNotEqual(type(loaded1).__module__, type(loaded2).__module__)
コード例 #19
0
    def test_exporting_mismatched_code(self):
        """
        If an object with the same qualified name is loaded from different
        packages, the user should get an error if they try to re-save the
        object with the wrong package's source code.
        """
        import package_a.subpackage

        obj = package_a.subpackage.PackageASubpackageObject()
        obj2 = package_a.PackageAObject(obj)
        f1 = self.temp()
        with PackageExporter(f1, verbose=False) as pe:
            pe.intern("**")
            pe.save_pickle("obj", "obj.pkl", obj2)

        importer1 = PackageImporter(f1)
        loaded1 = importer1.load_pickle("obj", "obj.pkl")
        importer2 = PackageImporter(f1)
        loaded2 = importer2.load_pickle("obj", "obj.pkl")

        f2 = self.temp()

        def make_exporter():
            pe = PackageExporter(f2,
                                 verbose=False,
                                 importer=[importer1, sys_importer])
            # Ensure that the importer finds the 'PackageAObject' defined in 'importer1' first.
            return pe

        # This should fail. The 'PackageAObject' type defined from 'importer1'
        # is not necessarily the same 'obj2's version of 'PackageAObject'.
        pe = make_exporter()
        with self.assertRaises(pickle.PicklingError):
            pe.save_pickle("obj", "obj.pkl", obj2)

        # This should also fail. The 'PackageAObject' type defined from 'importer1'
        # is not necessarily the same as the one defined from 'importer2'
        pe = make_exporter()
        with self.assertRaises(pickle.PicklingError):
            pe.save_pickle("obj", "obj.pkl", loaded2)

        # This should succeed. The 'PackageAObject' type defined from
        # 'importer1' is a match for the one used by loaded1.
        pe = make_exporter()
        pe.save_pickle("obj", "obj.pkl", loaded1)
コード例 #20
0
    def test_mixing_packaged_and_inline_modules_shared_code(self):
        """
        Test saving inline and imported modules in same package that
        share code.
        """
        class TorchVisionTestInline(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.tvmod = resnet18()

            def forward(self, x):
                x = a_non_torch_leaf(x, x)
                return torch.relu(x + 3.0)

        def a_non_torch_leaf(a, b):
            return a + b

        inline_mod = TorchVisionTestInline()
        scripted_inline = torch.jit.script(inline_mod)

        from package_c.test_module import TorchVisionTest

        imported_mod = TorchVisionTest()
        scripted_imported = torch.jit.script(imported_mod)

        buffer = BytesIO()
        with PackageExporter(buffer, verbose=False) as e:
            e.save_pickle("model", "inline.pkl", scripted_inline)
            e.save_pickle("model", "imported.pkl", scripted_imported)

        buffer.seek(0)
        importer = PackageImporter(buffer)
        loaded_inline = importer.load_pickle("model", "inline.pkl")
        loaded_imported = importer.load_pickle("model", "imported.pkl")

        input = torch.rand(2, 3)
        self.assertTrue(
            torch.allclose(loaded_imported(input), imported_mod(input)))
        self.assertTrue(torch.allclose(loaded_inline(input),
                                       inline_mod(input)))
コード例 #21
0
    def test_load_shared_tensors(self):
        """
        Test tensors shared across eager and ScriptModules on load
        are the same.
        """
        from package_a.test_module import (
            ModWithTensor,
            ModWithTwoSubmodsAndTensor,
        )

        shared_tensor = torch.ones(3, 3)

        scripted_mod_0 = torch.jit.script(ModWithTensor(shared_tensor))
        scripted_mod_1 = torch.jit.script(ModWithTensor(shared_tensor))

        mod1 = ModWithTwoSubmodsAndTensor(shared_tensor, scripted_mod_0, scripted_mod_1)

        self.assertEqual(
            shared_tensor.storage()._cdata,
            scripted_mod_0.tensor.storage()._cdata,
        )
        self.assertEqual(
            shared_tensor.storage()._cdata,
            scripted_mod_1.tensor.storage()._cdata,
        )

        buffer = BytesIO()
        with PackageExporter(buffer) as e:
            e.intern("**")
            e.save_pickle("res", "mod1.pkl", mod1)

        buffer.seek(0)
        importer = PackageImporter(buffer)
        loaded_mod_1 = importer.load_pickle("res", "mod1.pkl")

        self.assertEqual(
            loaded_mod_1.tensor.storage()._cdata,
            loaded_mod_1.sub_mod_0.tensor.storage()._cdata,
        )
        self.assertEqual(
            loaded_mod_1.tensor.storage()._cdata,
            loaded_mod_1.sub_mod_1.tensor.storage()._cdata,
        )

        loaded_mod_1.tensor.add_(torch.ones(3, 3))

        self.assertTrue(
            torch.allclose(loaded_mod_1.tensor, loaded_mod_1.sub_mod_0.tensor)
        )
        self.assertTrue(
            torch.allclose(loaded_mod_1.tensor, loaded_mod_1.sub_mod_1.tensor)
        )
コード例 #22
0
ファイル: test_package.py プロジェクト: yiqxiaobai/pytorch
    def test_script_resnet(self):
        resnet = resnet18()

        f1 = self.temp()
        # Option 1: save by pickling the whole model
        # + single-line, similar to torch.jit.save
        # - more difficult to edit the code after the model is created
        with PackageExporter(f1, verbose=False) as e:
            e.save_pickle('model', 'pickled', resnet)

        i = PackageImporter(f1)
        loaded = i.load_pickle('model', 'pickled')
        torch.jit.script(loaded)
コード例 #23
0
    def test_package_then_fx(self):
        from package_a.test_module import SimpleTest
        model = SimpleTest()
        f = BytesIO()
        with PackageExporter(f, verbose=False) as pe:
            pe.save_pickle('model', 'model.pkl', model)

        f.seek(0)
        pi = PackageImporter(f)
        loaded = pi.load_pickle('model', 'model.pkl')
        traced = symbolic_trace(loaded)
        input = torch.rand(2, 3)
        self.assertTrue(torch.allclose(loaded(input), traced(input)))
コード例 #24
0
    def test_externing_c_extension(self):
        """Externing c extensions modules should allow us to still access them especially those found in torch._C."""

        buffer = BytesIO()
        # The C extension module in question is F.gelu which comes from torch._C._nn
        model = torch.nn.TransformerEncoderLayer(
            d_model=64,
            nhead=2,
            dim_feedforward=64,
            dropout=1.0,
            batch_first=True,
            activation="gelu",
            norm_first=True,
        )
        with PackageExporter(buffer) as e:
            e.extern("torch.**")
            e.intern("**")

            e.save_pickle("model", "model.pkl", model)
        buffer.seek(0)
        imp = PackageImporter(buffer)
        imp.load_pickle("model", "model.pkl")
コード例 #25
0
    def test_save_repeat_scriptmodules(self):
        """
        Test to verify saving multiple different modules and
        repeats of same scriptmodule in package works. Also tests that
        PyTorchStreamReader isn't having code hidden from
        PyTorchStreamWriter writing ScriptModule code files multiple times.
        """
        from package_a.test_module import (
            ModWithSubmodAndTensor,
            ModWithTensor,
            SimpleTest,
        )

        scripted_mod_0 = torch.jit.script(SimpleTest())
        scripted_mod_1 = torch.jit.script(ModWithTensor(torch.rand(1, 2, 3)))
        scripted_mod_2 = torch.jit.script(
            ModWithSubmodAndTensor(
                torch.rand(1, 2, 3), ModWithTensor(torch.rand(1, 2, 3))
            )
        )

        buffer = BytesIO()
        with PackageExporter(buffer) as e:
            e.save_pickle("res", "mod0.pkl", scripted_mod_0)
            e.save_pickle("res", "mod1.pkl", scripted_mod_1)
            e.save_pickle("res", "mod2.pkl", scripted_mod_0)
            e.save_pickle("res", "mod3.pkl", scripted_mod_1)
            e.save_pickle("res", "mod4.pkl", scripted_mod_2)

        buffer.seek(0)
        importer = PackageImporter(buffer)
        loaded_mod_0 = importer.load_pickle("res", "mod0.pkl")
        loaded_mod_1 = importer.load_pickle("res", "mod3.pkl")
        loaded_mod_2 = importer.load_pickle("res", "mod4.pkl")
        input = torch.rand(1, 2, 3)
        self.assertEqual(loaded_mod_0(input), scripted_mod_0(input))
        self.assertEqual(loaded_mod_1(input), scripted_mod_1(input))
        self.assertEqual(loaded_mod_2(input), scripted_mod_2(input))
コード例 #26
0
    def test_save_scriptmodules_submod_redefinition(self):
        """
        Test to verify saving multiple ScriptModules with same top module
        but different submodules works. Submodule is redefined to between
        the defintion of the top module to check that the different concrete
        types of the modules are thoroughly recognized by serializaiton code.
        """

        class Submod(torch.nn.Module):
            def __init__(self):
                super().__init__()

            def forward(self, input: str):
                input = input + "_submod"
                return input

        class TopMod(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.modB = Submod()

            def forward(self, input: str):
                return self.modB(input)

        scripted_mod_0 = torch.jit.script(TopMod())

        # redefinition is intentional, change single inner string
        # string attribute, should trigger new module type
        class Submod(torch.nn.Module):  # noqa: F811
            def __init__(self):
                super().__init__()

            def forward(self, input: str):
                input = input + "_submod(changed)"
                return input

        scripted_mod_1 = torch.jit.script(TopMod())

        buffer = BytesIO()
        with PackageExporter(buffer) as e:
            e.save_pickle("res", "mod1.pkl", scripted_mod_0)
            e.save_pickle("res", "mod2.pkl", scripted_mod_1)

        buffer.seek(0)
        importer = PackageImporter(buffer)
        loaded_mod_0 = importer.load_pickle("res", "mod1.pkl")
        loaded_mod_1 = importer.load_pickle("res", "mod2.pkl")
        self.assertEqual(loaded_mod_0("input"), scripted_mod_0("input"))
        self.assertEqual(loaded_mod_1("input"), scripted_mod_1("input"))
        self.assertNotEqual(loaded_mod_0("input"), loaded_mod_1("input"))
コード例 #27
0
    def test_mixing_packaged_and_inline_modules(self):
        """
        Test saving inline and imported modules in same package with
        independent code.
        """

        class InlineMod(torch.nn.Module):
            def __init__(self, name: str):
                super().__init__()
                self.name = name
                self.tensor = torch.rand(1, 2, 3)

            def forward(self, input: str):
                input = input + "_modInline:" + self.name
                return input, (self.tensor * 4)

        inline_mod = InlineMod("inline")
        scripted_inline = torch.jit.script(inline_mod)

        from package_a.test_module import SimpleTest

        imported_mod = SimpleTest()
        scripted_imported = torch.jit.script(imported_mod)

        buffer = BytesIO()
        with PackageExporter(buffer) as e:
            e.save_pickle("model", "inline.pkl", scripted_inline)
            e.save_pickle("model", "imported.pkl", scripted_imported)

        buffer.seek(0)
        importer = PackageImporter(buffer)
        loaded_inline = importer.load_pickle("model", "inline.pkl")
        loaded_imported = importer.load_pickle("model", "imported.pkl")

        input = torch.rand(2, 3)
        self.assertEqual(loaded_imported(input), imported_mod(input))
        self.assertEqual(loaded_inline("input"), inline_mod("input"))
コード例 #28
0
    def test_package_then_fx(self):
        from package_a.test_module import SimpleTest

        model = SimpleTest()
        f = BytesIO()
        with PackageExporter(f) as pe:
            pe.intern("**")
            pe.save_pickle("model", "model.pkl", model)

        f.seek(0)
        pi = PackageImporter(f)
        loaded = pi.load_pickle("model", "model.pkl")
        traced = symbolic_trace(loaded)
        input = torch.rand(2, 3)
        self.assertEqual(loaded(input), traced(input))
コード例 #29
0
    def test_save_scriptmodule_file(self):
        """
        Test basic saving of ScriptModule in file.
        """
        from package_a.test_module import ModWithTensor

        scripted_mod = torch.jit.script(ModWithTensor(torch.rand(1, 2, 3)))

        filename = self.temp()
        with PackageExporter(filename) as e:
            e.save_pickle("res", "mod.pkl", scripted_mod)

        importer = PackageImporter(filename)
        loaded_mod = importer.load_pickle("res", "mod.pkl")
        input = torch.rand(1, 2, 3)
        self.assertEqual(loaded_mod(input), scripted_mod(input))
コード例 #30
0
    def test_save_scriptmodule(self):
        """
        Test basic saving of ScriptModule.
        """
        from package_a.test_module import ModWithTensor

        scripted_mod = torch.jit.script(ModWithTensor(torch.rand(1, 2, 3)))

        buffer = BytesIO()
        with PackageExporter(buffer) as e:
            e.save_pickle("res", "mod.pkl", scripted_mod)

        buffer.seek(0)
        importer = PackageImporter(buffer)
        loaded_mod = importer.load_pickle("res", "mod.pkl", map_location="cpu")
        input = torch.rand(1, 2, 3)
        self.assertEqual(loaded_mod(input), scripted_mod(input))