Пример #1
0
def load_from_zip_file(load_path,
                       load_data=True,
                       custom_objects=None,
                       device="auto",
                       verbose=0):
    """
    Load model data from a .zip archive.

    :param load_path: Where to load the model from
    :param load_data: Whether we should load and return data
    :param custom_objects: Dictionary of objects to replace upon loading
    :param device: Device on which the code should run
    :return: Class parameters, model state_dicts (aka "params", dict of state_dict) and dict of pytorch variables
    """
    load_path = open_path(load_path, "r", verbose=verbose, suffix="zip")

    # set device to cpu if cuda is not available
    device = get_device(device=device)

    # open the zip archive and load data
    try:
        with zipfile.ZipFile(load_path) as archive:
            namelist = archive.namelist()
            # if data or parameters is not in the zip archive, assume they were stored as None
            data = None
            pytorch_variables = None
            params = {}

            if "data" in namelist and load_data:
                # load class parameters that are stored with either JSON or pickle (not PyTorch variables)
                json_data = archive.read("data").decode()
                data = json_to_data(json_data, custom_objects=custom_objects)

            # check for all .pth files and load them using torch.load
            pth_files = [
                file_name for file_name in namelist
                if os.path.splitext(file_name)[1] == ".pth"
            ]
            for file_path in pth_files:
                with archive.open(file_path, mode="r") as param_file:
                    # file has to be seekable, but param_file is not, so load in BytesIO first
                    file_content = io.BytesIO()
                    file_content.write(param_file.read())
                    # go to start of file
                    file_content.seek(0)
                    # load the parameters with the right map_location, remove .pth ending with splitext
                    th_object = torch.load(file_content, map_location=device)
                    if file_path == "pytorch_variables.pth" or file_path == "tensors.pth":
                        # PyTorch variables (not state_dicts)
                        pytorch_variables = th_object
                    else:
                        # state dicts
                        params[os.path.splitext(file_path)[0]] = th_object

    except zipfile.BadZipFile:
        # load_path wasn't a zip file
        raise ValueError(f"Error: the file {load_path} wasn't a zip-file")

    return data, params, pytorch_variables
Пример #2
0
    def _load_from_file(load_path: str, load_data: bool = True) -> (Tuple[Optional[Dict[str, Any]],
                                                                          Optional[TensorDict],
                                                                          Optional[TensorDict]]):
        """ Load model data from a .zip archive

        :param load_path: Where to load the model from
        :param load_data: Whether we should load and return data
            (class parameters). Mainly used by 'load_parameters' to only load model parameters (weights)
        :return: (dict),(dict),(dict) Class parameters, model state_dicts (dict of state_dict)
            and dict of extra tensors
        """
        # Check if file exists if load_path is a string
        if isinstance(load_path, str):
            if not os.path.exists(load_path):
                if os.path.exists(load_path + ".zip"):
                    load_path += ".zip"
                else:
                    raise ValueError(f"Error: the file {load_path} could not be found")

        # set device to cpu if cuda is not available
        device = get_device()

        # Open the zip archive and load data
        try:
            with zipfile.ZipFile(load_path, "r") as archive:
                namelist = archive.namelist()
                # If data or parameters is not in the
                # zip archive, assume they were stored
                # as None (_save_to_file_zip allows this).
                data = None
                tensors = None
                params = {}

                if "data" in namelist and load_data:
                    # Load class parameters and convert to string
                    json_data = archive.read("data").decode()
                    data = json_to_data(json_data)

                if "tensors.pth" in namelist and load_data:
                    # Load extra tensors
                    with archive.open('tensors.pth', mode="r") as tensor_file:
                        # File has to be seekable, but opt_param_file is not, so load in BytesIO first
                        # fixed in python >= 3.7
                        file_content = io.BytesIO()
                        file_content.write(tensor_file.read())
                        # go to start of file
                        file_content.seek(0)
                        # load the parameters with the right ``map_location``
                        tensors = th.load(file_content, map_location=device)

                # check for all other .pth files
                other_files = [file_name for file_name in namelist if
                               os.path.splitext(file_name)[1] == ".pth" and file_name != "tensors.pth"]
                # if there are any other files which end with .pth and aren't "params.pth"
                # assume that they each are optimizer parameters
                if len(other_files) > 0:
                    for file_path in other_files:
                        with archive.open(file_path, mode="r") as opt_param_file:
                            # File has to be seekable, but opt_param_file is not, so load in BytesIO first
                            # fixed in python >= 3.7
                            file_content = io.BytesIO()
                            file_content.write(opt_param_file.read())
                            # go to start of file
                            file_content.seek(0)
                            # load the parameters with the right ``map_location``
                            params[os.path.splitext(file_path)[0]] = th.load(file_content, map_location=device)

        except zipfile.BadZipFile:
            # load_path wasn't a zip file
            raise ValueError(f"Error: the file {load_path} wasn't a zip-file")
        return data, params, tensors