def test_ordered_importer_whichmodule(self): """OrderedImporter's implementation of whichmodule should try each underlying importer's whichmodule in order. """ class DummyImporter(Importer): def __init__(self, whichmodule_return): self._whichmodule_return = whichmodule_return def import_module(self, module_name): raise NotImplementedError() def whichmodule(self, obj, name): return self._whichmodule_return class DummyClass: pass dummy_importer_foo = DummyImporter("foo") dummy_importer_bar = DummyImporter("bar") dummy_importer_not_found = DummyImporter( "__main__" ) # __main__ is used as a proxy for "not found" by CPython foo_then_bar = OrderedImporter(dummy_importer_foo, dummy_importer_bar) self.assertEqual(foo_then_bar.whichmodule(DummyClass(), ""), "foo") bar_then_foo = OrderedImporter(dummy_importer_bar, dummy_importer_foo) self.assertEqual(bar_then_foo.whichmodule(DummyClass(), ""), "bar") notfound_then_foo = OrderedImporter(dummy_importer_not_found, dummy_importer_foo) self.assertEqual(notfound_then_foo.whichmodule(DummyClass(), ""), "foo")
def _load_storages(id, zip_reader, obj_bytes, serialized_storages, serialized_dtypes): def persistent_load(saved_id): assert isinstance(saved_id, tuple) typename = _maybe_decode_ascii(saved_id[0]) data = saved_id[1:] if typename == "storage": # TODO: Once we decide to break serialization FC, we can # stop wrapping with TypedStorage storage = serialized_storages[data[0]] dtype = serialized_dtypes[data[0]] return torch.storage.TypedStorage(wrap_storage=storage.untyped(), dtype=dtype) if typename == "reduce_deploy": reduce_id, func, args = data if reduce_id not in _loaded_reduces: _loaded_reduces[reduce_id] = func(_raw_packages[zip_reader], *args) return _loaded_reduces[reduce_id] return None importer: Importer if zip_reader is not None: importer = OrderedImporter(_get_package(zip_reader), sys_importer) else: importer = sys_importer unpickler = PackageUnpickler(importer, io.BytesIO(obj_bytes)) unpickler.persistent_load = persistent_load # type: ignore[assignment] result = _deploy_objects[id] = unpickler.load() return result
def _save_storages(importer, obj): serialized_storages = [] serialized_dtypes = [] def persistent_id(obj): # FIXME: the docs say that persistent_id should only return a string # but torch store returns tuples. This works only in the binary protocol # see # https://docs.python.org/2/library/pickle.html#pickling-and-unpickling-external-objects # https://github.com/python/cpython/blob/master/Lib/pickle.py#L527-L537 if torch.is_storage(obj): serialized_storages.append(obj) serialized_dtypes.append(obj.dtype) return ('storage', len(serialized_storages) - 1) return None # Write the pickle data for `obj` data_buf = io.BytesIO() importer = importer if isinstance(importer, torch.package.PackageImporter) else None importers: Importer if importer is not None: importers = OrderedImporter(importer, sys_importer) else: importers = sys_importer pickler = create_custom_import_pickler(data_buf, importers) pickler.persistent_id = persistent_id pickler.dump(obj) data_value = data_buf.getvalue() return data_value, serialized_storages, serialized_dtypes, importer.zip_reader if importer else None
def _load_storages(id, zip_reader, obj_bytes, serialized_storages): def persistent_load(saved_id): assert isinstance(saved_id, tuple) typename = _maybe_decode_ascii(saved_id[0]) data = saved_id[1:] if typename == 'storage': return serialized_storages[data[0]] if typename == 'reduce_deploy': reduce_id, func, args = data if reduce_id not in _loaded_reduces: _loaded_reduces[reduce_id] = func(_raw_packages[zip_reader], *args) return _loaded_reduces[reduce_id] return None importer: Importer if zip_reader is not None: importer = OrderedImporter(_get_package(zip_reader), sys_importer) else: importer = sys_importer unpickler = PackageUnpickler(importer, io.BytesIO(obj_bytes)) unpickler.persistent_load = persistent_load result = _deploy_objects[id] = unpickler.load() return result
def test_single_ordered_importer(self): import module_a # noqa: F401 import package_a buffer = BytesIO() with PackageExporter(buffer) as pe: pe.save_module(package_a.__name__) buffer.seek(0) importer = PackageImporter(buffer) # Construct an importer-only environment. ordered_importer = OrderedImporter(importer) # The module returned by this environment should be the same one that's # in the importer. self.assertIs( ordered_importer.import_module("package_a"), importer.import_module("package_a"), ) # It should not be the one available in the outer Python environment. self.assertIsNot(ordered_importer.import_module("package_a"), package_a) # We didn't package this module, so it should not be available. with self.assertRaises(ModuleNotFoundError): ordered_importer.import_module("module_a")
def _save_storages(importer, obj): serialized_storages = [] serialized_dtypes = [] importer = importer if isinstance(importer, torch.package.PackageImporter) else None importers: Importer if importer is not None: importers = OrderedImporter(importer, sys_importer) else: importers = sys_importer def persistent_id(obj): if torch.is_storage(obj) or isinstance(obj, torch.storage.TypedStorage): if isinstance(obj, torch.storage.TypedStorage): # TODO: Once we decide to break serialization FC, we can # remove this case storage = obj._storage dtype = obj.dtype else: storage = obj dtype = torch.uint8 serialized_storages.append(obj) serialized_dtypes.append(dtype) return ("storage", len(serialized_storages) - 1) if hasattr(obj, "__reduce_deploy__"): if _serialized_reduces.get(id(obj)) is None: _serialized_reduces[id(obj)] = ( "reduce_deploy", id(obj), *obj.__reduce_deploy__(importers), ) return _serialized_reduces[id(obj)] return None # Write the pickle data for `obj` data_buf = io.BytesIO() pickler = create_pickler(data_buf, importers) pickler.persistent_id = persistent_id pickler.dump(obj) data_value = data_buf.getvalue() return ( data_value, serialized_storages, serialized_dtypes, importer.zip_reader if importer else None, )
def test_ordered_importer_basic(self): import package_a buffer = BytesIO() with PackageExporter(buffer, verbose=False) as pe: pe.save_module(package_a.__name__) buffer.seek(0) importer = PackageImporter(buffer) ordered_importer_sys_first = OrderedImporter(sys_importer, importer) self.assertIs(ordered_importer_sys_first.import_module('package_a'), package_a) ordered_importer_package_first = OrderedImporter(importer, sys_importer) self.assertIs(ordered_importer_package_first.import_module('package_a'), importer.import_module('package_a'))
def _load_storages(id, zip_reader, obj_bytes, serialized_storages): def persistent_load(saved_id): assert isinstance(saved_id, tuple) typename = _maybe_decode_ascii(saved_id[0]) data = saved_id[1:] assert typename == 'storage', \ f"Unknown typename for persistent_load, expected 'storage' but got '{typename}'" return serialized_storages[data[0]] importer: Importer if zip_reader is not None: importer = OrderedImporter(_get_package(zip_reader), sys_importer) else: importer = sys_importer unpickler = PackageUnpickler(importer, io.BytesIO(obj_bytes)) unpickler.persistent_load = persistent_load result = _deploy_objects[id] = unpickler.load() return result