def load(cls, load_path: str, env: Optional[GymEnv] = None, **kwargs): """ Load the model from a zip-file :param load_path: the location of the saved data :param env: the new environment to run the loaded model on (can be None if you only need prediction from a trained model) has priority over any saved environment :param kwargs: extra arguments to change the model when loading """ data, params, tensors = load_from_zip_file(load_path) if 'policy_kwargs' in data: for arg_to_remove in ['device']: if arg_to_remove in data['policy_kwargs']: del data['policy_kwargs'][arg_to_remove] if 'policy_kwargs' in kwargs and kwargs['policy_kwargs'] != data['policy_kwargs']: raise ValueError(f"The specified policy kwargs do not equal the stored policy kwargs." f"Stored kwargs: {data['policy_kwargs']}, specified kwargs: {kwargs['policy_kwargs']}") # check if observation space and action space are part of the saved parameters if ("observation_space" not in data or "action_space" not in data) and "env" not in data: raise ValueError("The observation_space and action_space was not given, can't verify new environments") # check if given env is valid if env is not None: check_for_correct_spaces(env, data["observation_space"], data["action_space"]) # if no new env was given use stored env if possible if env is None and "env" in data: env = data["env"] # noinspection PyArgumentList model = cls(policy=data["policy_class"], env=env, device='auto', _init_setup_model=False) # load parameters model.__dict__.update(data) model.__dict__.update(kwargs) if not hasattr(model, "_setup_model") and len(params) > 0: raise NotImplementedError(f"{cls} has no ``_setup_model()`` method") model._setup_model() # put state_dicts back in place for name in params: attr = recursive_getattr(model, name) attr.load_state_dict(params[name]) # put tensors back in place if tensors is not None: for name in tensors: recursive_setattr(model, name, tensors[name]) # Sample gSDE exploration matrix, so it uses the right device # see issue #44 if model.use_sde: model.policy.reset_noise() return model
def load( cls, path: Union[str, pathlib.Path, io.BufferedIOBase], env: Optional[GymEnv] = None, device: Union[th.device, str] = "auto", custom_objects: Optional[Dict[str, Any]] = None, **kwargs, ) -> "BaseAlgorithm": """ Load the model from a zip-file :param path: path to the file (or a file-like) where to load the agent from :param env: the new environment to run the loaded model on (can be None if you only need prediction from a trained model) has priority over any saved environment :param device: Device on which the code should run. :param custom_objects: Dictionary of objects to replace upon loading. If a variable is present in this dictionary as a key, it will not be deserialized and the corresponding item will be used instead. Similar to custom_objects in ``keras.models.load_model``. Useful when you have an object in file that can not be deserialized. :param kwargs: extra arguments to change the model when loading """ data, params, pytorch_variables = load_from_zip_file( path, device=device, custom_objects=custom_objects) # Remove stored device information and replace with ours if "policy_kwargs" in data: if "device" in data["policy_kwargs"]: del data["policy_kwargs"]["device"] if "policy_kwargs" in kwargs and kwargs["policy_kwargs"] != data[ "policy_kwargs"]: raise ValueError( f"The specified policy kwargs do not equal the stored policy kwargs." f"Stored kwargs: {data['policy_kwargs']}, specified kwargs: {kwargs['policy_kwargs']}" ) if "observation_space" not in data or "action_space" not in data: raise KeyError( "The observation_space and action_space were not given, can't verify new environments" ) if env is not None: # Wrap first if needed env = cls._wrap_env(env, data["verbose"]) # Check if given env is valid check_for_correct_spaces(env, data["observation_space"], data["action_space"]) else: # Use stored env, if one exists. If not, continue as is (can be used for predict) if "env" in data: env = data["env"] # noinspection PyArgumentList model = cls( # pytype: disable=not-instantiable,wrong-keyword-args policy=data["policy_class"], env=env, device=device, _init_setup_model=False, # pytype: disable=not-instantiable,wrong-keyword-args ) # load parameters model.__dict__.update(data) model.__dict__.update(kwargs) model._setup_model() # put state_dicts back in place model.set_parameters(params, exact_match=True, device=device) # put other pytorch variables back in place if pytorch_variables is not None: for name in pytorch_variables: # Set the data attribute directly to avoid issue when using optimizers # See https://github.com/DLR-RM/stable-baselines3/issues/391 recursive_setattr(model, name + ".data", pytorch_variables[name].data) # Sample gSDE exploration matrix, so it uses the right device # see issue #44 if model.use_sde: model.policy.reset_noise() # pytype: disable=attribute-error return model
def load( cls, path: Union[str, pathlib.Path, io.BufferedIOBase], env: Optional[GymEnv] = None, device: Union[th.device, str] = "auto", custom_objects: Optional[Dict[str, Any]] = None, print_system_info: bool = False, force_reset: bool = True, **kwargs, ) -> "BaseAlgorithm": """ Load the model from a zip-file. Warning: ``load`` re-creates the model from scratch, it does not update it in-place! For an in-place load use ``set_parameters`` instead. :param path: path to the file (or a file-like) where to load the agent from :param env: the new environment to run the loaded model on (can be None if you only need prediction from a trained model) has priority over any saved environment :param device: Device on which the code should run. :param custom_objects: Dictionary of objects to replace upon loading. If a variable is present in this dictionary as a key, it will not be deserialized and the corresponding item will be used instead. Similar to custom_objects in ``keras.models.load_model``. Useful when you have an object in file that can not be deserialized. :param print_system_info: Whether to print system info from the saved model and the current system info (useful to debug loading issues) :param force_reset: Force call to ``reset()`` before training to avoid unexpected behavior. See https://github.com/DLR-RM/stable-baselines3/issues/597 :param kwargs: extra arguments to change the model when loading :return: new model instance with loaded parameters """ if print_system_info: print("== CURRENT SYSTEM INFO ==") get_system_info() data, params, pytorch_variables = load_from_zip_file( path, device=device, custom_objects=custom_objects, print_system_info=print_system_info) # Remove stored device information and replace with ours if "policy_kwargs" in data: if "device" in data["policy_kwargs"]: del data["policy_kwargs"]["device"] if "policy_kwargs" in kwargs and kwargs["policy_kwargs"] != data[ "policy_kwargs"]: raise ValueError( f"The specified policy kwargs do not equal the stored policy kwargs." f"Stored kwargs: {data['policy_kwargs']}, specified kwargs: {kwargs['policy_kwargs']}" ) if "observation_space" not in data or "action_space" not in data: raise KeyError( "The observation_space and action_space were not given, can't verify new environments" ) if env is not None: # Wrap first if needed env = cls._wrap_env(env, data["verbose"]) # Check if given env is valid check_for_correct_spaces(env, data["observation_space"], data["action_space"]) # Discard `_last_obs`, this will force the env to reset before training # See issue https://github.com/DLR-RM/stable-baselines3/issues/597 if force_reset and data is not None: data["_last_obs"] = None else: # Use stored env, if one exists. If not, continue as is (can be used for predict) if "env" in data: env = data["env"] # noinspection PyArgumentList model = cls( # pytype: disable=not-instantiable,wrong-keyword-args policy=data["policy_class"], env=env, device=device, _init_setup_model=False, # pytype: disable=not-instantiable,wrong-keyword-args ) # load parameters model.__dict__.update(data) model.__dict__.update(kwargs) model._setup_model() # put state_dicts back in place model.set_parameters(params, exact_match=True, device=device) # put other pytorch variables back in place if pytorch_variables is not None: for name in pytorch_variables: # Skip if PyTorch variable was not defined (to ensure backward compatibility). # This happens when using SAC/TQC. # SAC has an entropy coefficient which can be fixed or optimized. # If it is optimized, an additional PyTorch variable `log_ent_coef` is defined, # otherwise it is initialized to `None`. if pytorch_variables[name] is None: continue # Set the data attribute directly to avoid issue when using optimizers # See https://github.com/DLR-RM/stable-baselines3/issues/391 recursive_setattr(model, name + ".data", pytorch_variables[name].data) # Sample gSDE exploration matrix, so it uses the right device # see issue #44 if model.use_sde: model.policy.reset_noise() # pytype: disable=attribute-error return model
def load( cls, path: Union[str, pathlib.Path, io.BufferedIOBase], env: Optional[GymEnv] = None, device: Union[th.device, str] = "auto", **kwargs, ) -> "BaseAlgorithm": """ Load the model from a zip-file :param path: path to the file (or a file-like) where to load the agent from :param env: the new environment to run the loaded model on (can be None if you only need prediction from a trained model) has priority over any saved environment :param device: Device on which the code should run. :param kwargs: extra arguments to change the model when loading """ data, params, pytorch_variables = load_from_zip_file(path, device=device) # Remove stored device information and replace with ours if "policy_kwargs" in data: if "device" in data["policy_kwargs"]: del data["policy_kwargs"]["device"] if "policy_kwargs" in kwargs and kwargs["policy_kwargs"] != data[ "policy_kwargs"]: raise ValueError( f"The specified policy kwargs do not equal the stored policy kwargs." f"Stored kwargs: {data['policy_kwargs']}, specified kwargs: {kwargs['policy_kwargs']}" ) if "observation_space" not in data or "action_space" not in data: raise KeyError( "The observation_space and action_space were not given, can't verify new environments" ) if env is not None: # Wrap first if needed env = cls._wrap_env(env, data["verbose"]) # Check if given env is valid check_for_correct_spaces(env, data["observation_space"], data["action_space"]) else: # Use stored env, if one exists. If not, continue as is (can be used for predict) if "env" in data: env = data["env"] # noinspection PyArgumentList model = cls( policy=data["policy_class"], env=env, device=device, _init_setup_model=False, # pytype: disable=not-instantiable,wrong-keyword-args ) # load parameters model.__dict__.update(data) model.__dict__.update(kwargs) model._setup_model() # import pdb; pdb.set_trace() # initial_parameters = model.get_parameters() # model.set_some_parameters(initial_parameters, params, exact_match=True, device=device) # put state_dicts back in place model.set_parameters(params, exact_match=True, device=device) # put other pytorch variables back in place if pytorch_variables is not None: for name in pytorch_variables: recursive_setattr(model, name, pytorch_variables[name]) # Sample gSDE exploration matrix, so it uses the right device # see issue #44 if model.use_sde: model.policy.reset_noise() # pytype: disable=attribute-error return model
def load( cls, load_path: str, env: Optional[GymEnv] = None, device: Union[th.device, str] = "auto", **kwargs ) -> "BaseAlgorithm": """ Load the model from a zip-file :param load_path: the location of the saved data :param env: the new environment to run the loaded model on (can be None if you only need prediction from a trained model) has priority over any saved environment :param device: (Union[th.device, str]) Device on which the code should run. :param kwargs: extra arguments to change the model when loading """ data, params, tensors = load_from_zip_file(load_path, device=device) if "policy_kwargs" in data: for arg_to_remove in ["device"]: if arg_to_remove in data["policy_kwargs"]: del data["policy_kwargs"][arg_to_remove] if "policy_kwargs" in kwargs and kwargs["policy_kwargs"] != data["policy_kwargs"]: raise ValueError( f"The specified policy kwargs do not equal the stored policy kwargs." f"Stored kwargs: {data['policy_kwargs']}, specified kwargs: {kwargs['policy_kwargs']}" ) # check if observation space and action space are part of the saved parameters if "observation_space" not in data or "action_space" not in data: raise KeyError("The observation_space and action_space were not given, can't verify new environments") # check if given env is valid if env is not None: check_for_correct_spaces(env, data["observation_space"], data["action_space"]) # if no new env was given use stored env if possible if env is None and "env" in data: env = data["env"] # noinspection PyArgumentList model = cls( policy=data["policy_class"], env=env, device=device, _init_setup_model=False, # pytype: disable=not-instantiable,wrong-keyword-args ) # load parameters model.__dict__.update(data) model.__dict__.update(kwargs) model._setup_model() # put state_dicts back in place for name in params: attr = recursive_getattr(model, name) attr.load_state_dict(params[name]) # put tensors back in place if tensors is not None: for name in tensors: recursive_setattr(model, name, tensors[name]) # Sample gSDE exploration matrix, so it uses the right device # see issue #44 if model.use_sde: model.policy.reset_noise() # pytype: disable=attribute-error return model
def load( cls, path: Union[str, pathlib.Path, io.BufferedIOBase], env: Optional[GymEnv] = None, device: Union[th.device, str] = "auto", custom_objects: Optional[Dict[str, Any]] = None, **kwargs, ) -> "BaseAlgorithm": """ Load the model from a zip-file :param path: path to the file (or a file-like) where to load the agent from :param env: the new environment to run the loaded model on (can be None if you only need prediction from a trained model) has priority over any saved environment :param device: Device on which the code should run. :param custom_objects: Dictionary of objects to replace upon loading. If a variable is present in this dictionary as a key, it will not be deserialized and the corresponding item will be used instead. Similar to custom_objects in ``keras.models.load_model``. Useful when you have an object in file that can not be deserialized. :param kwargs: extra arguments to change the model when loading """ data, params, pytorch_variables = load_from_zip_file( path, device=device, custom_objects=custom_objects) # Remove stored device information and replace with ours if "policy_kwargs" in data: if "device" in data["policy_kwargs"]: del data["policy_kwargs"]["device"] if "policy_kwargs" in kwargs and kwargs["policy_kwargs"] != data[ "policy_kwargs"]: raise ValueError( f"The specified policy kwargs do not equal the stored policy kwargs." f"Stored kwargs: {data['policy_kwargs']}, specified kwargs: {kwargs['policy_kwargs']}" ) # check if observation space and action space are part of the saved parameters if "observation_space" not in data or "action_space" not in data: raise KeyError( "The observation_space and action_space were not given, can't verify new environments" ) # check if given env is valid if env is not None: # Wrap first if needed env = cls._wrap_env(env, data["verbose"]) # Check if given env is valid check_for_correct_spaces(env, data["observation_space"], data["action_space"]) else: # Use stored env, if one exists. If not, continue as is (can be used for predict) if "env" in data: env = data["env"] if "use_sde" in data and data["use_sde"]: kwargs["use_sde"] = True # Keys that cannot be changed for key in {"model_class", "online_sampling", "max_episode_length"}: if key in kwargs: del kwargs[key] # Keys that can be changed for key in {"n_sampled_goal", "goal_selection_strategy"}: if key in kwargs: data[key] = kwargs[key] # pytype: disable=unsupported-operands del kwargs[key] # noinspection PyArgumentList her_model = cls( policy=data["policy_class"], env=env, model_class=data["model_class"], n_sampled_goal=data["n_sampled_goal"], goal_selection_strategy=data["goal_selection_strategy"], online_sampling=data["online_sampling"], max_episode_length=data["max_episode_length"], policy_kwargs=data["policy_kwargs"], _init_setup_model=False, # pytype: disable=not-instantiable,wrong-keyword-args **kwargs, ) # load parameters her_model.model.__dict__.update(data) her_model.model.__dict__.update(kwargs) her_model._setup_model() her_model._total_timesteps = her_model.model._total_timesteps her_model.num_timesteps = her_model.model.num_timesteps her_model._episode_num = her_model.model._episode_num # put state_dicts back in place her_model.model.set_parameters(params, exact_match=True, device=device) # put other pytorch variables back in place if pytorch_variables is not None: for name in pytorch_variables: recursive_setattr(her_model.model, name, pytorch_variables[name]) # Sample gSDE exploration matrix, so it uses the right device # see issue #44 if her_model.model.use_sde: her_model.model.policy.reset_noise() # pytype: disable=attribute-error return her_model
def load(cls, path, env=None, device="auto", custom_objects=None, **kwargs): """ Load the model from a zip-file. :param path: path to the file (or a file-like) where to load the agent from :param env: the new environment to run the loaded model on :param device: Device on which the code should run :param custom_objects: Dictionary of objects to replace upon loading :param kwargs: extra arguments to change the model when loading """ data, params, pytorch_variables = load_from_zip_file( path, device=device, custom_objects=custom_objects) # remove stored device information and replace with ours if "policy_kwargs" in data: if "device" in data["policy_kwargs"]: del data["policy_kwargs"]["device"] if "policy_kwargs" in kwargs and kwargs["policy_kwargs"] != data[ "policy_kwargs"]: raise ValueError( f"The specified policy kwargs do not equal the stored policy kwargs." f"Stored kwargs: {data['policy_kwargs']}, specified kwargs: {kwargs['policy_kwargs']}" ) if "observation_space" not in data or "action_space" not in data: raise KeyError( "The observation_space and action_space were not given, can't verify new environments" ) if env is not None: # wrap first if needed env = cls._wrap_env(env, data["verbose"]) # check if given env is valid check_for_correct_spaces(env, data["observation_space"], data["action_space"]) else: # use stored env, if one exists. If not, continue as is (can be used for predict) if "env" in data: env = data["env"] model = cls(policy=data["policy_class"], env=env, device=device, _init_setup_model=False) # load parameters model.__dict__.update(data) model.__dict__.update(kwargs) model._setup_model() # put state_dicts back in place model.set_parameters(params, exact_match=True, device=device) # put other pytorch variables back in place if pytorch_variables is not None: for name in pytorch_variables: recursive_setattr(model, name, pytorch_variables[name]) # sample gSDE exploration matrix, so it uses the right device if model.use_sde: model.policy.reset_noise() return model