def test_get_system_info(): info, info_str = get_system_info(print_info=True) assert info["Stable-Baselines3"] == str(sb3.__version__) assert "Python" in info_str assert "PyTorch" in info_str assert "GPU Enabled" in info_str assert "Numpy" in info_str assert "Gym" in info_str
def save_to_zip_file( save_path: Union[str, pathlib.Path, io.BufferedIOBase], data: Optional[Dict[str, Any]] = None, params: Optional[Dict[str, Any]] = None, pytorch_variables: Optional[Dict[str, Any]] = None, verbose: int = 0, ) -> None: """ Save model data to a zip archive. :param save_path: Where to store the model. if save_path is a str or pathlib.Path ensures that the path actually exists. :param data: Class parameters being stored (non-PyTorch variables) :param params: Model parameters being stored expected to contain an entry for every state_dict with its name and the state_dict. :param pytorch_variables: Other PyTorch variables expected to contain name and value of the variable. :param verbose: Verbosity level, 0 means only warnings, 2 means debug information """ save_path = open_path(save_path, "w", verbose=0, suffix="zip") # data/params can be None, so do not # try to serialize them blindly if data is not None: serialized_data = data_to_json(data) # Create a zip-archive and write our objects there. with zipfile.ZipFile(save_path, mode="w") as archive: # Do not try to save "None" elements if data is not None: archive.writestr("data", serialized_data) if pytorch_variables is not None: with archive.open("pytorch_variables.pth", mode="w") as pytorch_variables_file: th.save(pytorch_variables, pytorch_variables_file) if params is not None: for file_name, dict_ in params.items(): with archive.open(file_name + ".pth", mode="w") as param_file: th.save(dict_, param_file) # Save metadata: library version when file was saved archive.writestr("_stable_baselines3_version", sb3.__version__) # Save system info about the current python env archive.writestr("system_info.txt", get_system_info(print_info=False)[1])
def load( cls, path: Union[str, pathlib.Path, io.BufferedIOBase], env: Optional[GymEnv] = None, device: Union[th.device, str] = "auto", custom_objects: Optional[Dict[str, Any]] = None, print_system_info: bool = False, force_reset: bool = True, **kwargs, ) -> "BaseAlgorithm": """ Load the model from a zip-file. Warning: ``load`` re-creates the model from scratch, it does not update it in-place! For an in-place load use ``set_parameters`` instead. :param path: path to the file (or a file-like) where to load the agent from :param env: the new environment to run the loaded model on (can be None if you only need prediction from a trained model) has priority over any saved environment :param device: Device on which the code should run. :param custom_objects: Dictionary of objects to replace upon loading. If a variable is present in this dictionary as a key, it will not be deserialized and the corresponding item will be used instead. Similar to custom_objects in ``keras.models.load_model``. Useful when you have an object in file that can not be deserialized. :param print_system_info: Whether to print system info from the saved model and the current system info (useful to debug loading issues) :param force_reset: Force call to ``reset()`` before training to avoid unexpected behavior. See https://github.com/DLR-RM/stable-baselines3/issues/597 :param kwargs: extra arguments to change the model when loading :return: new model instance with loaded parameters """ if print_system_info: print("== CURRENT SYSTEM INFO ==") get_system_info() data, params, pytorch_variables = load_from_zip_file( path, device=device, custom_objects=custom_objects, print_system_info=print_system_info) # Remove stored device information and replace with ours if "policy_kwargs" in data: if "device" in data["policy_kwargs"]: del data["policy_kwargs"]["device"] if "policy_kwargs" in kwargs and kwargs["policy_kwargs"] != data[ "policy_kwargs"]: raise ValueError( f"The specified policy kwargs do not equal the stored policy kwargs." f"Stored kwargs: {data['policy_kwargs']}, specified kwargs: {kwargs['policy_kwargs']}" ) if "observation_space" not in data or "action_space" not in data: raise KeyError( "The observation_space and action_space were not given, can't verify new environments" ) if env is not None: # Wrap first if needed env = cls._wrap_env(env, data["verbose"]) # Check if given env is valid check_for_correct_spaces(env, data["observation_space"], data["action_space"]) # Discard `_last_obs`, this will force the env to reset before training # See issue https://github.com/DLR-RM/stable-baselines3/issues/597 if force_reset and data is not None: data["_last_obs"] = None else: # Use stored env, if one exists. If not, continue as is (can be used for predict) if "env" in data: env = data["env"] # noinspection PyArgumentList model = cls( # pytype: disable=not-instantiable,wrong-keyword-args policy=data["policy_class"], env=env, device=device, _init_setup_model=False, # pytype: disable=not-instantiable,wrong-keyword-args ) # load parameters model.__dict__.update(data) model.__dict__.update(kwargs) model._setup_model() # put state_dicts back in place model.set_parameters(params, exact_match=True, device=device) # put other pytorch variables back in place if pytorch_variables is not None: for name in pytorch_variables: # Skip if PyTorch variable was not defined (to ensure backward compatibility). # This happens when using SAC/TQC. # SAC has an entropy coefficient which can be fixed or optimized. # If it is optimized, an additional PyTorch variable `log_ent_coef` is defined, # otherwise it is initialized to `None`. if pytorch_variables[name] is None: continue # Set the data attribute directly to avoid issue when using optimizers # See https://github.com/DLR-RM/stable-baselines3/issues/391 recursive_setattr(model, name + ".data", pytorch_variables[name].data) # Sample gSDE exploration matrix, so it uses the right device # see issue #44 if model.use_sde: model.policy.reset_noise() # pytype: disable=attribute-error return model