コード例 #1
0
    def test_mock_glob(self):
        buffer = BytesIO()
        with PackageExporter(buffer) as he:
            he.mock(["package_a.*", "module*"])
            he.save_module("package_a")
            he.save_source_string(
                "test_module",
                dedent("""\
                    import package_a.subpackage
                    import module_a
                    """),
            )
        buffer.seek(0)
        hi = PackageImporter(buffer)
        import package_a.subpackage

        _ = package_a.subpackage
        import module_a

        _ = module_a

        m = hi.import_module("package_a.subpackage")
        r = m.result
        with self.assertRaisesRegex(NotImplementedError, "was mocked out"):
            r()
コード例 #2
0
    def test_resource_access_by_path(self):
        """
        Tests that packaged code can used importlib.resources.path.
        """
        filename = self.temp()
        with PackageExporter(filename) as e:
            e.save_binary("string_module", "my_string", "my string".encode("utf-8"))
            src = dedent(
                """\
                import importlib.resources
                import string_module

                with importlib.resources.path(string_module, 'my_string') as path:
                    with open(path, mode='r', encoding='utf-8') as f:
                        s = f.read()
                """
            )
            e.save_source_string("main", src, is_package=True)

        zip_file = zipfile.ZipFile(filename, "r")

        with TemporaryDirectory() as temp_dir:
            zip_file.extractall(path=temp_dir)
            dir_importer = PackageImporter(Path(temp_dir) / Path(filename).name)
            m = dir_importer.import_module("main")
            self.assertEqual(m.s, "my string")
コード例 #3
0
    def test_save_module_binary(self):
        f = BytesIO()
        with PackageExporter(f, verbose=False) as he:
            import module_a
            import package_a

            he.save_module(module_a.__name__)
            he.save_module(package_a.__name__)
        f.seek(0)
        hi = PackageImporter(f)
        module_a_i = hi.import_module("module_a")
        self.assertEqual(module_a_i.result, "module_a")
        self.assertIsNot(module_a, module_a_i)
        package_a_i = hi.import_module("package_a")
        self.assertEqual(package_a_i.result, "package_a")
        self.assertIsNot(package_a_i, package_a)
コード例 #4
0
    def test_package_resource_access(self):
        """Packaged modules should be able to use the importlib.resources API to access
        resources saved in the package.
        """
        mod_src = dedent(
            """\
            import importlib.resources
            import my_cool_resources

            def secret_message():
                return importlib.resources.read_text(my_cool_resources, 'sekrit.txt')
            """
        )
        filename = self.temp()
        with PackageExporter(filename) as pe:
            pe.save_source_string("foo.bar", mod_src)
            pe.save_text("my_cool_resources", "sekrit.txt", "my sekrit plays")

        zip_file = zipfile.ZipFile(filename, "r")

        with TemporaryDirectory() as temp_dir:
            zip_file.extractall(path=temp_dir)
            dir_importer = PackageImporter(Path(temp_dir) / Path(filename).name)
            self.assertEqual(
                dir_importer.import_module("foo.bar").secret_message(),
                "my sekrit plays",
            )
コード例 #5
0
    def test_different_package_script_class(self):
        """Test a case where the script class defined in the package is
        different than the one defined in the loading environment, to make
        sure TorchScript can distinguish between the two.
        """
        import package_a.fake_script_class as fake

        # Simulate a package that contains a different version of the
        # script class ,with the attribute `bar` instead of `foo`
        buffer = BytesIO()
        with PackageExporter(buffer, verbose=False) as pe2:
            pe2.save_source_string(
                fake.__name__,
                dedent("""\
                    import torch

                    @torch.jit.script
                    class MyScriptClass:
                        def __init__(self, x):
                            self.bar = x
                    """),
            )
        buffer.seek(0)

        package_importer = PackageImporter(buffer)
        diff_fake = package_importer.import_module(fake.__name__)
        input = torch.rand(2, 3)
        loaded_script_class = diff_fake.MyScriptClass(input)
        orig_script_class = fake.MyScriptClass(input)
        self.assertTrue(
            torch.allclose(loaded_script_class.bar, orig_script_class.foo))
コード例 #6
0
    def test_package_fx_with_imports(self):
        import package_a.subpackage

        # Manually construct a graph that invokes a leaf function
        graph = Graph()
        a = graph.placeholder("x")
        b = graph.placeholder("y")
        c = graph.call_function(package_a.subpackage.leaf_function, (a, b))
        d = graph.call_function(torch.sin, (c, ))
        graph.output(d)
        gm = GraphModule(torch.nn.Module(), graph)

        f = BytesIO()
        with PackageExporter(f, verbose=False) as pe:
            pe.save_pickle("model", "model.pkl", gm)
        f.seek(0)

        pi = PackageImporter(f)
        loaded_gm = pi.load_pickle("model", "model.pkl")
        input_x = torch.rand(2, 3)
        input_y = torch.rand(2, 3)

        self.assertTrue(
            torch.allclose(loaded_gm(input_x, input_y), gm(input_x, input_y)))

        # Check that the packaged version of the leaf_function dependency is
        # not the same as in the outer env.
        packaged_dependency = pi.import_module("package_a.subpackage")
        self.assertTrue(packaged_dependency is not package_a.subpackage)
コード例 #7
0
    def test_single_ordered_importer(self):
        import module_a  # noqa: F401
        import package_a

        buffer = BytesIO()
        with PackageExporter(buffer, verbose=False) as pe:
            pe.save_module(package_a.__name__)

        buffer.seek(0)
        importer = PackageImporter(buffer)

        # Construct an importer-only environment.
        ordered_importer = OrderedImporter(importer)

        # The module returned by this environment should be the same one that's
        # in the importer.
        self.assertIs(
            ordered_importer.import_module("package_a"),
            importer.import_module("package_a"),
        )
        # It should not be the one available in the outer Python environment.
        self.assertIsNot(ordered_importer.import_module("package_a"),
                         package_a)

        # We didn't package this module, so it should not be available.
        with self.assertRaises(ModuleNotFoundError):
            ordered_importer.import_module("module_a")
コード例 #8
0
    def test_extern(self):
        filename = self.temp()
        with PackageExporter(filename, verbose=False) as he:
            he.extern_modules(['package_a.subpackage', 'module_a'])
            he.save_module('package_a')
        hi = PackageImporter(filename)
        import package_a.subpackage
        import module_a

        module_a_im = hi.import_module('module_a')
        hi.import_module('package_a.subpackage')
        package_a_im = hi.import_module('package_a')

        self.assertIs(module_a, module_a_im)
        self.assertIsNot(package_a, package_a_im)
        self.assertIs(package_a.subpackage, package_a_im.subpackage)
コード例 #9
0
ファイル: test_model.py プロジェクト: ydcjeff/pytorch
    def test_resnet(self):
        resnet = resnet18()

        f1 = self.temp()

        # create a package that will save it along with its code
        with PackageExporter(f1, verbose=False) as e:
            # put the pickled resnet in the package, by default
            # this will also save all the code files references by
            # the objects in the pickle
            e.intern("**")
            e.save_pickle("model", "model.pkl", resnet)

            # check th debug graph has something reasonable:
            buf = StringIO()
            debug_graph = e._write_dep_graph(failing_module="torch")
            self.assertIn("torchvision.models.resnet", debug_graph)

        # we can now load the saved model
        i = PackageImporter(f1)
        r2 = i.load_pickle("model", "model.pkl")

        # test that it works
        input = torch.rand(1, 3, 224, 224)
        ref = resnet(input)
        self.assertTrue(torch.allclose(r2(input), ref))

        # functions exist also to get at the private modules in each package
        torchvision = i.import_module("torchvision")

        f2 = BytesIO()
        # if we are doing transfer learning we might want to re-save
        # things that were loaded from a package.
        # We need to tell the exporter about any modules that
        # came from imported packages so that it can resolve
        # class names like torchvision.models.resnet.ResNet
        # to their source code.
        with PackageExporter(f2, verbose=False,
                             importer=(i, sys_importer)) as e:
            # e.importers is a list of module importing functions
            # that by default contains importlib.import_module.
            # it is searched in order until the first success and
            # that module is taken to be what torchvision.models.resnet
            # should be in this code package. In the case of name collisions,
            # such as trying to save a ResNet from two different packages,
            # we take the first thing found in the path, so only ResNet objects from
            # one importer will work. This avoids a bunch of name mangling in
            # the source code. If you need to actually mix ResNet objects,
            # we suggest reconstructing the model objects using code from a single package
            # using functions like save_state_dict and load_state_dict to transfer state
            # to the correct code objects.
            e.intern("**")
            e.save_pickle("model", "model.pkl", r2)

        f2.seek(0)

        i2 = PackageImporter(f2)
        r3 = i2.load_pickle("model", "model.pkl")
        self.assertTrue(torch.allclose(r3(input), ref))
コード例 #10
0
    def test_saving_string(self):
        filename = self.temp()
        with PackageExporter(filename) as he:
            src = dedent(
                """\
                import math
                the_math = math
                """
            )
            he.save_source_string("my_mod", src)
        hi = PackageImporter(filename)
        m = hi.import_module("math")
        import math

        self.assertIs(m, math)
        my_mod = hi.import_module("my_mod")
        self.assertIs(my_mod.math, math)
コード例 #11
0
    def test_extern(self):
        buffer = BytesIO()
        with PackageExporter(buffer, verbose=False) as he:
            he.extern(["package_a.subpackage", "module_a"])
            he.save_source_string("foo", "import package_a.subpackage; import module_a")
        buffer.seek(0)
        hi = PackageImporter(buffer)
        import module_a
        import package_a.subpackage

        module_a_im = hi.import_module("module_a")
        hi.import_module("package_a.subpackage")
        package_a_im = hi.import_module("package_a")

        self.assertIs(module_a, module_a_im)
        self.assertIsNot(package_a, package_a_im)
        self.assertIs(package_a.subpackage, package_a_im.subpackage)
コード例 #12
0
ファイル: test_dependency_api.py プロジェクト: yj4889/pytorch
    def test_extern(self):
        filename = self.temp()
        with PackageExporter(filename, verbose=False) as he:
            he.extern(["package_a.subpackage", "module_a"])
            he.require_module("package_a.subpackage")
            he.require_module("module_a")
            he.save_module("package_a")
        hi = PackageImporter(filename)
        import module_a
        import package_a.subpackage

        module_a_im = hi.import_module("module_a")
        hi.import_module("package_a.subpackage")
        package_a_im = hi.import_module("package_a")

        self.assertIs(module_a, module_a_im)
        self.assertIsNot(package_a, package_a_im)
        self.assertIs(package_a.subpackage, package_a_im.subpackage)
コード例 #13
0
    def test_custom_requires(self):
        filename = self.temp()

        class Custom(PackageExporter):
            def require_module(self, name, dependencies):
                if name == 'module_a':
                    self.save_mock_module('module_a')
                elif name == 'package_a':
                    self.save_source_string('package_a', 'import module_a\nresult = 5\n')
                else:
                    raise NotImplementedError('wat')

        with Custom(filename, verbose=False) as he:
            he.save_source_string('main', 'import package_a\n')

        hi = PackageImporter(filename)
        hi.import_module('module_a').should_be_mocked
        bar = hi.import_module('package_a')
        self.assertEqual(bar.result, 5)
コード例 #14
0
    def test_pickle(self):
        import package_a.subpackage
        obj = package_a.subpackage.PackageASubpackageObject()
        obj2 = package_a.PackageAObject(obj)

        filename = self.temp()
        with PackageExporter(filename, verbose=False) as he:
            he.save_pickle('obj', 'obj.pkl', obj2)
        hi = PackageImporter(filename)

        # check we got dependencies
        sp = hi.import_module('package_a.subpackage')
        # check we didn't get other stuff
        with self.assertRaises(ImportError):
            hi.import_module('module_a')

        obj_loaded = hi.load_pickle('obj', 'obj.pkl')
        self.assertIsNot(obj2, obj_loaded)
        self.assertIsInstance(obj_loaded.obj, sp.PackageASubpackageObject)
        self.assertIsNot(package_a.subpackage.PackageASubpackageObject, sp.PackageASubpackageObject)
コード例 #15
0
    def test_custom_requires(self):
        buffer = BytesIO()

        class Custom(PackageExporter):
            def require_module(self, name, dependencies):
                if name == "module_a":
                    self.save_mock_module("module_a")
                elif name == "package_a":
                    self.save_source_string("package_a",
                                            "import module_a\nresult = 5\n")
                else:
                    raise NotImplementedError("wat")

        with Custom(buffer, verbose=False) as he:
            he.save_source_string("main", "import package_a\n")

        buffer.seek(0)
        hi = PackageImporter(buffer)
        hi.import_module("module_a").should_be_mocked
        bar = hi.import_module("package_a")
        self.assertEqual(bar.result, 5)
コード例 #16
0
    def test_dunder_imports(self):
        buffer = BytesIO()
        with PackageExporter(buffer) as he:
            import package_b

            obj = package_b.PackageBObject
            he.intern("**")
            he.save_pickle("res", "obj.pkl", obj)

        buffer.seek(0)
        hi = PackageImporter(buffer)
        loaded_obj = hi.load_pickle("res", "obj.pkl")

        package_b = hi.import_module("package_b")
        self.assertEqual(package_b.result, "package_b")

        math = hi.import_module("math")
        self.assertEqual(math.__name__, "math")

        xml_sub_sub_package = hi.import_module("xml.sax.xmlreader")
        self.assertEqual(xml_sub_sub_package.__name__, "xml.sax.xmlreader")

        subpackage_1 = hi.import_module("package_b.subpackage_1")
        self.assertEqual(subpackage_1.result, "subpackage_1")

        subpackage_2 = hi.import_module("package_b.subpackage_2")
        self.assertEqual(subpackage_2.result, "subpackage_2")

        subsubpackage_0 = hi.import_module("package_b.subpackage_0.subsubpackage_0")
        self.assertEqual(subsubpackage_0.result, "subsubpackage_0")
コード例 #17
0
ファイル: test_save_load.py プロジェクト: yanboliang/pytorch
    def test_pickle(self):
        import package_a.subpackage

        obj = package_a.subpackage.PackageASubpackageObject()
        obj2 = package_a.PackageAObject(obj)

        filename = self.temp()
        with PackageExporter(filename) as he:
            he.intern("**")
            he.save_pickle("obj", "obj.pkl", obj2)
        hi = PackageImporter(filename)

        # check we got dependencies
        sp = hi.import_module("package_a.subpackage")
        # check we didn't get other stuff
        with self.assertRaises(ImportError):
            hi.import_module("module_a")

        obj_loaded = hi.load_pickle("obj", "obj.pkl")
        self.assertIsNot(obj2, obj_loaded)
        self.assertIsInstance(obj_loaded.obj, sp.PackageASubpackageObject)
        self.assertIsNot(package_a.subpackage.PackageASubpackageObject,
                         sp.PackageASubpackageObject)
コード例 #18
0
    def test_ordered_importer_basic(self):
        import package_a
        buffer = BytesIO()
        with PackageExporter(buffer, verbose=False) as pe:
            pe.save_module(package_a.__name__)

        buffer.seek(0)
        importer = PackageImporter(buffer)

        ordered_importer_sys_first = OrderedImporter(sys_importer, importer)
        self.assertIs(ordered_importer_sys_first.import_module('package_a'), package_a)

        ordered_importer_package_first = OrderedImporter(importer, sys_importer)
        self.assertIs(ordered_importer_package_first.import_module('package_a'), importer.import_module('package_a'))
コード例 #19
0
ファイル: test_package.py プロジェクト: yiqxiaobai/pytorch
    def test_resources(self):
        filename = self.temp()
        with PackageExporter(filename, verbose=False) as he:
            he.save_text('main', 'main', "my string")
            he.save_binary('main', 'main_binary', "my string".encode('utf-8'))
            src = """\
import resources
t = resources.load_text('main', 'main')
b = resources.load_binary('main', 'main_binary')
"""
            he.save_source_string('main', src, is_package=True)
        hi = PackageImporter(filename)
        m = hi.import_module('main')
        self.assertEqual(m.t, "my string")
        self.assertEqual(m.b, "my string".encode('utf-8'))
コード例 #20
0
    def test_mock(self):
        filename = self.temp()
        with PackageExporter(filename, verbose=False) as he:
            he.mock_modules(['package_a.subpackage', 'module_a'])
            he.save_module('package_a')
        hi = PackageImporter(filename)
        import package_a.subpackage
        _ = package_a.subpackage
        import module_a
        _ = module_a

        m = hi.import_module('package_a.subpackage')
        r = m.result
        with self.assertRaisesRegex(NotImplementedError, 'was mocked out'):
            r()
コード例 #21
0
ファイル: test_package_script.py プロジェクト: xsacha/pytorch
    def test_package_script_class(self):
        import package_a.fake_script_class as fake

        buffer = BytesIO()
        with PackageExporter(buffer) as pe:
            pe.save_module(fake.__name__)
        buffer.seek(0)

        package_importer = PackageImporter(buffer)
        loaded = package_importer.import_module(fake.__name__)

        input = torch.tensor(1)
        self.assertTrue(
            torch.allclose(fake.uses_script_class(input),
                           loaded.uses_script_class(input)))
コード例 #22
0
    def test_extern_glob(self):
        buffer = BytesIO()
        with PackageExporter(buffer) as he:
            he.extern(["package_a.*", "module_*"])
            he.save_module("package_a")
            he.save_source_string(
                "test_module",
                dedent("""\
                    import package_a.subpackage
                    import module_a
                    """),
            )
        buffer.seek(0)
        hi = PackageImporter(buffer)
        import module_a
        import package_a.subpackage

        module_a_im = hi.import_module("module_a")
        hi.import_module("package_a.subpackage")
        package_a_im = hi.import_module("package_a")

        self.assertIs(module_a, module_a_im)
        self.assertIsNot(package_a, package_a_im)
        self.assertIs(package_a.subpackage, package_a_im.subpackage)
コード例 #23
0
    def test_dunder_package_present(self):
        """
        The attribute '__torch_package__' should be populated on imported modules.
        """
        import package_a.subpackage

        buffer = BytesIO()
        obj = package_a.subpackage.PackageASubpackageObject()

        with PackageExporter(buffer, verbose=False) as pe:
            pe.save_pickle("obj", "obj.pkl", obj)

        buffer.seek(0)
        pi = PackageImporter(buffer)
        mod = pi.import_module("package_a.subpackage")
        self.assertTrue(hasattr(mod, "__torch_package__"))
コード例 #24
0
    def test_inspect_class(self):
        """Should be able to retrieve source for a packaged class."""
        import package_a.subpackage
        buffer = BytesIO()
        obj = package_a.subpackage.PackageASubpackageObject()

        with PackageExporter(buffer, verbose=False) as pe:
            pe.save_pickle('obj', 'obj.pkl', obj)

        buffer.seek(0)
        pi = PackageImporter(buffer)
        packaged_class = pi.import_module('package_a.subpackage').PackageASubpackageObject
        regular_class = package_a.subpackage.PackageASubpackageObject

        packaged_src = inspect.getsourcelines(packaged_class)
        regular_src = inspect.getsourcelines(regular_class)
        self.assertEqual(packaged_src, regular_src)
コード例 #25
0
    def test_different_package_interface(self):
        """Test a case where the interface defined in the package is
        different than the one defined in the loading environment, to make
        sure TorchScript can distinguish between the two.
        """
        # Import one version of the interface
        import package_a.fake_interface as fake

        # Simulate a package that contains a different version of the
        # interface, with the exact same name.
        buffer = BytesIO()
        with PackageExporter(buffer) as pe:
            pe.save_source_string(
                fake.__name__,
                dedent(
                    """\
                    import torch
                    from torch import Tensor

                    @torch.jit.interface
                    class ModuleInterface(torch.nn.Module):
                        def one(self, inp1: Tensor) -> Tensor:
                            pass

                    class ImplementsInterface(torch.nn.Module):
                        def one(self, inp1: Tensor) -> Tensor:
                            return inp1 + 1

                    class UsesInterface(torch.nn.Module):
                        proxy_mod: ModuleInterface

                        def __init__(self):
                            super().__init__()
                            self.proxy_mod = ImplementsInterface()

                        def forward(self, input: Tensor) -> Tensor:
                            return self.proxy_mod.one(input)
                    """
                ),
            )
        buffer.seek(0)

        package_importer = PackageImporter(buffer)
        diff_fake = package_importer.import_module(fake.__name__)
        # We should be able to script successfully.
        torch.jit.script(diff_fake.UsesInterface())
コード例 #26
0
    def test_loading_module(self):
        """
        Test basic saving and loading of a packages from a DirectoryReader.
        """
        import package_a

        filename = self.temp()
        with PackageExporter(filename) as e:
            e.save_module("package_a")

        zip_file = zipfile.ZipFile(filename, "r")

        with TemporaryDirectory() as temp_dir:
            zip_file.extractall(path=temp_dir)
            dir_importer = PackageImporter(Path(temp_dir) / Path(filename).name)
            dir_mod = dir_importer.import_module("package_a")
            self.assertEqual(dir_mod.result, package_a.result)
コード例 #27
0
    def test_save_module_with_module_object(self):
        """
        Test that save_module works with a module object
        instead of a module name.
        """
        buffer = BytesIO()

        with PackageExporter(buffer, verbose=False) as he:
            import module_a

            he.save_module(module_a)

        buffer.seek(0)
        hi = PackageImporter(buffer)
        module_a_i = hi.import_module("module_a")
        self.assertEqual(module_a_i.result, "module_a")
        self.assertIsNot(module_a, module_a_i)
コード例 #28
0
ファイル: test_resources.py プロジェクト: yj4889/pytorch
    def test_importer_access(self):
        filename = self.temp()
        with PackageExporter(filename, verbose=False) as he:
            he.save_text("main", "main", "my string")
            he.save_binary("main", "main_binary", "my string".encode("utf-8"))
            src = dedent("""\
                import importlib
                import torch_package_importer as resources

                t = resources.load_text('main', 'main')
                b = resources.load_binary('main', 'main_binary')
                """)
            he.save_source_string("main", src, is_package=True)
        hi = PackageImporter(filename)
        m = hi.import_module("main")
        self.assertEqual(m.t, "my string")
        self.assertEqual(m.b, "my string".encode("utf-8"))
コード例 #29
0
    def test_dunder_package_works_from_package(self):
        """
        The attribute '__torch_package__' should be accessible from within
        the module itself, so that packaged code can detect whether it's
        being used in a packaged context or not.
        """
        import package_a.use_dunder_package as mod

        buffer = BytesIO()

        with PackageExporter(buffer, verbose=False) as pe:
            pe.save_module(mod.__name__)

        buffer.seek(0)
        pi = PackageImporter(buffer)
        imported_mod = pi.import_module(mod.__name__)
        self.assertTrue(imported_mod.is_from_package())
        self.assertFalse(mod.is_from_package())
コード例 #30
0
ファイル: test_resources.py プロジェクト: xsacha/pytorch
    def test_importer_access(self):
        buffer = BytesIO()
        with PackageExporter(buffer) as he:
            he.save_text("main", "main", "my string")
            he.save_binary("main", "main_binary", "my string".encode("utf-8"))
            src = dedent("""\
                import importlib
                import torch_package_importer as resources

                t = resources.load_text('main', 'main')
                b = resources.load_binary('main', 'main_binary')
                """)
            he.save_source_string("main", src, is_package=True)
        buffer.seek(0)
        hi = PackageImporter(buffer)
        m = hi.import_module("main")
        self.assertEqual(m.t, "my string")
        self.assertEqual(m.b, "my string".encode("utf-8"))