def save_to_zip_file(save_path, data=None, params=None, pytorch_variables=None, verbose=0): """ Save model data to a zip archive. :param save_path: Where to store the model. :param data: Class parameters being stored (non-PyTorch variables) :param params: Model parameters being stored expected to contain an entry for every state_dict :param pytorch_variables: Other PyTorch variables expected to contain name and value of the variable :param verbose: Verbosity level, 0 means only warnings, 2 means debug information """ save_path = open_path(save_path, "w", verbose=0, suffix="zip") # data/params can be None, so do not # try to serialize them blindly if data is not None: serialized_data = data_to_json(data) # create a zip-archive and write our objects there. with zipfile.ZipFile(save_path, mode="w") as archive: # do not try to save "None" elements if data is not None: archive.writestr("data", serialized_data) if pytorch_variables is not None: with archive.open("pytorch_variables.pth", mode="w") as pytorch_variables_file: torch.save(pytorch_variables, pytorch_variables_file, _use_new_zipfile_serialization=False) if params is not None: for file_name, dict_ in params.items(): with archive.open(file_name + ".pth", mode="w") as param_file: torch.save(dict_, param_file, _use_new_zipfile_serialization=False) # save metadata: library version when file was saved archive.writestr("_stable_baselines3_version", stable_baselines3.__version__)
def _save_to_file_zip(save_path: str, data: Dict[str, Any] = None, params: Dict[str, Any] = None, tensors: Dict[str, Any] = None) -> None: """ Save model to a zip archive. :param save_path: Where to store the model :param data: Class parameters being stored :param params: Model parameters being stored expected to contain an entry for every state_dict with its name and the state_dict :param tensors: Extra tensor variables expected to contain name and value of tensors """ # data/params can be None, so do not # try to serialize them blindly if data is not None: serialized_data = data_to_json(data) # Check postfix if save_path is a string if isinstance(save_path, str): _, ext = os.path.splitext(save_path) if ext == "": save_path += ".zip" # Create a zip-archive and write our objects # there. This works when save_path is either # str or a file-like with zipfile.ZipFile(save_path, "w") as archive: # Do not try to save "None" elements if data is not None: archive.writestr("data", serialized_data) if tensors is not None: with archive.open('tensors.pth', mode="w") as tensors_file: th.save(tensors, tensors_file) if params is not None: for file_name, dict_ in params.items(): with archive.open(file_name + '.pth', mode="w") as param_file: th.save(dict_, param_file)