def _setup_model(self) -> None: self._setup_lr_schedule() self.set_random_seed(self.seed) # Use DictReplayBuffer if needed if self.replay_buffer_class is None: if isinstance(self.observation_space, gym.spaces.Dict): self.replay_buffer_class = DictReplayBuffer else: self.replay_buffer_class = ReplayBuffer elif self.replay_buffer_class == HerReplayBuffer: assert self.env is not None, "You must pass an environment when using `HerReplayBuffer`" # If using offline sampling, we need a classic replay buffer too if self.replay_buffer_kwargs.get("online_sampling", True): replay_buffer = None else: replay_buffer = DictReplayBuffer( self.buffer_size, self.observation_space, self.action_space, self.device, optimize_memory_usage=self.optimize_memory_usage, ) self.replay_buffer = HerReplayBuffer( self.env, self.buffer_size, self.device, replay_buffer=replay_buffer, **self.replay_buffer_kwargs, ) if self.replay_buffer is None: self.replay_buffer = self.replay_buffer_class( self.buffer_size, self.observation_space, self.action_space, self.device, optimize_memory_usage=self.optimize_memory_usage, **self.replay_buffer_kwargs, ) self.policy = self.policy_class( # pytype:disable=not-instantiable self.observation_space, self.action_space, self.lr_schedule, **self.policy_kwargs, # pytype:disable=not-instantiable ) self.policy = self.policy.to(self.device) # Convert train freq parameter to TrainFreq object self._convert_train_freq()
class OffPolicyAlgorithm(BaseAlgorithm): """ The base for Off-Policy algorithms (ex: SAC/TD3) :param policy: Policy object :param env: The environment to learn from (if registered in Gym, can be str. Can be None for loading trained models) :param policy_base: The base policy used by this method :param learning_rate: learning rate for the optimizer, it can be a function of the current progress remaining (from 1 to 0) :param buffer_size: size of the replay buffer :param learning_starts: how many steps of the model to collect transitions for before learning starts :param batch_size: Minibatch size for each gradient update :param tau: the soft update coefficient ("Polyak update", between 0 and 1) :param gamma: the discount factor :param train_freq: Update the model every ``train_freq`` steps. Alternatively pass a tuple of frequency and unit like ``(5, "step")`` or ``(2, "episode")``. :param gradient_steps: How many gradient steps to do after each rollout (see ``train_freq``) Set to ``-1`` means to do as many gradient steps as steps done in the environment during the rollout. :param action_noise: the action noise type (None by default), this can help for hard exploration problem. Cf common.noise for the different action noise type. :param replay_buffer_class: Replay buffer class to use (for instance ``HerReplayBuffer``). If ``None``, it will be automatically selected. :param replay_buffer_kwargs: Keyword arguments to pass to the replay buffer on creation. :param optimize_memory_usage: Enable a memory efficient variant of the replay buffer at a cost of more complexity. See https://github.com/DLR-RM/stable-baselines3/issues/37#issuecomment-637501195 :param policy_kwargs: Additional arguments to be passed to the policy on creation :param tensorboard_log: the log location for tensorboard (if None, no logging) :param verbose: The verbosity level: 0 none, 1 training information, 2 debug :param device: Device on which the code should run. By default, it will try to use a Cuda compatible device and fallback to cpu if it is not possible. :param support_multi_env: Whether the algorithm supports training with multiple environments (as in A2C) :param create_eval_env: Whether to create a second environment that will be used for evaluating the agent periodically. (Only available when passing string for the environment) :param monitor_wrapper: When creating an environment, whether to wrap it or not in a Monitor wrapper. :param seed: Seed for the pseudo random generators :param use_sde: Whether to use State Dependent Exploration (SDE) instead of action noise exploration (default: False) :param sde_sample_freq: Sample a new noise matrix every n steps when using gSDE Default: -1 (only sample at the beginning of the rollout) :param use_sde_at_warmup: Whether to use gSDE instead of uniform sampling during the warm up phase (before learning starts) :param sde_support: Whether the model support gSDE or not :param remove_time_limit_termination: Remove terminations (dones) that are due to time limit. See https://github.com/hill-a/stable-baselines/issues/863 :param supported_action_spaces: The action spaces supported by the algorithm. """ def __init__( self, policy: Type[BasePolicy], env: Union[GymEnv, str], policy_base: Type[BasePolicy], learning_rate: Union[float, Schedule], buffer_size: int = 1000000, # 1e6 learning_starts: int = 100, batch_size: int = 256, tau: float = 0.005, gamma: float = 0.99, train_freq: Union[int, Tuple[int, str]] = (1, "step"), gradient_steps: int = 1, action_noise: Optional[ActionNoise] = None, replay_buffer_class: Optional[ReplayBuffer] = None, replay_buffer_kwargs: Optional[Dict[str, Any]] = None, optimize_memory_usage: bool = False, policy_kwargs: Dict[str, Any] = None, tensorboard_log: Optional[str] = None, verbose: int = 0, device: Union[th.device, str] = "auto", support_multi_env: bool = False, create_eval_env: bool = False, monitor_wrapper: bool = True, seed: Optional[int] = None, use_sde: bool = False, sde_sample_freq: int = -1, use_sde_at_warmup: bool = False, sde_support: bool = True, remove_time_limit_termination: bool = False, supported_action_spaces: Optional[Tuple[gym.spaces.Space, ...]] = None, ): super(OffPolicyAlgorithm, self).__init__( policy=policy, env=env, policy_base=policy_base, learning_rate=learning_rate, policy_kwargs=policy_kwargs, tensorboard_log=tensorboard_log, verbose=verbose, device=device, support_multi_env=support_multi_env, create_eval_env=create_eval_env, monitor_wrapper=monitor_wrapper, seed=seed, use_sde=use_sde, sde_sample_freq=sde_sample_freq, supported_action_spaces=supported_action_spaces, ) self.buffer_size = buffer_size self.batch_size = batch_size self.learning_starts = learning_starts self.tau = tau self.gamma = gamma self.gradient_steps = gradient_steps self.action_noise = action_noise self.optimize_memory_usage = optimize_memory_usage self.replay_buffer_class = replay_buffer_class if replay_buffer_kwargs is None: replay_buffer_kwargs = {} self.replay_buffer_kwargs = replay_buffer_kwargs self._episode_storage = None # Remove terminations (dones) that are due to time limit # see https://github.com/hill-a/stable-baselines/issues/863 self.remove_time_limit_termination = remove_time_limit_termination # Save train freq parameter, will be converted later to TrainFreq object self.train_freq = train_freq self.actor = None # type: Optional[th.nn.Module] self.replay_buffer = None # type: Optional[ReplayBuffer] # Update policy keyword arguments if sde_support: self.policy_kwargs["use_sde"] = self.use_sde # For gSDE only self.use_sde_at_warmup = use_sde_at_warmup def _convert_train_freq(self) -> None: """ Convert `train_freq` parameter (int or tuple) to a TrainFreq object. """ if not isinstance(self.train_freq, TrainFreq): train_freq = self.train_freq # The value of the train frequency will be checked later if not isinstance(train_freq, tuple): train_freq = (train_freq, "step") try: train_freq = (train_freq[0], TrainFrequencyUnit(train_freq[1])) except ValueError: raise ValueError(f"The unit of the `train_freq` must be either 'step' or 'episode' not '{train_freq[1]}'!") if not isinstance(train_freq[0], int): raise ValueError(f"The frequency of `train_freq` must be an integer and not {train_freq[0]}") self.train_freq = TrainFreq(*train_freq) def _setup_model(self) -> None: self._setup_lr_schedule() self.set_random_seed(self.seed) # Use DictReplayBuffer if needed if self.replay_buffer_class is None: if isinstance(self.observation_space, gym.spaces.Dict): self.replay_buffer_class = DictReplayBuffer else: self.replay_buffer_class = ReplayBuffer elif self.replay_buffer_class == HerReplayBuffer: assert self.env is not None, "You must pass an environment when using `HerReplayBuffer`" # If using offline sampling, we need a classic replay buffer too if self.replay_buffer_kwargs.get("online_sampling", True): replay_buffer = None else: replay_buffer = DictReplayBuffer( self.buffer_size, self.observation_space, self.action_space, self.device, optimize_memory_usage=self.optimize_memory_usage, ) self.replay_buffer = HerReplayBuffer( self.env, self.buffer_size, self.device, replay_buffer=replay_buffer, **self.replay_buffer_kwargs, ) if self.replay_buffer is None: self.replay_buffer = self.replay_buffer_class( self.buffer_size, self.observation_space, self.action_space, self.device, optimize_memory_usage=self.optimize_memory_usage, **self.replay_buffer_kwargs, ) self.policy = self.policy_class( # pytype:disable=not-instantiable self.observation_space, self.action_space, self.lr_schedule, **self.policy_kwargs, # pytype:disable=not-instantiable ) self.policy = self.policy.to(self.device) # Convert train freq parameter to TrainFreq object self._convert_train_freq() def save_replay_buffer(self, path: Union[str, pathlib.Path, io.BufferedIOBase]) -> None: """ Save the replay buffer as a pickle file. :param path: Path to the file where the replay buffer should be saved. if path is a str or pathlib.Path, the path is automatically created if necessary. """ assert self.replay_buffer is not None, "The replay buffer is not defined" save_to_pkl(path, self.replay_buffer, self.verbose) def load_replay_buffer( self, path: Union[str, pathlib.Path, io.BufferedIOBase], truncate_last_traj: bool = True, ) -> None: """ Load a replay buffer from a pickle file. :param path: Path to the pickled replay buffer. :param truncate_last_traj: When using ``HerReplayBuffer`` with online sampling: If set to ``True``, we assume that the last trajectory in the replay buffer was finished (and truncate it). If set to ``False``, we assume that we continue the same trajectory (same episode). """ self.replay_buffer = load_from_pkl(path, self.verbose) assert isinstance(self.replay_buffer, ReplayBuffer), "The replay buffer must inherit from ReplayBuffer class" # Backward compatibility with SB3 < 2.1.0 replay buffer # Keep old behavior: do not handle timeout termination separately if not hasattr(self.replay_buffer, "handle_timeout_termination"): # pragma: no cover self.replay_buffer.handle_timeout_termination = False self.replay_buffer.timeouts = np.zeros_like(self.replay_buffer.dones) if isinstance(self.replay_buffer, HerReplayBuffer): assert self.env is not None, "You must pass an environment at load time when using `HerReplayBuffer`" self.replay_buffer.set_env(self.get_env()) if truncate_last_traj: self.replay_buffer.truncate_last_trajectory() def _setup_learn( self, total_timesteps: int, eval_env: Optional[GymEnv], callback: MaybeCallback = None, eval_freq: int = 10000, n_eval_episodes: int = 5, log_path: Optional[str] = None, reset_num_timesteps: bool = True, tb_log_name: str = "run", ) -> Tuple[int, BaseCallback]: """ cf `BaseAlgorithm`. """ # Prevent continuity issue by truncating trajectory # when using memory efficient replay buffer # see https://github.com/DLR-RM/stable-baselines3/issues/46 # Special case when using HerReplayBuffer, # the classic replay buffer is inside it when using offline sampling if isinstance(self.replay_buffer, HerReplayBuffer): replay_buffer = self.replay_buffer.replay_buffer else: replay_buffer = self.replay_buffer truncate_last_traj = ( self.optimize_memory_usage and reset_num_timesteps and replay_buffer is not None and (replay_buffer.full or replay_buffer.pos > 0) ) if truncate_last_traj: warnings.warn( "The last trajectory in the replay buffer will be truncated, " "see https://github.com/DLR-RM/stable-baselines3/issues/46." "You should use `reset_num_timesteps=False` or `optimize_memory_usage=False`" "to avoid that issue." ) # Go to the previous index pos = (replay_buffer.pos - 1) % replay_buffer.buffer_size replay_buffer.dones[pos] = True return super()._setup_learn( total_timesteps, eval_env, callback, eval_freq, n_eval_episodes, log_path, reset_num_timesteps, tb_log_name, ) def learn( self, total_timesteps: int, callback: MaybeCallback = None, log_interval: int = 4, eval_env: Optional[GymEnv] = None, eval_freq: int = -1, n_eval_episodes: int = 5, tb_log_name: str = "run", eval_log_path: Optional[str] = None, reset_num_timesteps: bool = True, ) -> "OffPolicyAlgorithm": total_timesteps, callback = self._setup_learn( total_timesteps, eval_env, callback, eval_freq, n_eval_episodes, eval_log_path, reset_num_timesteps, tb_log_name, ) callback.on_training_start(locals(), globals()) while self.num_timesteps < total_timesteps: rollout = self.collect_rollouts( self.env, train_freq=self.train_freq, action_noise=self.action_noise, callback=callback, learning_starts=self.learning_starts, replay_buffer=self.replay_buffer, log_interval=log_interval, ) if rollout.continue_training is False: break if self.num_timesteps > 0 and self.num_timesteps > self.learning_starts: # If no `gradient_steps` is specified, # do as many gradients steps as steps performed during the rollout gradient_steps = self.gradient_steps if self.gradient_steps > 0 else rollout.episode_timesteps self.train(batch_size=self.batch_size, gradient_steps=gradient_steps) callback.on_training_end() return self def train(self, gradient_steps: int, batch_size: int) -> None: """ Sample the replay buffer and do the updates (gradient descent and update target networks) """ raise NotImplementedError() def _sample_action( self, learning_starts: int, action_noise: Optional[ActionNoise] = None ) -> Tuple[np.ndarray, np.ndarray]: """ Sample an action according to the exploration policy. This is either done by sampling the probability distribution of the policy, or sampling a random action (from a uniform distribution over the action space) or by adding noise to the deterministic output. :param action_noise: Action noise that will be used for exploration Required for deterministic policy (e.g. TD3). This can also be used in addition to the stochastic policy for SAC. :param learning_starts: Number of steps before learning for the warm-up phase. :return: action to take in the environment and scaled action that will be stored in the replay buffer. The two differs when the action space is not normalized (bounds are not [-1, 1]). """ # Select action randomly or according to policy if self.num_timesteps < learning_starts and not (self.use_sde and self.use_sde_at_warmup): # Warmup phase unscaled_action = np.array([self.action_space.sample()]) else: # Note: when using continuous actions, # we assume that the policy uses tanh to scale the action # We use non-deterministic action in the case of SAC, for TD3, it does not matter unscaled_action, _ = self.predict(self._last_obs, deterministic=False) # Rescale the action from [low, high] to [-1, 1] if isinstance(self.action_space, gym.spaces.Box): scaled_action = self.policy.scale_action(unscaled_action) # Add noise to the action (improve exploration) if action_noise is not None: scaled_action = np.clip(scaled_action + action_noise(), -1, 1) # We store the scaled action in the buffer buffer_action = scaled_action action = self.policy.unscale_action(scaled_action) else: # Discrete case, no need to normalize or clip buffer_action = unscaled_action action = buffer_action return action, buffer_action def _dump_logs(self) -> None: """ Write log. """ time_elapsed = time.time() - self.start_time fps = int(self.num_timesteps / (time_elapsed + 1e-8)) self.logger.record("time/episodes", self._episode_num, exclude="tensorboard") if len(self.ep_info_buffer) > 0 and len(self.ep_info_buffer[0]) > 0: self.logger.record("rollout/ep_rew_mean", safe_mean([ep_info["r"] for ep_info in self.ep_info_buffer])) self.logger.record("rollout/ep_len_mean", safe_mean([ep_info["l"] for ep_info in self.ep_info_buffer])) self.logger.record("time/fps", fps) self.logger.record("time/time_elapsed", int(time_elapsed), exclude="tensorboard") self.logger.record("time/total timesteps", self.num_timesteps, exclude="tensorboard") if self.use_sde: self.logger.record("train/std", (self.actor.get_std()).mean().item()) if len(self.ep_success_buffer) > 0: self.logger.record("rollout/success rate", safe_mean(self.ep_success_buffer)) # Pass the number of timesteps for tensorboard self.logger.dump(step=self.num_timesteps) def _on_step(self) -> None: """ Method called after each step in the environment. It is meant to trigger DQN target network update but can be used for other purposes """ pass def _store_transition( self, replay_buffer: ReplayBuffer, buffer_action: np.ndarray, new_obs: np.ndarray, reward: np.ndarray, done: np.ndarray, infos: List[Dict[str, Any]], ) -> None: """ Store transition in the replay buffer. We store the normalized action and the unnormalized observation. It also handles terminal observations (because VecEnv resets automatically). :param replay_buffer: Replay buffer object where to store the transition. :param buffer_action: normalized action :param new_obs: next observation in the current episode or first observation of the episode (when done is True) :param reward: reward for the current transition :param done: Termination signal :param infos: List of additional information about the transition. It may contain the terminal observations and information about timeout. """ # Store only the unnormalized version if self._vec_normalize_env is not None: new_obs_ = self._vec_normalize_env.get_original_obs() reward_ = self._vec_normalize_env.get_original_reward() else: # Avoid changing the original ones self._last_original_obs, new_obs_, reward_ = self._last_obs, new_obs, reward # As the VecEnv resets automatically, new_obs is already the # first observation of the next episode if done and infos[0].get("terminal_observation") is not None: next_obs = infos[0]["terminal_observation"] # VecNormalize normalizes the terminal observation if self._vec_normalize_env is not None: next_obs = self._vec_normalize_env.unnormalize_obs(next_obs) else: next_obs = new_obs_ replay_buffer.add( self._last_original_obs, next_obs, buffer_action, reward_, done, infos, ) self._last_obs = new_obs # Save the unnormalized observation if self._vec_normalize_env is not None: self._last_original_obs = new_obs_ def collect_rollouts( self, env: VecEnv, callback: BaseCallback, train_freq: TrainFreq, replay_buffer: ReplayBuffer, action_noise: Optional[ActionNoise] = None, learning_starts: int = 0, log_interval: Optional[int] = None, ) -> RolloutReturn: """ Collect experiences and store them into a ``ReplayBuffer``. :param env: The training environment :param callback: Callback that will be called at each step (and at the beginning and end of the rollout) :param train_freq: How much experience to collect by doing rollouts of current policy. Either ``TrainFreq(<n>, TrainFrequencyUnit.STEP)`` or ``TrainFreq(<n>, TrainFrequencyUnit.EPISODE)`` with ``<n>`` being an integer greater than 0. :param action_noise: Action noise that will be used for exploration Required for deterministic policy (e.g. TD3). This can also be used in addition to the stochastic policy for SAC. :param learning_starts: Number of steps before learning for the warm-up phase. :param replay_buffer: :param log_interval: Log data every ``log_interval`` episodes :return: """ episode_rewards, total_timesteps = [], [] num_collected_steps, num_collected_episodes = 0, 0 assert isinstance(env, VecEnv), "You must pass a VecEnv" assert env.num_envs == 1, "OffPolicyAlgorithm only support single environment" assert train_freq.frequency > 0, "Should at least collect one step or episode." if self.use_sde: self.actor.reset_noise() callback.on_rollout_start() continue_training = True while should_collect_more_steps(train_freq, num_collected_steps, num_collected_episodes): done = False episode_reward, episode_timesteps = 0.0, 0 while not done: if self.use_sde and self.sde_sample_freq > 0 and num_collected_steps % self.sde_sample_freq == 0: # Sample a new noise matrix self.actor.reset_noise() # Select action randomly or according to policy action, buffer_action = self._sample_action(learning_starts, action_noise) # Rescale and perform action new_obs, reward, done, infos = env.step(action) self.num_timesteps += 1 episode_timesteps += 1 num_collected_steps += 1 # Give access to local variables callback.update_locals(locals()) # Only stop training if return value is False, not when it is None. if callback.on_step() is False: return RolloutReturn(0.0, num_collected_steps, num_collected_episodes, continue_training=False) episode_reward += reward # Retrieve reward and episode length if using Monitor wrapper self._update_info_buffer(infos, done) # Store data in replay buffer (normalized action and unnormalized observation) self._store_transition(replay_buffer, buffer_action, new_obs, reward, done, infos) self._update_current_progress_remaining(self.num_timesteps, self._total_timesteps) # For DQN, check if the target network should be updated # and update the exploration schedule # For SAC/TD3, the update is done as the same time as the gradient update # see https://github.com/hill-a/stable-baselines/issues/900 self._on_step() if not should_collect_more_steps(train_freq, num_collected_steps, num_collected_episodes): break if done: num_collected_episodes += 1 self._episode_num += 1 episode_rewards.append(episode_reward) total_timesteps.append(episode_timesteps) if action_noise is not None: action_noise.reset() # Log training infos if log_interval is not None and self._episode_num % log_interval == 0: self._dump_logs() mean_reward = np.mean(episode_rewards) if num_collected_episodes > 0 else 0.0 callback.on_rollout_end() return RolloutReturn(mean_reward, num_collected_steps, num_collected_episodes, continue_training)
def __init__( self, policy: Union[str, Type[BasePolicy]], env: Union[GymEnv, str], model_class: Type[OffPolicyAlgorithm], n_sampled_goal: int = 4, goal_selection_strategy: Union[GoalSelectionStrategy, str] = "future", online_sampling: bool = False, max_episode_length: Optional[int] = None, *args, **kwargs, ): # we will use the policy and learning rate from the model super(HER, self).__init__(policy=BasePolicy, env=env, policy_base=BasePolicy, learning_rate=0.0) del self.policy, self.learning_rate if self.get_vec_normalize_env() is not None: assert online_sampling, "You must pass `online_sampling=True` if you want to use `VecNormalize` with `HER`" _init_setup_model = kwargs.get("_init_setup_model", True) if "_init_setup_model" in kwargs: del kwargs["_init_setup_model"] # model initialization self.model_class = model_class self.model = model_class( policy=policy, env=self.env, _init_setup_model=False, # pytype: disable=wrong-keyword-args *args, **kwargs, # pytype: disable=wrong-keyword-args ) # Make HER use self.model.action_noise del self.action_noise self.verbose = self.model.verbose self.tensorboard_log = self.model.tensorboard_log # convert goal_selection_strategy into GoalSelectionStrategy if string if isinstance(goal_selection_strategy, str): self.goal_selection_strategy = KEY_TO_GOAL_STRATEGY[ goal_selection_strategy.lower()] else: self.goal_selection_strategy = goal_selection_strategy # check if goal_selection_strategy is valid assert isinstance( self.goal_selection_strategy, GoalSelectionStrategy ), f"Invalid goal selection strategy, please use one of {list(GoalSelectionStrategy)}" self.n_sampled_goal = n_sampled_goal # if we sample her transitions online use custom replay buffer self.online_sampling = online_sampling # compute ratio between HER replays and regular replays in percent for online HER sampling self.her_ratio = 1 - (1.0 / (self.n_sampled_goal + 1)) # maximum steps in episode self.max_episode_length = get_time_limit(self.env, max_episode_length) # storage for transitions of current episode for offline sampling # for online sampling, it replaces the "classic" replay buffer completely her_buffer_size = self.buffer_size if online_sampling else self.max_episode_length assert self.env is not None, "Because it needs access to `env.compute_reward()` HER you must provide the env." self._episode_storage = HerReplayBuffer( self.env, her_buffer_size, self.max_episode_length, self.goal_selection_strategy, self.env.observation_space, self.env.action_space, self.device, self.n_envs, self.her_ratio, # pytype: disable=wrong-arg-types ) # counter for steps in episode self.episode_steps = 0 if _init_setup_model: self._setup_model()
class HER(BaseAlgorithm): """ Hindsight Experience Replay (HER) Paper: https://arxiv.org/abs/1707.01495 .. warning:: For performance reasons, the maximum number of steps per episodes must be specified. In most cases, it will be inferred if you specify ``max_episode_steps`` when registering the environment or if you use a ``gym.wrappers.TimeLimit`` (and ``env.spec`` is not None). Otherwise, you can directly pass ``max_episode_length`` to the model constructor For additional offline algorithm specific arguments please have a look at the corresponding documentation. :param policy: The policy model to use. :param env: The environment to learn from (if registered in Gym, can be str) :param model_class: Off policy model which will be used with hindsight experience replay. (SAC, TD3, DDPG, DQN) :param n_sampled_goal: Number of sampled goals for replay. (offline sampling) :param goal_selection_strategy: Strategy for sampling goals for replay. One of ['episode', 'final', 'future', 'random'] :param online_sampling: Sample HER transitions online. :param learning_rate: learning rate for the optimizer, it can be a function of the current progress remaining (from 1 to 0) :param max_episode_length: The maximum length of an episode. If not specified, it will be automatically inferred if the environment uses a ``gym.wrappers.TimeLimit`` wrapper. """ def __init__( self, policy: Union[str, Type[BasePolicy]], env: Union[GymEnv, str], model_class: Type[OffPolicyAlgorithm], n_sampled_goal: int = 4, goal_selection_strategy: Union[GoalSelectionStrategy, str] = "future", online_sampling: bool = False, max_episode_length: Optional[int] = None, *args, **kwargs, ): # we will use the policy and learning rate from the model super(HER, self).__init__(policy=BasePolicy, env=env, policy_base=BasePolicy, learning_rate=0.0) del self.policy, self.learning_rate if self.get_vec_normalize_env() is not None: assert online_sampling, "You must pass `online_sampling=True` if you want to use `VecNormalize` with `HER`" _init_setup_model = kwargs.get("_init_setup_model", True) if "_init_setup_model" in kwargs: del kwargs["_init_setup_model"] # model initialization self.model_class = model_class self.model = model_class( policy=policy, env=self.env, _init_setup_model=False, # pytype: disable=wrong-keyword-args *args, **kwargs, # pytype: disable=wrong-keyword-args ) # Make HER use self.model.action_noise del self.action_noise self.verbose = self.model.verbose self.tensorboard_log = self.model.tensorboard_log # convert goal_selection_strategy into GoalSelectionStrategy if string if isinstance(goal_selection_strategy, str): self.goal_selection_strategy = KEY_TO_GOAL_STRATEGY[ goal_selection_strategy.lower()] else: self.goal_selection_strategy = goal_selection_strategy # check if goal_selection_strategy is valid assert isinstance( self.goal_selection_strategy, GoalSelectionStrategy ), f"Invalid goal selection strategy, please use one of {list(GoalSelectionStrategy)}" self.n_sampled_goal = n_sampled_goal # if we sample her transitions online use custom replay buffer self.online_sampling = online_sampling # compute ratio between HER replays and regular replays in percent for online HER sampling self.her_ratio = 1 - (1.0 / (self.n_sampled_goal + 1)) # maximum steps in episode self.max_episode_length = get_time_limit(self.env, max_episode_length) # storage for transitions of current episode for offline sampling # for online sampling, it replaces the "classic" replay buffer completely her_buffer_size = self.buffer_size if online_sampling else self.max_episode_length assert self.env is not None, "Because it needs access to `env.compute_reward()` HER you must provide the env." self._episode_storage = HerReplayBuffer( self.env, her_buffer_size, self.max_episode_length, self.goal_selection_strategy, self.env.observation_space, self.env.action_space, self.device, self.n_envs, self.her_ratio, # pytype: disable=wrong-arg-types ) # counter for steps in episode self.episode_steps = 0 if _init_setup_model: self._setup_model() def _setup_model(self) -> None: self.model._setup_model() # assign episode storage to replay buffer when using online HER sampling if self.online_sampling: self.model.replay_buffer = self._episode_storage def predict( self, observation: np.ndarray, state: Optional[np.ndarray] = None, mask: Optional[np.ndarray] = None, deterministic: bool = False, ) -> Tuple[np.ndarray, Optional[np.ndarray]]: return self.model.predict(observation, state, mask, deterministic) def learn( self, total_timesteps: int, callback: MaybeCallback = None, log_interval: int = 4, eval_env: Optional[GymEnv] = None, eval_freq: int = -1, n_eval_episodes: int = 5, tb_log_name: str = "HER", eval_log_path: Optional[str] = None, reset_num_timesteps: bool = True, ) -> BaseAlgorithm: total_timesteps, callback = self._setup_learn( total_timesteps, eval_env, callback, eval_freq, n_eval_episodes, eval_log_path, reset_num_timesteps, tb_log_name) self.model.start_time = self.start_time self.model.ep_info_buffer = self.ep_info_buffer self.model.ep_success_buffer = self.ep_success_buffer self.model.num_timesteps = self.num_timesteps self.model._episode_num = self._episode_num self.model._last_obs = self._last_obs self.model._total_timesteps = self._total_timesteps callback.on_training_start(locals(), globals()) while self.num_timesteps < total_timesteps: rollout = self.collect_rollouts( self.env, train_freq=self.train_freq, action_noise=self.action_noise, callback=callback, learning_starts=self.learning_starts, log_interval=log_interval, ) if rollout.continue_training is False: break if self.num_timesteps > 0 and self.num_timesteps > self.learning_starts and self.replay_buffer.size( ) > 0: # If no `gradient_steps` is specified, # do as many gradients steps as steps performed during the rollout gradient_steps = self.gradient_steps if self.gradient_steps > 0 else rollout.episode_timesteps self.train(batch_size=self.batch_size, gradient_steps=gradient_steps) callback.on_training_end() return self def collect_rollouts( self, env: VecEnv, callback: BaseCallback, train_freq: TrainFreq, action_noise: Optional[ActionNoise] = None, learning_starts: int = 0, log_interval: Optional[int] = None, ) -> RolloutReturn: """ Collect experiences and store them into a ReplayBuffer. :param env: The training environment :param callback: Callback that will be called at each step (and at the beginning and end of the rollout) :param train_freq: How much experience to collect by doing rollouts of current policy. Either ``TrainFreq(<n>, TrainFrequencyUnit.STEP)`` or ``TrainFreq(<n>, TrainFrequencyUnit.EPISODE)`` with ``<n>`` being an integer greater than 0. :param action_noise: Action noise that will be used for exploration Required for deterministic policy (e.g. TD3). This can also be used in addition to the stochastic policy for SAC. :param learning_starts: Number of steps before learning for the warm-up phase. :param log_interval: Log data every ``log_interval`` episodes :return: """ episode_rewards, total_timesteps = [], [] num_collected_steps, num_collected_episodes = 0, 0 assert isinstance(env, VecEnv), "You must pass a VecEnv" assert env.num_envs == 1, "OffPolicyAlgorithm only support single environment" assert train_freq.frequency > 0, "Should at least collect one step or episode." if self.model.use_sde: self.actor.reset_noise() callback.on_rollout_start() continue_training = True while should_collect_more_steps(train_freq, num_collected_steps, num_collected_episodes): done = False episode_reward, episode_timesteps = 0.0, 0 while not done: # concatenate observation and (desired) goal observation = self._last_obs self._last_obs = ObsDictWrapper.convert_dict(observation) if (self.model.use_sde and self.model.sde_sample_freq > 0 and num_collected_steps % self.model.sde_sample_freq == 0): # Sample a new noise matrix self.actor.reset_noise() # Select action randomly or according to policy self.model._last_obs = self._last_obs action, buffer_action = self._sample_action( learning_starts, action_noise) # Perform action new_obs, reward, done, infos = env.step(action) self.num_timesteps += 1 self.model.num_timesteps = self.num_timesteps episode_timesteps += 1 num_collected_steps += 1 # Only stop training if return value is False, not when it is None. if callback.on_step() is False: return RolloutReturn(0.0, num_collected_steps, num_collected_episodes, continue_training=False) episode_reward += reward # Retrieve reward and episode length if using Monitor wrapper self._update_info_buffer(infos, done) self.model.ep_info_buffer = self.ep_info_buffer self.model.ep_success_buffer = self.ep_success_buffer # == Store transition in the replay buffer and/or in the episode storage == if self._vec_normalize_env is not None: # Store only the unnormalized version new_obs_ = self._vec_normalize_env.get_original_obs() reward_ = self._vec_normalize_env.get_original_reward() else: # Avoid changing the original ones self._last_original_obs, new_obs_, reward_ = observation, new_obs, reward self.model._last_original_obs = self._last_original_obs # As the VecEnv resets automatically, new_obs is already the # first observation of the next episode if done and infos[0].get("terminal_observation") is not None: next_obs = infos[0]["terminal_observation"] # VecNormalize normalizes the terminal observation if self._vec_normalize_env is not None: next_obs = self._vec_normalize_env.unnormalize_obs( next_obs) else: next_obs = new_obs_ if self.online_sampling: self.replay_buffer.add(self._last_original_obs, next_obs, buffer_action, reward_, done, infos) else: # concatenate observation with (desired) goal flattened_obs = ObsDictWrapper.convert_dict( self._last_original_obs) flattened_next_obs = ObsDictWrapper.convert_dict(next_obs) # add to replay buffer self.replay_buffer.add(flattened_obs, flattened_next_obs, buffer_action, reward_, done) # add current transition to episode storage self._episode_storage.add(self._last_original_obs, next_obs, buffer_action, reward_, done, infos) self._last_obs = new_obs self.model._last_obs = self._last_obs # Save the unnormalized new observation if self._vec_normalize_env is not None: self._last_original_obs = new_obs_ self.model._last_original_obs = self._last_original_obs self.model._update_current_progress_remaining( self.num_timesteps, self._total_timesteps) # For DQN, check if the target network should be updated # and update the exploration schedule # For SAC/TD3, the update is done as the same time as the gradient update # see https://github.com/hill-a/stable-baselines/issues/900 self.model._on_step() self.episode_steps += 1 if not should_collect_more_steps(train_freq, num_collected_steps, num_collected_episodes): break if done or self.episode_steps >= self.max_episode_length: if self.online_sampling: self.replay_buffer.store_episode() else: self._episode_storage.store_episode() # sample virtual transitions and store them in replay buffer self._sample_her_transitions() # clear storage for current episode self._episode_storage.reset() num_collected_episodes += 1 self._episode_num += 1 self.model._episode_num = self._episode_num episode_rewards.append(episode_reward) total_timesteps.append(episode_timesteps) if action_noise is not None: action_noise.reset() # Log training infos if log_interval is not None and self._episode_num % log_interval == 0: self._dump_logs() self.episode_steps = 0 mean_reward = np.mean( episode_rewards) if num_collected_episodes > 0 else 0.0 callback.on_rollout_end() return RolloutReturn(mean_reward, num_collected_steps, num_collected_episodes, continue_training) def _sample_her_transitions(self) -> None: """ Sample additional goals and store new transitions in replay buffer when using offline sampling. """ # Sample goals and get new observations # maybe_vec_env=None as we should store unnormalized transitions, # they will be normalized at sampling time observations, next_observations, actions, rewards = self._episode_storage.sample_offline( n_sampled_goal=self.n_sampled_goal) # store data in replay buffer dones = np.zeros((len(observations)), dtype=bool) self.replay_buffer.extend(observations, next_observations, actions, rewards, dones) def __getattr__(self, item: str) -> Any: """ Find attribute from model class if this class does not have it. """ if hasattr(self.model, item): return getattr(self.model, item) else: raise AttributeError(f"{self} has no attribute {item}") def _get_torch_save_params(self) -> Tuple[List[str], List[str]]: return self.model._get_torch_save_params() def save( self, path: Union[str, pathlib.Path, io.BufferedIOBase], exclude: Optional[Iterable[str]] = None, include: Optional[Iterable[str]] = None, ) -> None: """ Save all the attributes of the object and the model parameters in a zip-file. :param path: path to the file where the rl agent should be saved :param exclude: name of parameters that should be excluded in addition to the default one :param include: name of parameters that might be excluded but should be included anyway """ # add HER parameters to model self.model.n_sampled_goal = self.n_sampled_goal self.model.goal_selection_strategy = self.goal_selection_strategy self.model.online_sampling = self.online_sampling self.model.model_class = self.model_class self.model.max_episode_length = self.max_episode_length self.model.save(path, exclude, include) @classmethod def load( cls, path: Union[str, pathlib.Path, io.BufferedIOBase], env: Optional[GymEnv] = None, device: Union[th.device, str] = "auto", custom_objects: Optional[Dict[str, Any]] = None, **kwargs, ) -> "BaseAlgorithm": """ Load the model from a zip-file :param path: path to the file (or a file-like) where to load the agent from :param env: the new environment to run the loaded model on (can be None if you only need prediction from a trained model) has priority over any saved environment :param device: Device on which the code should run. :param custom_objects: Dictionary of objects to replace upon loading. If a variable is present in this dictionary as a key, it will not be deserialized and the corresponding item will be used instead. Similar to custom_objects in ``keras.models.load_model``. Useful when you have an object in file that can not be deserialized. :param kwargs: extra arguments to change the model when loading """ data, params, pytorch_variables = load_from_zip_file( path, device=device, custom_objects=custom_objects) # Remove stored device information and replace with ours if "policy_kwargs" in data: if "device" in data["policy_kwargs"]: del data["policy_kwargs"]["device"] if "policy_kwargs" in kwargs and kwargs["policy_kwargs"] != data[ "policy_kwargs"]: raise ValueError( f"The specified policy kwargs do not equal the stored policy kwargs." f"Stored kwargs: {data['policy_kwargs']}, specified kwargs: {kwargs['policy_kwargs']}" ) # check if observation space and action space are part of the saved parameters if "observation_space" not in data or "action_space" not in data: raise KeyError( "The observation_space and action_space were not given, can't verify new environments" ) # check if given env is valid if env is not None: # Wrap first if needed env = cls._wrap_env(env, data["verbose"]) # Check if given env is valid check_for_correct_spaces(env, data["observation_space"], data["action_space"]) else: # Use stored env, if one exists. If not, continue as is (can be used for predict) if "env" in data: env = data["env"] if "use_sde" in data and data["use_sde"]: kwargs["use_sde"] = True # Keys that cannot be changed for key in {"model_class", "online_sampling", "max_episode_length"}: if key in kwargs: del kwargs[key] # Keys that can be changed for key in {"n_sampled_goal", "goal_selection_strategy"}: if key in kwargs: data[key] = kwargs[key] # pytype: disable=unsupported-operands del kwargs[key] # noinspection PyArgumentList her_model = cls( policy=data["policy_class"], env=env, model_class=data["model_class"], n_sampled_goal=data["n_sampled_goal"], goal_selection_strategy=data["goal_selection_strategy"], online_sampling=data["online_sampling"], max_episode_length=data["max_episode_length"], policy_kwargs=data["policy_kwargs"], _init_setup_model=False, # pytype: disable=not-instantiable,wrong-keyword-args **kwargs, ) # load parameters her_model.model.__dict__.update(data) her_model.model.__dict__.update(kwargs) her_model._setup_model() her_model._total_timesteps = her_model.model._total_timesteps her_model.num_timesteps = her_model.model.num_timesteps her_model._episode_num = her_model.model._episode_num # put state_dicts back in place her_model.model.set_parameters(params, exact_match=True, device=device) # put other pytorch variables back in place if pytorch_variables is not None: for name in pytorch_variables: recursive_setattr(her_model.model, name, pytorch_variables[name]) # Sample gSDE exploration matrix, so it uses the right device # see issue #44 if her_model.model.use_sde: her_model.model.policy.reset_noise() # pytype: disable=attribute-error return her_model def load_replay_buffer(self, path: Union[str, pathlib.Path, io.BufferedIOBase], truncate_last_trajectory: bool = True) -> None: """ Load a replay buffer from a pickle file and set environment for replay buffer (only online sampling). :param path: Path to the pickled replay buffer. :param truncate_last_trajectory: Only for online sampling. If set to ``True`` we assume that the last trajectory in the replay buffer was finished. If it is set to ``False`` we assume that we continue the same trajectory (same episode). """ self.model.load_replay_buffer(path=path) if self.online_sampling: # set environment self.replay_buffer.set_env(self.env) # If we are at the start of an episode, no need to truncate current_idx = self.replay_buffer.current_idx # truncate interrupted episode if truncate_last_trajectory and current_idx > 0: warnings.warn( "The last trajectory in the replay buffer will be truncated.\n" "If you are in the same episode as when the replay buffer was saved,\n" "you should use `truncate_last_trajectory=False` to avoid that issue." ) # get current episode and transition index pos = self.replay_buffer.pos # set episode length for current episode self.replay_buffer.episode_lengths[pos] = current_idx # set done = True for current episode # current_idx was already incremented self.replay_buffer.buffer["done"][pos][current_idx - 1] = np.array( [True], dtype=np.float32) # reset current transition index self.replay_buffer.current_idx = 0 # increment episode counter self.replay_buffer.pos = ( self.replay_buffer.pos + 1) % self.replay_buffer.max_episode_stored # update "full" indicator self.replay_buffer.full = self.replay_buffer.full or self.replay_buffer.pos == 0