def test_mock_glob(self): buffer = BytesIO() with PackageExporter(buffer) as he: he.mock(["package_a.*", "module*"]) he.save_module("package_a") he.save_source_string( "test_module", dedent("""\ import package_a.subpackage import module_a """), ) buffer.seek(0) hi = PackageImporter(buffer) import package_a.subpackage _ = package_a.subpackage import module_a _ = module_a m = hi.import_module("package_a.subpackage") r = m.result with self.assertRaisesRegex(NotImplementedError, "was mocked out"): r()
def test_resource_access_by_path(self): """ Tests that packaged code can used importlib.resources.path. """ filename = self.temp() with PackageExporter(filename) as e: e.save_binary("string_module", "my_string", "my string".encode("utf-8")) src = dedent( """\ import importlib.resources import string_module with importlib.resources.path(string_module, 'my_string') as path: with open(path, mode='r', encoding='utf-8') as f: s = f.read() """ ) e.save_source_string("main", src, is_package=True) zip_file = zipfile.ZipFile(filename, "r") with TemporaryDirectory() as temp_dir: zip_file.extractall(path=temp_dir) dir_importer = PackageImporter(Path(temp_dir) / Path(filename).name) m = dir_importer.import_module("main") self.assertEqual(m.s, "my string")
def test_save_module_binary(self): f = BytesIO() with PackageExporter(f, verbose=False) as he: import module_a import package_a he.save_module(module_a.__name__) he.save_module(package_a.__name__) f.seek(0) hi = PackageImporter(f) module_a_i = hi.import_module("module_a") self.assertEqual(module_a_i.result, "module_a") self.assertIsNot(module_a, module_a_i) package_a_i = hi.import_module("package_a") self.assertEqual(package_a_i.result, "package_a") self.assertIsNot(package_a_i, package_a)
def test_package_resource_access(self): """Packaged modules should be able to use the importlib.resources API to access resources saved in the package. """ mod_src = dedent( """\ import importlib.resources import my_cool_resources def secret_message(): return importlib.resources.read_text(my_cool_resources, 'sekrit.txt') """ ) filename = self.temp() with PackageExporter(filename) as pe: pe.save_source_string("foo.bar", mod_src) pe.save_text("my_cool_resources", "sekrit.txt", "my sekrit plays") zip_file = zipfile.ZipFile(filename, "r") with TemporaryDirectory() as temp_dir: zip_file.extractall(path=temp_dir) dir_importer = PackageImporter(Path(temp_dir) / Path(filename).name) self.assertEqual( dir_importer.import_module("foo.bar").secret_message(), "my sekrit plays", )
def test_different_package_script_class(self): """Test a case where the script class defined in the package is different than the one defined in the loading environment, to make sure TorchScript can distinguish between the two. """ import package_a.fake_script_class as fake # Simulate a package that contains a different version of the # script class ,with the attribute `bar` instead of `foo` buffer = BytesIO() with PackageExporter(buffer, verbose=False) as pe2: pe2.save_source_string( fake.__name__, dedent("""\ import torch @torch.jit.script class MyScriptClass: def __init__(self, x): self.bar = x """), ) buffer.seek(0) package_importer = PackageImporter(buffer) diff_fake = package_importer.import_module(fake.__name__) input = torch.rand(2, 3) loaded_script_class = diff_fake.MyScriptClass(input) orig_script_class = fake.MyScriptClass(input) self.assertTrue( torch.allclose(loaded_script_class.bar, orig_script_class.foo))
def test_package_fx_with_imports(self): import package_a.subpackage # Manually construct a graph that invokes a leaf function graph = Graph() a = graph.placeholder("x") b = graph.placeholder("y") c = graph.call_function(package_a.subpackage.leaf_function, (a, b)) d = graph.call_function(torch.sin, (c, )) graph.output(d) gm = GraphModule(torch.nn.Module(), graph) f = BytesIO() with PackageExporter(f, verbose=False) as pe: pe.save_pickle("model", "model.pkl", gm) f.seek(0) pi = PackageImporter(f) loaded_gm = pi.load_pickle("model", "model.pkl") input_x = torch.rand(2, 3) input_y = torch.rand(2, 3) self.assertTrue( torch.allclose(loaded_gm(input_x, input_y), gm(input_x, input_y))) # Check that the packaged version of the leaf_function dependency is # not the same as in the outer env. packaged_dependency = pi.import_module("package_a.subpackage") self.assertTrue(packaged_dependency is not package_a.subpackage)
def test_single_ordered_importer(self): import module_a # noqa: F401 import package_a buffer = BytesIO() with PackageExporter(buffer, verbose=False) as pe: pe.save_module(package_a.__name__) buffer.seek(0) importer = PackageImporter(buffer) # Construct an importer-only environment. ordered_importer = OrderedImporter(importer) # The module returned by this environment should be the same one that's # in the importer. self.assertIs( ordered_importer.import_module("package_a"), importer.import_module("package_a"), ) # It should not be the one available in the outer Python environment. self.assertIsNot(ordered_importer.import_module("package_a"), package_a) # We didn't package this module, so it should not be available. with self.assertRaises(ModuleNotFoundError): ordered_importer.import_module("module_a")
def test_extern(self): filename = self.temp() with PackageExporter(filename, verbose=False) as he: he.extern_modules(['package_a.subpackage', 'module_a']) he.save_module('package_a') hi = PackageImporter(filename) import package_a.subpackage import module_a module_a_im = hi.import_module('module_a') hi.import_module('package_a.subpackage') package_a_im = hi.import_module('package_a') self.assertIs(module_a, module_a_im) self.assertIsNot(package_a, package_a_im) self.assertIs(package_a.subpackage, package_a_im.subpackage)
def test_resnet(self): resnet = resnet18() f1 = self.temp() # create a package that will save it along with its code with PackageExporter(f1, verbose=False) as e: # put the pickled resnet in the package, by default # this will also save all the code files references by # the objects in the pickle e.intern("**") e.save_pickle("model", "model.pkl", resnet) # check th debug graph has something reasonable: buf = StringIO() debug_graph = e._write_dep_graph(failing_module="torch") self.assertIn("torchvision.models.resnet", debug_graph) # we can now load the saved model i = PackageImporter(f1) r2 = i.load_pickle("model", "model.pkl") # test that it works input = torch.rand(1, 3, 224, 224) ref = resnet(input) self.assertTrue(torch.allclose(r2(input), ref)) # functions exist also to get at the private modules in each package torchvision = i.import_module("torchvision") f2 = BytesIO() # if we are doing transfer learning we might want to re-save # things that were loaded from a package. # We need to tell the exporter about any modules that # came from imported packages so that it can resolve # class names like torchvision.models.resnet.ResNet # to their source code. with PackageExporter(f2, verbose=False, importer=(i, sys_importer)) as e: # e.importers is a list of module importing functions # that by default contains importlib.import_module. # it is searched in order until the first success and # that module is taken to be what torchvision.models.resnet # should be in this code package. In the case of name collisions, # such as trying to save a ResNet from two different packages, # we take the first thing found in the path, so only ResNet objects from # one importer will work. This avoids a bunch of name mangling in # the source code. If you need to actually mix ResNet objects, # we suggest reconstructing the model objects using code from a single package # using functions like save_state_dict and load_state_dict to transfer state # to the correct code objects. e.intern("**") e.save_pickle("model", "model.pkl", r2) f2.seek(0) i2 = PackageImporter(f2) r3 = i2.load_pickle("model", "model.pkl") self.assertTrue(torch.allclose(r3(input), ref))
def test_saving_string(self): filename = self.temp() with PackageExporter(filename) as he: src = dedent( """\ import math the_math = math """ ) he.save_source_string("my_mod", src) hi = PackageImporter(filename) m = hi.import_module("math") import math self.assertIs(m, math) my_mod = hi.import_module("my_mod") self.assertIs(my_mod.math, math)
def test_extern(self): buffer = BytesIO() with PackageExporter(buffer, verbose=False) as he: he.extern(["package_a.subpackage", "module_a"]) he.save_source_string("foo", "import package_a.subpackage; import module_a") buffer.seek(0) hi = PackageImporter(buffer) import module_a import package_a.subpackage module_a_im = hi.import_module("module_a") hi.import_module("package_a.subpackage") package_a_im = hi.import_module("package_a") self.assertIs(module_a, module_a_im) self.assertIsNot(package_a, package_a_im) self.assertIs(package_a.subpackage, package_a_im.subpackage)
def test_extern(self): filename = self.temp() with PackageExporter(filename, verbose=False) as he: he.extern(["package_a.subpackage", "module_a"]) he.require_module("package_a.subpackage") he.require_module("module_a") he.save_module("package_a") hi = PackageImporter(filename) import module_a import package_a.subpackage module_a_im = hi.import_module("module_a") hi.import_module("package_a.subpackage") package_a_im = hi.import_module("package_a") self.assertIs(module_a, module_a_im) self.assertIsNot(package_a, package_a_im) self.assertIs(package_a.subpackage, package_a_im.subpackage)
def test_custom_requires(self): filename = self.temp() class Custom(PackageExporter): def require_module(self, name, dependencies): if name == 'module_a': self.save_mock_module('module_a') elif name == 'package_a': self.save_source_string('package_a', 'import module_a\nresult = 5\n') else: raise NotImplementedError('wat') with Custom(filename, verbose=False) as he: he.save_source_string('main', 'import package_a\n') hi = PackageImporter(filename) hi.import_module('module_a').should_be_mocked bar = hi.import_module('package_a') self.assertEqual(bar.result, 5)
def test_pickle(self): import package_a.subpackage obj = package_a.subpackage.PackageASubpackageObject() obj2 = package_a.PackageAObject(obj) filename = self.temp() with PackageExporter(filename, verbose=False) as he: he.save_pickle('obj', 'obj.pkl', obj2) hi = PackageImporter(filename) # check we got dependencies sp = hi.import_module('package_a.subpackage') # check we didn't get other stuff with self.assertRaises(ImportError): hi.import_module('module_a') obj_loaded = hi.load_pickle('obj', 'obj.pkl') self.assertIsNot(obj2, obj_loaded) self.assertIsInstance(obj_loaded.obj, sp.PackageASubpackageObject) self.assertIsNot(package_a.subpackage.PackageASubpackageObject, sp.PackageASubpackageObject)
def test_custom_requires(self): buffer = BytesIO() class Custom(PackageExporter): def require_module(self, name, dependencies): if name == "module_a": self.save_mock_module("module_a") elif name == "package_a": self.save_source_string("package_a", "import module_a\nresult = 5\n") else: raise NotImplementedError("wat") with Custom(buffer, verbose=False) as he: he.save_source_string("main", "import package_a\n") buffer.seek(0) hi = PackageImporter(buffer) hi.import_module("module_a").should_be_mocked bar = hi.import_module("package_a") self.assertEqual(bar.result, 5)
def test_dunder_imports(self): buffer = BytesIO() with PackageExporter(buffer) as he: import package_b obj = package_b.PackageBObject he.intern("**") he.save_pickle("res", "obj.pkl", obj) buffer.seek(0) hi = PackageImporter(buffer) loaded_obj = hi.load_pickle("res", "obj.pkl") package_b = hi.import_module("package_b") self.assertEqual(package_b.result, "package_b") math = hi.import_module("math") self.assertEqual(math.__name__, "math") xml_sub_sub_package = hi.import_module("xml.sax.xmlreader") self.assertEqual(xml_sub_sub_package.__name__, "xml.sax.xmlreader") subpackage_1 = hi.import_module("package_b.subpackage_1") self.assertEqual(subpackage_1.result, "subpackage_1") subpackage_2 = hi.import_module("package_b.subpackage_2") self.assertEqual(subpackage_2.result, "subpackage_2") subsubpackage_0 = hi.import_module("package_b.subpackage_0.subsubpackage_0") self.assertEqual(subsubpackage_0.result, "subsubpackage_0")
def test_pickle(self): import package_a.subpackage obj = package_a.subpackage.PackageASubpackageObject() obj2 = package_a.PackageAObject(obj) filename = self.temp() with PackageExporter(filename) as he: he.intern("**") he.save_pickle("obj", "obj.pkl", obj2) hi = PackageImporter(filename) # check we got dependencies sp = hi.import_module("package_a.subpackage") # check we didn't get other stuff with self.assertRaises(ImportError): hi.import_module("module_a") obj_loaded = hi.load_pickle("obj", "obj.pkl") self.assertIsNot(obj2, obj_loaded) self.assertIsInstance(obj_loaded.obj, sp.PackageASubpackageObject) self.assertIsNot(package_a.subpackage.PackageASubpackageObject, sp.PackageASubpackageObject)
def test_ordered_importer_basic(self): import package_a buffer = BytesIO() with PackageExporter(buffer, verbose=False) as pe: pe.save_module(package_a.__name__) buffer.seek(0) importer = PackageImporter(buffer) ordered_importer_sys_first = OrderedImporter(sys_importer, importer) self.assertIs(ordered_importer_sys_first.import_module('package_a'), package_a) ordered_importer_package_first = OrderedImporter(importer, sys_importer) self.assertIs(ordered_importer_package_first.import_module('package_a'), importer.import_module('package_a'))
def test_resources(self): filename = self.temp() with PackageExporter(filename, verbose=False) as he: he.save_text('main', 'main', "my string") he.save_binary('main', 'main_binary', "my string".encode('utf-8')) src = """\ import resources t = resources.load_text('main', 'main') b = resources.load_binary('main', 'main_binary') """ he.save_source_string('main', src, is_package=True) hi = PackageImporter(filename) m = hi.import_module('main') self.assertEqual(m.t, "my string") self.assertEqual(m.b, "my string".encode('utf-8'))
def test_mock(self): filename = self.temp() with PackageExporter(filename, verbose=False) as he: he.mock_modules(['package_a.subpackage', 'module_a']) he.save_module('package_a') hi = PackageImporter(filename) import package_a.subpackage _ = package_a.subpackage import module_a _ = module_a m = hi.import_module('package_a.subpackage') r = m.result with self.assertRaisesRegex(NotImplementedError, 'was mocked out'): r()
def test_package_script_class(self): import package_a.fake_script_class as fake buffer = BytesIO() with PackageExporter(buffer) as pe: pe.save_module(fake.__name__) buffer.seek(0) package_importer = PackageImporter(buffer) loaded = package_importer.import_module(fake.__name__) input = torch.tensor(1) self.assertTrue( torch.allclose(fake.uses_script_class(input), loaded.uses_script_class(input)))
def test_extern_glob(self): buffer = BytesIO() with PackageExporter(buffer) as he: he.extern(["package_a.*", "module_*"]) he.save_module("package_a") he.save_source_string( "test_module", dedent("""\ import package_a.subpackage import module_a """), ) buffer.seek(0) hi = PackageImporter(buffer) import module_a import package_a.subpackage module_a_im = hi.import_module("module_a") hi.import_module("package_a.subpackage") package_a_im = hi.import_module("package_a") self.assertIs(module_a, module_a_im) self.assertIsNot(package_a, package_a_im) self.assertIs(package_a.subpackage, package_a_im.subpackage)
def test_dunder_package_present(self): """ The attribute '__torch_package__' should be populated on imported modules. """ import package_a.subpackage buffer = BytesIO() obj = package_a.subpackage.PackageASubpackageObject() with PackageExporter(buffer, verbose=False) as pe: pe.save_pickle("obj", "obj.pkl", obj) buffer.seek(0) pi = PackageImporter(buffer) mod = pi.import_module("package_a.subpackage") self.assertTrue(hasattr(mod, "__torch_package__"))
def test_inspect_class(self): """Should be able to retrieve source for a packaged class.""" import package_a.subpackage buffer = BytesIO() obj = package_a.subpackage.PackageASubpackageObject() with PackageExporter(buffer, verbose=False) as pe: pe.save_pickle('obj', 'obj.pkl', obj) buffer.seek(0) pi = PackageImporter(buffer) packaged_class = pi.import_module('package_a.subpackage').PackageASubpackageObject regular_class = package_a.subpackage.PackageASubpackageObject packaged_src = inspect.getsourcelines(packaged_class) regular_src = inspect.getsourcelines(regular_class) self.assertEqual(packaged_src, regular_src)
def test_different_package_interface(self): """Test a case where the interface defined in the package is different than the one defined in the loading environment, to make sure TorchScript can distinguish between the two. """ # Import one version of the interface import package_a.fake_interface as fake # Simulate a package that contains a different version of the # interface, with the exact same name. buffer = BytesIO() with PackageExporter(buffer) as pe: pe.save_source_string( fake.__name__, dedent( """\ import torch from torch import Tensor @torch.jit.interface class ModuleInterface(torch.nn.Module): def one(self, inp1: Tensor) -> Tensor: pass class ImplementsInterface(torch.nn.Module): def one(self, inp1: Tensor) -> Tensor: return inp1 + 1 class UsesInterface(torch.nn.Module): proxy_mod: ModuleInterface def __init__(self): super().__init__() self.proxy_mod = ImplementsInterface() def forward(self, input: Tensor) -> Tensor: return self.proxy_mod.one(input) """ ), ) buffer.seek(0) package_importer = PackageImporter(buffer) diff_fake = package_importer.import_module(fake.__name__) # We should be able to script successfully. torch.jit.script(diff_fake.UsesInterface())
def test_loading_module(self): """ Test basic saving and loading of a packages from a DirectoryReader. """ import package_a filename = self.temp() with PackageExporter(filename) as e: e.save_module("package_a") zip_file = zipfile.ZipFile(filename, "r") with TemporaryDirectory() as temp_dir: zip_file.extractall(path=temp_dir) dir_importer = PackageImporter(Path(temp_dir) / Path(filename).name) dir_mod = dir_importer.import_module("package_a") self.assertEqual(dir_mod.result, package_a.result)
def test_save_module_with_module_object(self): """ Test that save_module works with a module object instead of a module name. """ buffer = BytesIO() with PackageExporter(buffer, verbose=False) as he: import module_a he.save_module(module_a) buffer.seek(0) hi = PackageImporter(buffer) module_a_i = hi.import_module("module_a") self.assertEqual(module_a_i.result, "module_a") self.assertIsNot(module_a, module_a_i)
def test_importer_access(self): filename = self.temp() with PackageExporter(filename, verbose=False) as he: he.save_text("main", "main", "my string") he.save_binary("main", "main_binary", "my string".encode("utf-8")) src = dedent("""\ import importlib import torch_package_importer as resources t = resources.load_text('main', 'main') b = resources.load_binary('main', 'main_binary') """) he.save_source_string("main", src, is_package=True) hi = PackageImporter(filename) m = hi.import_module("main") self.assertEqual(m.t, "my string") self.assertEqual(m.b, "my string".encode("utf-8"))
def test_dunder_package_works_from_package(self): """ The attribute '__torch_package__' should be accessible from within the module itself, so that packaged code can detect whether it's being used in a packaged context or not. """ import package_a.use_dunder_package as mod buffer = BytesIO() with PackageExporter(buffer, verbose=False) as pe: pe.save_module(mod.__name__) buffer.seek(0) pi = PackageImporter(buffer) imported_mod = pi.import_module(mod.__name__) self.assertTrue(imported_mod.is_from_package()) self.assertFalse(mod.is_from_package())
def test_importer_access(self): buffer = BytesIO() with PackageExporter(buffer) as he: he.save_text("main", "main", "my string") he.save_binary("main", "main_binary", "my string".encode("utf-8")) src = dedent("""\ import importlib import torch_package_importer as resources t = resources.load_text('main', 'main') b = resources.load_binary('main', 'main_binary') """) he.save_source_string("main", src, is_package=True) buffer.seek(0) hi = PackageImporter(buffer) m = hi.import_module("main") self.assertEqual(m.t, "my string") self.assertEqual(m.b, "my string".encode("utf-8"))