def load_from_file(load_path, load_data=True, custom_objects=None): """Load model data from a .zip archive :param load_path: (str or file-like) Where to load model from :param load_data: (bool) Whether we should load and return data (class parameters). Mainly used by `load_parameters` to only load model parameters (weights). :param custom_objects: (dict) Dictionary of objects to replace upon loading. If a variable is present in this dictionary as a key, it will not be deserialized and the corresponding item will be used instead. Similar to custom_objects in `keras.models.load_model`. Useful when you have an object in file that can not be deserialized. :return: (dict, OrderedDict) Class parameters and model parameters """ # Check if file exists if load_path is # a string if isinstance(load_path, str): if not os.path.exists(load_path): if os.path.exists(load_path + ".zip"): load_path += ".zip" else: raise ValueError( "Error: the file {} could not be found".format(load_path)) # Open the zip archive and load data. try: with zipfile.ZipFile(load_path, "r") as file_: namelist = file_.namelist() # If data or parameters is not in the # zip archive, assume they were stored # as None (_save_to_file allows this). data = None params = None if "data" in namelist and load_data: # Load class parameters and convert to string # (Required for json library in Python 3.5) json_data = file_.read("data").decode() data = json_to_data(json_data, custom_objects=custom_objects) if "parameters" in namelist: # Load parameter list and and parameters parameter_list_json = file_.read("parameter_list").decode() parameter_list = json.loads(parameter_list_json) serialized_params = file_.read("parameters") params = bytes_to_params(serialized_params, parameter_list) except zipfile.BadZipFile: # load_path wasn't a zip file. Possibly a cloudpickle # file. Show a warning and fall back to loading cloudpickle. warnings.warn( "It appears you are loading from a file with old format. " + "Older cloudpickle format has been replaced with zip-archived " + "models. Consider saving the model with new format.", DeprecationWarning) # Attempt loading with the cloudpickle format. # If load_path is file-like, seek back to beginning of file if not isinstance(load_path, str): load_path.seek(0) data, params = BaseRLModel._load_from_file_cloudpickle(load_path) return data, params
def __init__(self, load_path): self.load_path = load_path if load_path.endswith(".pkl"): with open(load_path, "rb") as file_: self.data, self.params = cloudpickle.load(file_) elif load_path.endswith(".zip"): with zipfile.ZipFile(load_path, "r") as file_: namelist = file_.namelist() # If data or parameters is not in the # zip archive, assume they were stored # as None (_save_to_file allows this). data = None params = None custom_objects = None load_data = True if "data" in namelist and load_data: # Load class parameters and convert to string # (Required for json library in Python 3.5) json_data = file_.read("data").decode() data = json_to_data(json_data, custom_objects=custom_objects) self.data = data if "parameters" in namelist: # Load parameter list and and parameters parameter_list_json = file_.read("parameter_list").decode() parameter_list = json.loads(parameter_list_json) serialized_params = file_.read("parameters") params = bytes_to_params(serialized_params, parameter_list) self.params = params else: raise RuntimeError("bad file path for stable baselines load")
def load(self, path): # load data d_path = osp.join(path, 'params') with open(d_path, 'r') as f: json_data = f.read() data = json_to_data(json_data) # weights w_path = osp.join(path, 'model') return data, w_path
def _load_from_file(load_path, load_data=True, custom_objects=None): """Load model data from a .zip archive :param load_path: (str or file-like) Where to load model from :param load_data: (bool) Whether we should load and return data (class parameters). Mainly used by `load_parameters` to only load model parameters (weights). :param custom_objects: (dict) Dictionary of objects to replace upon loading. If a variable is present in this dictionary as a key, it will not be deserialized and the corresponding item will be used instead. Similar to custom_objects in `keras.models.load_model`. Useful when you have an object in file that can not be deserialized. :return: (dict, OrderedDict) Class parameters and model parameters """ # Check if file exists if load_path is a string if isinstance(load_path, str): if not os.path.exists(load_path): if os.path.exists(load_path + ".zip"): load_path += ".zip" else: raise ValueError("Error: the file {} could not be found".format(load_path)) # Open the zip archive and load data. with zipfile.ZipFile(load_path, "r") as file_: namelist = file_.namelist() # If data or parameters is not in the zip archive, assume they were stored as None (_save_to_file allows this). data = None params = None if "data" in namelist and load_data: # Load class parameters and convert to string # (Required for json library in Python 3.5) json_data = file_.read("data").decode() data = save_util.json_to_data(json_data, custom_objects=custom_objects) if "parameters" in namelist: # Load parameter list and and parameters parameter_list_json = file_.read("parameter_list").decode() parameter_list = json.loads(parameter_list_json) serialized_params = file_.read("parameters") params = save_util.bytes_to_params(serialized_params, parameter_list) return data, params