def load_pickle(self, package: str, resource: str, map_location=None) -> Any: """Unpickles the resource from the package, loading any modules that are needed to construct the objects using :meth:`import_module` Args: package (str): The name of module package (e.g. "my_package.my_subpackage") resource (str): The unique name for the resource. map_location: Passed to `torch.load` to determine how tensors are mapped to devices. Defaults to None. Returns: Any: the unpickled object. """ pickle_file = self._zipfile_path(package, resource) restore_location = _get_restore_location(map_location) loaded_storages = {} def load_tensor(data_type, size, key, location, restore_location): name = f'.data/{key}.storage' dtype = data_type(0).dtype storage = self.zip_reader.get_storage_from_record( name, size, dtype).storage() loaded_storages[key] = restore_location(storage, location) def persistent_load(saved_id): assert isinstance(saved_id, tuple) typename = _maybe_decode_ascii(saved_id[0]) data = saved_id[1:] assert typename == 'storage', \ f"Unknown typename for persistent_load, expected 'storage' but got '{typename}'" data_type, key, location, size = data if key not in loaded_storages: load_tensor(data_type, size, key, _maybe_decode_ascii(location), restore_location) storage = loaded_storages[key] return storage # Load the data (which may in turn use `persistent_load` to load tensors) data_file = io.BytesIO(self.zip_reader.get_record(pickle_file)) unpickler = self.Unpickler(data_file) unpickler.persistent_load = persistent_load result = unpickler.load() # TODO from zdevito: # This stateful weird function will need to be removed in our efforts # to unify the format. It has a race condition if multiple python # threads try to read independent files torch._utils._validate_loaded_sparse_tensors() return result
def load_pickle(self, package: str, resource: str, map_location=None) -> Any: """Unpickles the resource from the package, loading any modules that are needed to construct the objects using :meth:`import_module`. Args: package (str): The name of module package (e.g. ``"my_package.my_subpackage"``). resource (str): The unique name for the resource. map_location: Passed to `torch.load` to determine how tensors are mapped to devices. Defaults to ``None``. Returns: Any: The unpickled object. """ pickle_file = self._zipfile_path(package, resource) restore_location = _get_restore_location(map_location) loaded_storages = {} loaded_reduces = {} storage_context = torch._C.DeserializationStorageContext() def load_tensor(dtype, size, key, location, restore_location): name = f"{key}.storage" if storage_context.has_storage(name): storage = storage_context.get_storage(name, dtype).storage() else: tensor = self.zip_reader.get_storage_from_record( ".data/" + name, size, dtype) if isinstance(self.zip_reader, torch._C.PyTorchFileReader): storage_context.add_storage(name, tensor) storage = tensor.storage() loaded_storages[key] = restore_location(storage, location) def persistent_load(saved_id): assert isinstance(saved_id, tuple) typename = _maybe_decode_ascii(saved_id[0]) data = saved_id[1:] if typename == "storage": storage_type, key, location, size = data dtype = storage_type.dtype if key not in loaded_storages: load_tensor( dtype, size, key, _maybe_decode_ascii(location), restore_location, ) storage = loaded_storages[key] # TODO: Once we decide to break serialization FC, we can # stop wrapping with _TypedStorage return torch.storage._TypedStorage( wrap_storage=storage._untyped(), dtype=dtype) elif typename == "reduce_package": # to fix BC breaking change, objects on this load path # will be loaded multiple times erroneously if len(data) == 2: func, args = data return func(self, *args) reduce_id, func, args = data if reduce_id not in loaded_reduces: loaded_reduces[reduce_id] = func(self, *args) return loaded_reduces[reduce_id] else: f"Unknown typename for persistent_load, expected 'storage' or 'reduce_package' but got '{typename}'" # Load the data (which may in turn use `persistent_load` to load tensors) data_file = io.BytesIO(self.zip_reader.get_record(pickle_file)) unpickler = self.Unpickler(data_file) unpickler.persistent_load = persistent_load @contextmanager def set_deserialization_context(): # to let reduce_package access deserializaiton context self.storage_context = storage_context self.last_map_location = map_location try: yield finally: self.storage_context = None self.last_map_location = None with set_deserialization_context(): result = unpickler.load() # TODO from zdevito: # This stateful weird function will need to be removed in our efforts # to unify the format. It has a race condition if multiple python # threads try to read independent files torch._utils._validate_loaded_sparse_tensors() return result
def _legacy_load(f, map_location, pickle_module, **pickle_load_args): deserialized_objects: Dict[int, Any] = {} restore_location = _get_restore_location(map_location) def _check_container_source(container_type, source_file, original_source): try: current_source = ''.join( get_source_lines_and_file(container_type)[0]) except Exception: # saving the source is optional, so we can ignore any errors warnings.warn("Couldn't retrieve source code for container of " "type " + container_type.__name__ + ". It won't be checked " "for correctness upon loading.") return if original_source != current_source: if container_type.dump_patches: file_name = container_type.__name__ + '.patch' diff = difflib.unified_diff(current_source.split('\n'), original_source.split('\n'), source_file, source_file, lineterm="") lines = '\n'.join(diff) try: with open(file_name, 'a+') as f: file_size = f.seek(0, 2) f.seek(0) if file_size == 0: f.write(lines) elif file_size != len(lines) or f.read() != lines: raise IOError msg = ("Saved a reverse patch to " + file_name + ". " "Run `patch -p0 < " + file_name + "` to revert your " "changes.") except IOError: msg = ("Tried to save a patch, but couldn't create a " "writable file " + file_name + ". Make sure it " "doesn't exist and your working directory is " "writable.") else: msg = ("you can retrieve the original source code by " "accessing the object's source attribute or set " "`torch.nn.Module.dump_patches = True` and use the " "patch tool to revert the changes.") msg = f"source code of class '{torch.typename(container_type)}' has changed. {msg}" warnings.warn(msg, SourceChangeWarning) def legacy_load(f, obj=None): deserialized_objects: Dict[int, Any] = {} def persistent_load(saved_id): if isinstance(saved_id, tuple): # Ignore containers that don't have any sources saved if all(saved_id[1:]): _check_container_source(*saved_id) return saved_id[0] return deserialized_objects[int(saved_id)] with closing(tarfile.open(fileobj=f, mode='r:', format=tarfile.PAX_FORMAT)) as tar, \ mkdtemp() as tmpdir: tar.extract('storages', path=tmpdir) with open(os.path.join(tmpdir, 'storages'), 'rb', 0) as f: num_storages = pickle_module.load(f, **pickle_load_args) for i in range(num_storages): args = pickle_module.load(f, **pickle_load_args) key, location, storage_type = args obj = storage_type._new_with_file(f) obj = restore_location(obj, location) deserialized_objects[key] = obj storage_views = pickle_module.load(f, **pickle_load_args) for target_cdata, root_cdata, offset, size in storage_views: root = deserialized_objects[root_cdata] deserialized_objects[target_cdata] = root[offset:offset + size] tar.extract('tensors', path=tmpdir) with open(os.path.join(tmpdir, 'tensors'), 'rb', 0) as f: num_tensors = pickle_module.load(f, **pickle_load_args) for _ in range(num_tensors): args = pickle_module.load(f, **pickle_load_args) key, storage_id, original_tensor_type = args storage = deserialized_objects[storage_id] tensor_type = storage_to_tensor_type(storage) ndim, = struct.unpack('<i', f.read(4)) # skip next 4 bytes; legacy encoding treated ndim as 8 bytes f.read(4) size = struct.unpack(f'<{ndim}q', f.read(8 * ndim)) stride = struct.unpack(f'<{ndim}q', f.read(8 * ndim)) storage_offset, = struct.unpack('<q', f.read(8)) tensor = tensor_type().set_(storage, storage_offset, size, stride) deserialized_objects[key] = tensor pickle_file = tar.extractfile('pickle') unpickler = pickle_module.Unpickler(pickle_file, **pickle_load_args) unpickler.persistent_load = persistent_load result = unpickler.load() return result deserialized_objects = {} def persistent_load(saved_id): assert isinstance(saved_id, tuple) typename = _maybe_decode_ascii(saved_id[0]) data = saved_id[1:] if typename == 'module': # Ignore containers that don't have any sources saved if all(data[1:]): _check_container_source(*data) return data[0] elif typename == 'storage': data_type, root_key, location, size, view_metadata = data location = _maybe_decode_ascii(location) if root_key not in deserialized_objects: obj = data_type(size) obj._torch_load_uninitialized = True s = str(root_key) + '.bint' if not os.path.isfile(s): with open(s, 'wb') as ff: obj._write_file(ff, True, False) obj = obj.__class__.from_file(s, shared=1, size=size) deserialized_objects[root_key] = restore_location( obj, location) storage = deserialized_objects[root_key] if view_metadata is not None: view_key, offset, view_size = view_metadata if view_key not in deserialized_objects: deserialized_objects[view_key] = storage[offset:offset + view_size] return deserialized_objects[view_key] else: return storage else: raise RuntimeError("Unknown saved id type: %s" % saved_id[0]) _check_seekable(f) f_should_read_directly = _should_read_directly(f) if f_should_read_directly and f.tell() == 0: # legacy_load requires that f has fileno() # only if offset is zero we can attempt the legacy tar file loader try: return legacy_load(f) except tarfile.TarError: if _is_zipfile(f): # .zip is used for torch.jit.save and will throw an un-pickling error here raise RuntimeError( f"{f.name} is a zip archive (did you mean to use torch.jit.load()?)" ) from None # if not a tarfile, reset file offset and proceed f.seek(0) if not hasattr(f, 'readinto') and (3, 8, 0) <= sys.version_info < (3, 8, 2): raise RuntimeError( "torch.load does not work with file-like objects that do not implement readinto on Python 3.8.0 and 3.8.1. " f"Received object of type \"{type(f)}\". Please update to Python 3.8.2 or newer to restore this " "functionality.") magic_number = pickle_module.load(f, **pickle_load_args) if magic_number != MAGIC_NUMBER: raise RuntimeError("Invalid magic number; corrupt file?") protocol_version = pickle_module.load(f, **pickle_load_args) if protocol_version != PROTOCOL_VERSION: raise RuntimeError("Invalid protocol version: %s" % protocol_version) _sys_info = pickle_module.load(f, **pickle_load_args) unpickler = pickle_module.Unpickler(f, **pickle_load_args) unpickler.persistent_load = persistent_load result = unpickler.load() deserialized_storage_keys = pickle_module.load(f, **pickle_load_args) offset = f.tell() if f_should_read_directly else None for key in tqdm(deserialized_storage_keys): assert key in deserialized_objects deserialized_objects[key]._set_from_file(f, offset, f_should_read_directly) if offset is not None: offset = f.tell() torch._utils._validate_loaded_sparse_tensors() return result
def _safe_legacy_load(f): MAGIC_NUMBER = 0x1950A86A20F9469CFC6C PROTOCOL_VERSION = 1001 deserialized_objects = {} restore_location = _get_restore_location(None) def _check_container_source(container_type, source_file, original_source): try: current_source = "".join( get_source_lines_and_file(container_type)[0]) except Exception: # saving the source is optional, so we can ignore any errors warnings.warn("Couldn't retrieve source code for container of " "type " + container_type.__name__ + ". It won't be checked " "for correctness upon loading.") return if original_source != current_source: if container_type.dump_patches: file_name = container_type.__name__ + ".patch" diff = difflib.unified_diff( current_source.split("\n"), original_source.split("\n"), source_file, source_file, lineterm="", ) lines = "\n".join(diff) try: with open(file_name, "a+") as f: file_size = f.seek(0, 2) f.seek(0) if file_size == 0: f.write(lines) elif file_size != len(lines) or f.read() != lines: raise IOError msg = ("Saved a reverse patch to " + file_name + ". " "Run `patch -p0 < " + file_name + "` to revert your " "changes.") except IOError: msg = ("Tried to save a patch, but couldn't create a " "writable file " + file_name + ". Make sure it " "doesn't exist and your working directory is " "writable.") else: msg = ("you can retrieve the original source code by " "accessing the object's source attribute or set " "`torch.nn.Module.dump_patches = True` and use the " "patch tool to revert the changes.") msg = "source code of class '{container_type}' has changed. {msg}".format( container_type=torch.typename(container_type), msg=msg) warnings.warn(msg, SourceChangeWarning) def legacy_load(f): deserialized_objects = {} def persistent_load(saved_id): if isinstance(saved_id, tuple): # Ignore containers that don't have any sources saved if all(saved_id[1:]): _check_container_source(*saved_id) return saved_id[0] return deserialized_objects[int(saved_id)] with closing( tarfile.open( fileobj=f, mode="r:", format=tarfile.PAX_FORMAT)) as tar, mkdtemp() as tmpdir: tar.extract("storages", path=tmpdir) with open(os.path.join(tmpdir, "storages"), "rb", 0) as f: num_storages = RestrictedUnpickler(f).load() for _ in range(num_storages): args = RestrictedUnpickler(f).load() key, location, storage_type = args obj = storage_type._new_with_file(f) obj = restore_location(obj, location) deserialized_objects[key] = obj storage_views = RestrictedUnpickler(f).load() for target_cdata, root_cdata, offset, size in storage_views: root = deserialized_objects[root_cdata] deserialized_objects[target_cdata] = root[offset:offset + size] tar.extract("tensors", path=tmpdir) with open(os.path.join(tmpdir, "tensors"), "rb", 0) as f: num_tensors = RestrictedUnpickler(f).load() for _ in range(num_tensors): args = RestrictedUnpickler(f).load() key, storage_id, original_tensor_type = args storage = deserialized_objects[storage_id] tensor_type = storage_to_tensor_type(storage) (ndim, ) = struct.unpack("<i", f.read(4)) # skip next 4 bytes; legacy encoding treated ndim as 8 bytes f.read(4) size = struct.unpack("<{}q".format(ndim), f.read(8 * ndim)) stride = struct.unpack("<{}q".format(ndim), f.read(8 * ndim)) (storage_offset, ) = struct.unpack("<q", f.read(8)) tensor = tensor_type().set_(storage, storage_offset, size, stride) deserialized_objects[key] = tensor pickle_file = tar.extractfile("pickle") unpickler = RestrictedUnpickler(pickle_file) unpickler.persistent_load = persistent_load result = unpickler.load() return result deserialized_objects = {} def persistent_load(saved_id): assert isinstance(saved_id, tuple) typename = _maybe_decode_ascii(saved_id[0]) data = saved_id[1:] if typename == "module": # Ignore containers that don't have any sources saved if all(data[1:]): _check_container_source(*data) return data[0] elif typename == "storage": data_type, root_key, location, size, view_metadata = data location = _maybe_decode_ascii(location) if root_key not in deserialized_objects: obj = data_type(size) obj._torch_load_uninitialized = True deserialized_objects[root_key] = restore_location( obj, location) storage = deserialized_objects[root_key] if view_metadata is not None: view_key, offset, view_size = view_metadata if view_key not in deserialized_objects: deserialized_objects[view_key] = storage[offset:offset + view_size] return deserialized_objects[view_key] else: return storage else: raise RuntimeError("Unknown saved id type: %s" % saved_id[0]) _check_seekable(f) f_should_read_directly = _should_read_directly(f) if f_should_read_directly and f.tell() == 0: # legacy_load requires that f has fileno() # only if offset is zero we can attempt the legacy tar file loader try: return legacy_load(f) except tarfile.TarError: if _is_zipfile(f): # .zip is used for torch.jit.save and will throw an un-pickling error raise RuntimeError( f"{f.name} is a zip archive (did you mean to use torch.jit.load()?)" ) # if not a tarfile, reset file offset and proceed f.seek(0) if not hasattr(f, "readinto") and (3, 8, 0) <= sys.version_info < (3, 8, 2): raise RuntimeError( "torch.load does not work with file-like objects that do not implement" "readinto on Python 3.8.0 and 3.8.1. Received object of type" '"{}". Please update to Python 3.8.2 or newer to restore this' "functionality.".format(type(f))) magic_number = RestrictedUnpickler(f).load() if magic_number != MAGIC_NUMBER: raise RuntimeError("Invalid magic number; corrupt file?") protocol_version = RestrictedUnpickler(f).load() if protocol_version != PROTOCOL_VERSION: raise RuntimeError("Invalid protocol version: %s" % protocol_version) _ = RestrictedUnpickler(f).load() # _sys_info unpickler = RestrictedUnpickler(f) unpickler.persistent_load = persistent_load result = unpickler.load() deserialized_storage_keys = RestrictedUnpickler(f).load() offset = f.tell() if f_should_read_directly else None for key in deserialized_storage_keys: assert key in deserialized_objects deserialized_objects[key]._set_from_file(f, offset, f_should_read_directly) if offset is not None: offset = f.tell() return result