示例#1
0
def save_to_zip_file(save_path,
                     data=None,
                     params=None,
                     pytorch_variables=None,
                     verbose=0):
    """
    Save model data to a zip archive.

    :param save_path: Where to store the model.
    :param data: Class parameters being stored (non-PyTorch variables)
    :param params: Model parameters being stored expected to contain an entry for every state_dict
    :param pytorch_variables: Other PyTorch variables expected to contain name and value of the variable
    :param verbose: Verbosity level, 0 means only warnings, 2 means debug information
    """
    save_path = open_path(save_path, "w", verbose=0, suffix="zip")
    # data/params can be None, so do not
    # try to serialize them blindly
    if data is not None:
        serialized_data = data_to_json(data)

    # create a zip-archive and write our objects there.
    with zipfile.ZipFile(save_path, mode="w") as archive:
        # do not try to save "None" elements
        if data is not None:
            archive.writestr("data", serialized_data)
        if pytorch_variables is not None:
            with archive.open("pytorch_variables.pth",
                              mode="w") as pytorch_variables_file:
                torch.save(pytorch_variables,
                           pytorch_variables_file,
                           _use_new_zipfile_serialization=False)
        if params is not None:
            for file_name, dict_ in params.items():
                with archive.open(file_name + ".pth", mode="w") as param_file:
                    torch.save(dict_,
                               param_file,
                               _use_new_zipfile_serialization=False)

        # save metadata: library version when file was saved
        archive.writestr("_stable_baselines3_version",
                         stable_baselines3.__version__)
示例#2
0
    def _save_to_file_zip(save_path: str,
                          data: Dict[str, Any] = None,
                          params: Dict[str, Any] = None,
                          tensors: Dict[str, Any] = None) -> None:
        """
        Save model to a zip archive.

        :param save_path: Where to store the model
        :param data: Class parameters being stored
        :param params: Model parameters being stored expected to contain an entry for every
                       state_dict with its name and the state_dict
        :param tensors: Extra tensor variables expected to contain name and value of tensors
        """

        # data/params can be None, so do not
        # try to serialize them blindly
        if data is not None:
            serialized_data = data_to_json(data)

        # Check postfix if save_path is a string
        if isinstance(save_path, str):
            _, ext = os.path.splitext(save_path)
            if ext == "":
                save_path += ".zip"

        # Create a zip-archive and write our objects
        # there. This works when save_path is either
        # str or a file-like
        with zipfile.ZipFile(save_path, "w") as archive:
            # Do not try to save "None" elements
            if data is not None:
                archive.writestr("data", serialized_data)
            if tensors is not None:
                with archive.open('tensors.pth', mode="w") as tensors_file:
                    th.save(tensors, tensors_file)
            if params is not None:
                for file_name, dict_ in params.items():
                    with archive.open(file_name + '.pth',
                                      mode="w") as param_file:
                        th.save(dict_, param_file)