Exemplo n.º 1
0
def load_from_file(load_path, load_data=True, custom_objects=None):
    """Load model data from a .zip archive
    :param load_path: (str or file-like) Where to load model from
    :param load_data: (bool) Whether we should load and return data
        (class parameters). Mainly used by `load_parameters` to
        only load model parameters (weights).
    :param custom_objects: (dict) Dictionary of objects to replace
        upon loading. If a variable is present in this dictionary as a
        key, it will not be deserialized and the corresponding item
        will be used instead. Similar to custom_objects in
        `keras.models.load_model`. Useful when you have an object in
        file that can not be deserialized.
    :return: (dict, OrderedDict) Class parameters and model parameters
    """
    # Check if file exists if load_path is
    # a string
    if isinstance(load_path, str):
        if not os.path.exists(load_path):
            if os.path.exists(load_path + ".zip"):
                load_path += ".zip"
            else:
                raise ValueError(
                    "Error: the file {} could not be found".format(load_path))

    # Open the zip archive and load data.
    try:
        with zipfile.ZipFile(load_path, "r") as file_:
            namelist = file_.namelist()
            # If data or parameters is not in the
            # zip archive, assume they were stored
            # as None (_save_to_file allows this).
            data = None
            params = None
            if "data" in namelist and load_data:
                # Load class parameters and convert to string
                # (Required for json library in Python 3.5)
                json_data = file_.read("data").decode()
                data = json_to_data(json_data, custom_objects=custom_objects)

            if "parameters" in namelist:
                # Load parameter list and and parameters
                parameter_list_json = file_.read("parameter_list").decode()
                parameter_list = json.loads(parameter_list_json)
                serialized_params = file_.read("parameters")
                params = bytes_to_params(serialized_params, parameter_list)
    except zipfile.BadZipFile:
        # load_path wasn't a zip file. Possibly a cloudpickle
        # file. Show a warning and fall back to loading cloudpickle.
        warnings.warn(
            "It appears you are loading from a file with old format. " +
            "Older cloudpickle format has been replaced with zip-archived " +
            "models. Consider saving the model with new format.",
            DeprecationWarning)
        # Attempt loading with the cloudpickle format.
        # If load_path is file-like, seek back to beginning of file
        if not isinstance(load_path, str):
            load_path.seek(0)
        data, params = BaseRLModel._load_from_file_cloudpickle(load_path)

    return data, params
Exemplo n.º 2
0
    def __init__(self, load_path):
        self.load_path = load_path
        if load_path.endswith(".pkl"):
            with open(load_path, "rb") as file_:
                self.data, self.params = cloudpickle.load(file_)

        elif load_path.endswith(".zip"):
            with zipfile.ZipFile(load_path, "r") as file_:
                namelist = file_.namelist()
                # If data or parameters is not in the
                # zip archive, assume they were stored
                # as None (_save_to_file allows this).
                data = None
                params = None
                custom_objects = None
                load_data = True
                if "data" in namelist and load_data:
                    # Load class parameters and convert to string
                    # (Required for json library in Python 3.5)
                    json_data = file_.read("data").decode()
                    data = json_to_data(json_data,
                                        custom_objects=custom_objects)
                    self.data = data
                if "parameters" in namelist:
                    # Load parameter list and and parameters
                    parameter_list_json = file_.read("parameter_list").decode()
                    parameter_list = json.loads(parameter_list_json)
                    serialized_params = file_.read("parameters")
                    params = bytes_to_params(serialized_params, parameter_list)
                    self.params = params
        else:
            raise RuntimeError("bad file path for stable baselines load")
Exemplo n.º 3
0
    def load(self, path):

        # load data

        d_path = osp.join(path, 'params')
        with open(d_path, 'r') as f:
            json_data = f.read()
        data = json_to_data(json_data)

        #  weights

        w_path = osp.join(path, 'model')

        return data, w_path
Exemplo n.º 4
0
    def _load_from_file(load_path, load_data=True, custom_objects=None):
        """Load model data from a .zip archive

        :param load_path: (str or file-like) Where to load model from
        :param load_data: (bool) Whether we should load and return data (class parameters). Mainly used by `load_parameters` to
            only load model parameters (weights).
        :param custom_objects: (dict) Dictionary of objects to replace
            upon loading. If a variable is present in this dictionary as a key, it will not be deserialized and the corresponding item
            will be used instead. Similar to custom_objects in `keras.models.load_model`. Useful when you have an object in
            file that can not be deserialized.
        :return: (dict, OrderedDict) Class parameters and model parameters
        """
        # Check if file exists if load_path is a string
        if isinstance(load_path, str):
            if not os.path.exists(load_path):
                if os.path.exists(load_path + ".zip"):
                    load_path += ".zip"
                else:
                    raise ValueError("Error: the file {} could not be found".format(load_path))

        # Open the zip archive and load data.
        with zipfile.ZipFile(load_path, "r") as file_:
            namelist = file_.namelist()
            # If data or parameters is not in the zip archive, assume they were stored as None (_save_to_file allows this).
            data = None
            params = None
            if "data" in namelist and load_data:
                # Load class parameters and convert to string
                # (Required for json library in Python 3.5)
                json_data = file_.read("data").decode()
                data = save_util.json_to_data(json_data, custom_objects=custom_objects)

            if "parameters" in namelist:
                # Load parameter list and and parameters
                parameter_list_json = file_.read("parameter_list").decode()
                parameter_list = json.loads(parameter_list_json)
                serialized_params = file_.read("parameters")
                params = save_util.bytes_to_params(serialized_params, parameter_list)

        return data, params