예제 #1
0
    def test_save_shared_tensors(self):
        """
        Test tensors shared across eager and ScriptModules are serialized once.
        """
        from package_a.test_module import ModWithSubmodAndTensor, ModWithTensor

        shared_tensor = torch.rand(2, 3, 4)
        scripted_mod = torch.jit.script(ModWithTensor(shared_tensor))

        mod1 = ModWithSubmodAndTensor(shared_tensor, scripted_mod)
        mod2 = ModWithSubmodAndTensor(shared_tensor, scripted_mod)

        buffer = BytesIO()
        with PackageExporter(buffer) as e:
            e.intern("**")
            e.save_pickle("res", "tensor", shared_tensor)
            e.save_pickle("res", "mod1.pkl", mod1)
            e.save_pickle("res", "mod2.pkl", mod2)

        buffer.seek(0)
        importer = PackageImporter(buffer)
        loaded_mod_1 = importer.load_pickle("res", "mod1.pkl")

        # assert that there is only one storage stored in package
        file_structure = importer.file_structure(include=".data/*.storage")
        self.assertTrue(len(file_structure.children[".data"].children) == 1)

        input = torch.rand(2, 3, 4)
        self.assertEqual(loaded_mod_1(input), mod1(input))
예제 #2
0
    def test_file_structure(self):
        filename = self.temp()

        export_plain = """\
    ├── main
    │   └── main
    ├── obj
    │   └── obj.pkl
    ├── package_a
    │   ├── __init__.py
    │   └── subpackage.py
    └── module_a.py
"""
        export_include = """\
    ├── obj
    │   └── obj.pkl
    └── package_a
        └── subpackage.py
"""
        import_exclude = """\
    ├── .data
    │   ├── extern_modules
    │   └── version
    ├── main
    │   └── main
    ├── obj
    │   └── obj.pkl
    ├── package_a
    │   ├── __init__.py
    │   └── subpackage.py
    └── module_a.py
"""

        with PackageExporter(filename, verbose=False) as he:
            import module_a
            import package_a
            import package_a.subpackage
            obj = package_a.subpackage.PackageASubpackageObject()
            he.save_module(module_a.__name__)
            he.save_module(package_a.__name__)
            he.save_pickle('obj', 'obj.pkl', obj)
            he.save_text('main', 'main', "my string")

            export_file_structure = he.file_structure()
            # remove first line from testing because WINDOW/iOS/Unix treat the filename differently
            self.assertEqual(
                '\n'.join(str(export_file_structure).split('\n')[1:]),
                export_plain)
            export_file_structure = he.file_structure(
                include=["**/subpackage.py", "**/*.pkl"])
            self.assertEqual(
                '\n'.join(str(export_file_structure).split('\n')[1:]),
                export_include)

        hi = PackageImporter(filename)
        import_file_structure = hi.file_structure(exclude="**/*.storage")
        self.assertEqual('\n'.join(str(import_file_structure).split('\n')[1:]),
                         import_exclude)
예제 #3
0
    def test_file_structure_has_file(self):
        """
        Test Directory's has_file() method.
        """
        buffer = BytesIO()
        with PackageExporter(buffer, verbose=False) as he:
            import package_a.subpackage

            obj = package_a.subpackage.PackageASubpackageObject()
            he.save_pickle("obj", "obj.pkl", obj)

        buffer.seek(0)

        importer = PackageImporter(buffer)
        file_structure = importer.file_structure()
        self.assertTrue(file_structure.has_file("package_a/subpackage.py"))
        self.assertFalse(file_structure.has_file("package_a/subpackage"))
예제 #4
0
    def test_save_eager_mods_sharing_scriptmodule(self):
        """
        Test saving of single ScriptModule shared by multiple
        eager modules (ScriptModule should be saved just once
        even though is contained in multiple pickles).
        """
        from package_a.test_module import ModWithSubmod, SimpleTest

        scripted_mod = torch.jit.script(SimpleTest())

        mod1 = ModWithSubmod(scripted_mod)
        mod2 = ModWithSubmod(scripted_mod)

        buffer = BytesIO()
        with PackageExporter(buffer) as e:
            e.intern("**")
            e.save_pickle("res", "mod1.pkl", mod1)
            e.save_pickle("res", "mod2.pkl", mod2)

        buffer.seek(0)
        importer = PackageImporter(buffer)
        file_structure = importer.file_structure()
        self.assertTrue(file_structure.has_file(".data/ts_code/0"))
        self.assertFalse(file_structure.has_file(".data/ts_code/1"))
예제 #5
0
    def test_save_scriptmodule_only_necessary_code(self):
        """
        Test to verify when saving multiple packages with same CU
        that packages don't include unnecessary torchscript code files.
        The TorchVision code should only be saved in the package that
        relies on it.
        """
        from package_a.test_module import ModWithTensor

        class ModWithTorchVision(torch.nn.Module):
            def __init__(self, name: str):
                super().__init__()
                self.tvmod = resnet18()

            def forward(self, input):
                return input * 4

        scripted_mod_0 = torch.jit.script(ModWithTorchVision("foo"))
        scripted_mod_1 = torch.jit.script(ModWithTensor(torch.rand(1, 2, 3)))

        buffer_0 = BytesIO()
        with PackageExporter(buffer_0) as e:
            e.save_pickle("res", "mod1.pkl", scripted_mod_0)

        buffer_0.seek(0)
        importer_0 = importer = PackageImporter(buffer_0)

        buffer_1 = BytesIO()
        with PackageExporter(buffer_1) as e:
            e.save_pickle("res", "mod1.pkl", scripted_mod_1)

        buffer_1.seek(0)
        importer_1 = PackageImporter(buffer_1)

        self.assertTrue("torchvision" in str(importer_0.file_structure()))
        self.assertFalse("torchvision" in str(importer_1.file_structure()))
예제 #6
0
    def test_file_structure(self):
        """
        Tests package's Folder structure representation of a zip file. Ensures
        that the returned Folder prints what is expected and filters
        inputs/outputs correctly.
        """
        buffer = BytesIO()

        export_plain = dedent("""\
                ├── main
                │   └── main
                ├── obj
                │   └── obj.pkl
                ├── package_a
                │   ├── __init__.py
                │   └── subpackage.py
                └── module_a.py
            """)
        export_include = dedent("""\
                ├── obj
                │   └── obj.pkl
                └── package_a
                    └── subpackage.py
            """)
        import_exclude = dedent("""\
                ├── .data
                │   ├── extern_modules
                │   └── version
                ├── main
                │   └── main
                ├── obj
                │   └── obj.pkl
                ├── package_a
                │   ├── __init__.py
                │   └── subpackage.py
                └── module_a.py
            """)

        with PackageExporter(buffer, verbose=False) as he:
            import module_a
            import package_a
            import package_a.subpackage

            obj = package_a.subpackage.PackageASubpackageObject()
            he.save_module(module_a.__name__)
            he.save_module(package_a.__name__)
            he.save_pickle("obj", "obj.pkl", obj)
            he.save_text("main", "main", "my string")

            export_file_structure = he.file_structure()
            # remove first line from testing because WINDOW/iOS/Unix treat the buffer differently
            self.assertEqual(
                dedent("\n".join(str(export_file_structure).split("\n")[1:])),
                export_plain,
            )
            export_file_structure = he.file_structure(
                include=["**/subpackage.py", "**/*.pkl"])
            self.assertEqual(
                dedent("\n".join(str(export_file_structure).split("\n")[1:])),
                export_include,
            )

        buffer.seek(0)
        hi = PackageImporter(buffer)
        import_file_structure = hi.file_structure(exclude="**/*.storage")
        self.assertEqual(
            dedent("\n".join(str(import_file_structure).split("\n")[1:])),
            import_exclude,
        )