def load_from_zip_file(load_path, load_data=True, custom_objects=None, device="auto", verbose=0): """ Load model data from a .zip archive. :param load_path: Where to load the model from :param load_data: Whether we should load and return data :param custom_objects: Dictionary of objects to replace upon loading :param device: Device on which the code should run :return: Class parameters, model state_dicts (aka "params", dict of state_dict) and dict of pytorch variables """ load_path = open_path(load_path, "r", verbose=verbose, suffix="zip") # set device to cpu if cuda is not available device = get_device(device=device) # open the zip archive and load data try: with zipfile.ZipFile(load_path) as archive: namelist = archive.namelist() # if data or parameters is not in the zip archive, assume they were stored as None data = None pytorch_variables = None params = {} if "data" in namelist and load_data: # load class parameters that are stored with either JSON or pickle (not PyTorch variables) json_data = archive.read("data").decode() data = json_to_data(json_data, custom_objects=custom_objects) # check for all .pth files and load them using torch.load pth_files = [ file_name for file_name in namelist if os.path.splitext(file_name)[1] == ".pth" ] for file_path in pth_files: with archive.open(file_path, mode="r") as param_file: # file has to be seekable, but param_file is not, so load in BytesIO first file_content = io.BytesIO() file_content.write(param_file.read()) # go to start of file file_content.seek(0) # load the parameters with the right map_location, remove .pth ending with splitext th_object = torch.load(file_content, map_location=device) if file_path == "pytorch_variables.pth" or file_path == "tensors.pth": # PyTorch variables (not state_dicts) pytorch_variables = th_object else: # state dicts params[os.path.splitext(file_path)[0]] = th_object except zipfile.BadZipFile: # load_path wasn't a zip file raise ValueError(f"Error: the file {load_path} wasn't a zip-file") return data, params, pytorch_variables
def _load_from_file(load_path: str, load_data: bool = True) -> (Tuple[Optional[Dict[str, Any]], Optional[TensorDict], Optional[TensorDict]]): """ Load model data from a .zip archive :param load_path: Where to load the model from :param load_data: Whether we should load and return data (class parameters). Mainly used by 'load_parameters' to only load model parameters (weights) :return: (dict),(dict),(dict) Class parameters, model state_dicts (dict of state_dict) and dict of extra tensors """ # Check if file exists if load_path is a string if isinstance(load_path, str): if not os.path.exists(load_path): if os.path.exists(load_path + ".zip"): load_path += ".zip" else: raise ValueError(f"Error: the file {load_path} could not be found") # set device to cpu if cuda is not available device = get_device() # Open the zip archive and load data try: with zipfile.ZipFile(load_path, "r") as archive: namelist = archive.namelist() # If data or parameters is not in the # zip archive, assume they were stored # as None (_save_to_file_zip allows this). data = None tensors = None params = {} if "data" in namelist and load_data: # Load class parameters and convert to string json_data = archive.read("data").decode() data = json_to_data(json_data) if "tensors.pth" in namelist and load_data: # Load extra tensors with archive.open('tensors.pth', mode="r") as tensor_file: # File has to be seekable, but opt_param_file is not, so load in BytesIO first # fixed in python >= 3.7 file_content = io.BytesIO() file_content.write(tensor_file.read()) # go to start of file file_content.seek(0) # load the parameters with the right ``map_location`` tensors = th.load(file_content, map_location=device) # check for all other .pth files other_files = [file_name for file_name in namelist if os.path.splitext(file_name)[1] == ".pth" and file_name != "tensors.pth"] # if there are any other files which end with .pth and aren't "params.pth" # assume that they each are optimizer parameters if len(other_files) > 0: for file_path in other_files: with archive.open(file_path, mode="r") as opt_param_file: # File has to be seekable, but opt_param_file is not, so load in BytesIO first # fixed in python >= 3.7 file_content = io.BytesIO() file_content.write(opt_param_file.read()) # go to start of file file_content.seek(0) # load the parameters with the right ``map_location`` params[os.path.splitext(file_path)[0]] = th.load(file_content, map_location=device) except zipfile.BadZipFile: # load_path wasn't a zip file raise ValueError(f"Error: the file {load_path} wasn't a zip-file") return data, params, tensors