def train(self): """ Update policy using the currently gathered rollout buffer (one gradient step over whole data). """ # Update optimizer learning rate self._update_learning_rate(self.policy.optimizer) # This will only loop once (get all data in one go) for rollout_data in self.rollout_buffer.get(batch_size=None): actions = rollout_data.actions if isinstance(self.action_space, spaces.Discrete): # Convert discrete action from float to long actions = actions.long().flatten() # TODO: avoid second computation of everything because of the gradient values, log_prob, entropy = self.policy.evaluate_actions(rollout_data.observations, actions) values = values.flatten() # Normalize advantage (not present in the original implementation) advantages = rollout_data.advantages if self.normalize_advantage: advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8) # Policy gradient loss policy_loss = -(advantages * log_prob).mean() # Value loss using the TD(gae_lambda) target value_loss = F.mse_loss(rollout_data.returns, values) # Entropy loss favor exploration if entropy is None: # Approximate entropy when no analytical form entropy_loss = -th.mean(-log_prob) else: entropy_loss = -th.mean(entropy) loss = policy_loss + self.ent_coef * entropy_loss + self.vf_coef * value_loss # Optimization step self.policy.optimizer.zero_grad() loss.backward() # Clip grad norm th.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm) self.policy.optimizer.step() explained_var = explained_variance(self.rollout_buffer.values.flatten(), self.rollout_buffer.returns.flatten()) self._n_updates += 1 logger.record("train/n_updates", self._n_updates, exclude="tensorboard") logger.record("train/explained_variance", explained_var) logger.record("train/entropy_loss", entropy_loss.item()) logger.record("train/policy_loss", policy_loss.item()) logger.record("train/value_loss", value_loss.item()) if hasattr(self.policy, "log_std"): logger.record("train/std", th.exp(self.policy.log_std).mean().item()) return entropy_loss.item(), policy_loss.item(), value_loss.item()
def train_orig(self) -> None: """ Update policy using the currently gathered rollout buffer. """ # Update optimizer learning rate self._update_learning_rate(self.policy.optimizer) # Compute current clip range clip_range = self.clip_range(self._current_progress_remaining) # Optional: clip range for the value function if self.clip_range_vf is not None: clip_range_vf = self.clip_range_vf( self._current_progress_remaining) entropy_losses, all_kl_divs = [], [] pg_losses, value_losses = [], [] clip_fractions = [] # train for n_epochs epochs for epoch in range(self.n_epochs): approx_kl_divs = [] # Do a complete pass on the rollout buffer for rollout_data in self.rollout_buffer.get(self.batch_size): actions = rollout_data.actions if isinstance(self.action_space, spaces.Discrete): # Convert discrete action from float to long actions = rollout_data.actions.long().flatten() # Re-sample the noise matrix because the log_std has changed # TODO: investigate why there is no issue with the gradient # if that line is commented (as in SAC) if self.use_sde: self.policy.reset_noise(self.batch_size) values, log_prob, entropy = self.policy.evaluate_actions( rollout_data.observations, actions) values = values.flatten() # Normalize advantage advantages = rollout_data.advantages advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8) # ratio between old and new policy, should be one at the first iteration ratio = th.exp(log_prob - rollout_data.old_log_prob) # clipped surrogate loss policy_loss_1 = advantages * ratio policy_loss_2 = advantages * th.clamp(ratio, 1 - clip_range, 1 + clip_range) policy_loss = -th.min(policy_loss_1, policy_loss_2).mean() # Logging pg_losses.append(policy_loss.item()) clip_fraction = th.mean( (th.abs(ratio - 1) > clip_range).float()).item() clip_fractions.append(clip_fraction) if self.clip_range_vf is None: # No clipping values_pred = values else: # Clip the different between old and new value # NOTE: this depends on the reward scaling values_pred = rollout_data.old_values + th.clamp( values - rollout_data.old_values, -clip_range_vf, clip_range_vf) # Value loss using the TD(gae_lambda) target value_loss = F.mse_loss(rollout_data.returns, values_pred) value_losses.append(value_loss.item()) # Entropy loss favor exploration if entropy is None: # Approximate entropy when no analytical form entropy_loss = -th.mean(-log_prob) else: entropy_loss = -th.mean(entropy) entropy_losses.append(entropy_loss.item()) loss = policy_loss + self.ent_coef * entropy_loss + self.vf_coef * value_loss # Optimization step self.policy.optimizer.zero_grad() loss.backward() # Clip grad norm th.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm) self.policy.optimizer.step() approx_kl_divs.append( th.mean(rollout_data.old_log_prob - log_prob).detach().cpu().numpy()) all_kl_divs.append(np.mean(approx_kl_divs)) if self.target_kl is not None and np.mean( approx_kl_divs) > 1.5 * self.target_kl: print( f"Early stopping at step {epoch} due to reaching max kl: {np.mean(approx_kl_divs):.2f}" ) break self._n_updates += self.n_epochs explained_var = explained_variance( self.rollout_buffer.values.flatten(), self.rollout_buffer.returns.flatten()) # Logs logger.record("train/entropy_loss", np.mean(entropy_losses)) logger.record("train/policy_gradient_loss", np.mean(pg_losses)) logger.record("train/value_loss", np.mean(value_losses)) logger.record("train/approx_kl", np.mean(approx_kl_divs)) logger.record("train/clip_fraction", np.mean(clip_fractions)) logger.record("train/loss", loss.item()) logger.record("train/explained_variance", explained_var) if hasattr(self.policy, "log_std"): logger.record("train/std", th.exp(self.policy.log_std).mean().item()) logger.record("train/n_updates", self._n_updates, exclude="tensorboard") logger.record("train/clip_range", clip_range) if self.clip_range_vf is not None: logger.record("train/clip_range_vf", clip_range_vf)
def train(self, gradient_steps: int, batch_size: Optional[int] = None) -> None: # Update optimizer learning rate self._update_learning_rate(self.policy.optimizer) # A2C with gradient_steps > 1 does not make sense assert gradient_steps == 1, "A2C does not support multiple gradient steps" # We do not use minibatches for A2C assert batch_size is None, "A2C does not support minibatch" for rollout_data in self.rollout_buffer.get(batch_size=None): actions = rollout_data.actions if isinstance(self.action_space, spaces.Discrete): # Convert discrete action from float to long actions = actions.long().flatten() # TODO: avoid second computation of everything because of the gradient values, log_prob, entropy = self.policy.evaluate_actions( rollout_data.observations, actions) values = values.flatten() # Normalize advantage (not present in the original implementation) advantages = rollout_data.advantages if self.normalize_advantage: advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8) # Policy gradient loss policy_loss = -(advantages * log_prob).mean() # Value loss using the TD(gae_lambda) target value_loss = F.mse_loss(rollout_data.returns, values) # Entropy loss favor exploration if entropy is None: # Approximate entropy when no analytical form entropy_loss = -log_prob.mean() else: entropy_loss = -th.mean(entropy) loss = policy_loss + self.ent_coef * entropy_loss + self.vf_coef * value_loss # # MSA debugging learning # self.loss_hist.append([loss.item(), policy_loss.item(), value_loss.item(), entropy_loss.item() ]) # if len (self.loss_hist) == 25: # import matplotlib.pyplot as plt # l = [] # pl = [] # vl = [] # el = [] # for losses in self.loss_hist: # l.append (losses[0]) # pl.append (losses[1]) # vl.append (losses[2]) # el.append (losses[3]) # plt.plot (l, marker="o") # plt.plot (pl, marker="o") # plt.plot (vl, marker="o") # plt.plot (el, marker="o") # plt.title ('Losses') # plt.legend (['loss', 'policy loss', 'value loss', 'ent loss']) # filename = "RL_detailed_plots/2/losses.png" # plt.savefig (filename) # plt.close() # Optimization step self.policy.optimizer.zero_grad() loss.backward() # Clip grad norm th.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm) self.policy.optimizer.step() explained_var = explained_variance( self.rollout_buffer.returns.flatten(), self.rollout_buffer.values.flatten()) self._n_updates += 1 logger.logkv("n_updates", self._n_updates) logger.logkv("explained_variance", explained_var) logger.logkv("entropy_loss", entropy_loss.item()) logger.logkv("policy_loss", policy_loss.item()) logger.logkv("value_loss", value_loss.item()) if hasattr(self.policy, 'log_std'): logger.logkv("std", th.exp(self.policy.log_std).mean().item())
def train(self) -> None: """ Update policy using the currently gathered rollout buffer. """ # Switch to train mode (this affects batch norm / dropout) self.policy.set_training_mode(True) # Update optimizer learning rate self._update_learning_rate(self.policy.optimizer) # Compute current clip range clip_range = self.clip_range(self._current_progress_remaining) # Optional: clip range for the value function if self.clip_range_vf is not None: clip_range_vf = self.clip_range_vf( self._current_progress_remaining) entropy_losses = [] pg_losses, value_losses = [], [] clip_fractions = [] continue_training = True # train for n_epochs epochs for epoch in range(self.n_epochs): approx_kl_divs = [] # Do a complete pass on the rollout buffer for rollout_data in self.rollout_buffer.get(self.batch_size): actions = rollout_data.actions if isinstance(self.action_space, spaces.Discrete): # Convert discrete action from float to long actions = rollout_data.actions.long().flatten() # Re-sample the noise matrix because the log_std has changed if self.use_sde: self.policy.reset_noise(self.batch_size) values, log_prob, entropy = self.policy.evaluate_actions( rollout_data.observations, actions) values = values.flatten() # Normalize advantage advantages = rollout_data.advantages advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8) # ratio between old and new policy, should be one at the first iteration ratio = th.exp(log_prob - rollout_data.old_log_prob) # clipped surrogate loss policy_loss_1 = advantages * ratio policy_loss_2 = advantages * th.clamp(ratio, 1 - clip_range, 1 + clip_range) policy_loss = -th.min(policy_loss_1, policy_loss_2).mean() # Logging pg_losses.append(policy_loss.item()) clip_fraction = th.mean( (th.abs(ratio - 1) > clip_range).float()).item() clip_fractions.append(clip_fraction) if self.clip_range_vf is None: # No clipping values_pred = values else: # Clip the different between old and new value # NOTE: this depends on the reward scaling values_pred = rollout_data.old_values + th.clamp( values - rollout_data.old_values, -clip_range_vf, clip_range_vf) # Value loss using the TD(gae_lambda) target value_loss = F.mse_loss(rollout_data.returns, values_pred) value_losses.append(value_loss.item()) # Entropy loss favor exploration if entropy is None: # Approximate entropy when no analytical form entropy_loss = -th.mean(-log_prob) else: entropy_loss = -th.mean(entropy) entropy_losses.append(entropy_loss.item()) loss = policy_loss + self.ent_coef * entropy_loss + self.vf_coef * value_loss # Calculate approximate form of reverse KL Divergence for early stopping # see issue #417: https://github.com/DLR-RM/stable-baselines3/issues/417 # and discussion in PR #419: https://github.com/DLR-RM/stable-baselines3/pull/419 # and Schulman blog: http://joschu.net/blog/kl-approx.html with th.no_grad(): log_ratio = log_prob - rollout_data.old_log_prob approx_kl_div = th.mean((th.exp(log_ratio) - 1) - log_ratio).cpu().numpy() approx_kl_divs.append(approx_kl_div) if self.target_kl is not None and approx_kl_div > 1.5 * self.target_kl: continue_training = False if self.verbose >= 1: print( f"Early stopping at step {epoch} due to reaching max kl: {approx_kl_div:.2f}" ) break # Optimization step self.policy.optimizer.zero_grad() loss.backward() # Clip grad norm th.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm) self.policy.optimizer.step() if not continue_training: break self._n_updates += self.n_epochs explained_var = explained_variance( self.rollout_buffer.values.flatten(), self.rollout_buffer.returns.flatten()) # Logs self.logger.record("train/entropy_loss", np.mean(entropy_losses)) self.logger.record("train/policy_gradient_loss", np.mean(pg_losses)) self.logger.record("train/value_loss", np.mean(value_losses)) self.logger.record("train/approx_kl", np.mean(approx_kl_divs)) self.logger.record("train/clip_fraction", np.mean(clip_fractions)) self.logger.record("train/loss", loss.item()) self.logger.record("train/explained_variance", explained_var) if hasattr(self.policy, "log_std"): self.logger.record("train/std", th.exp(self.policy.log_std).mean().item()) self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard") self.logger.record("train/clip_range", clip_range) if self.clip_range_vf is not None: self.logger.record("train/clip_range_vf", clip_range_vf)
def train(self) -> None: """ Update policy using the currently gathered rollout buffer. """ # Update optimizer learning rate self._update_learning_rate(self.policy.optimizer) entropy_losses, all_kl_divs = [], [] pg_losses, value_losses = [], [] # Train until the callback function tells otherwise num_steps_before_termination = 0 continue_updating = True self.step_constraint.before_updates(self, self.rollout_buffer) while continue_updating: approx_kl_divs = [] # Do a complete pass on the rollout buffer for rollout_data in self.rollout_buffer.get(self.batch_size): actions = rollout_data.actions if isinstance(self.action_space, spaces.Discrete): # Convert discrete action from float to long actions = rollout_data.actions.long().flatten() # Re-sample the noise matrix because the log_std has changed # TODO: investigate why there is no issue with the gradient # if that line is commented (as in SAC) if self.use_sde: self.policy.reset_noise(self.batch_size) values, log_prob, entropy = self.policy.evaluate_actions( rollout_data.observations, actions) values = values.flatten() # Normalize advantage advantages = rollout_data.advantages advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8) # ratio between old and new policy, should be one at the first iteration ratio = th.exp(log_prob - rollout_data.old_log_prob) policy_loss = (-(advantages * ratio)).mean() # Logging pg_losses.append(policy_loss.item()) values_pred = values # Value loss using the TD(gae_lambda) target value_loss = F.mse_loss(rollout_data.returns, values_pred) value_losses.append(value_loss.item()) # Entropy loss favor exploration if entropy is None: # Approximate entropy when no analytical form entropy_loss = -th.mean(-log_prob) else: entropy_loss = -th.mean(entropy) entropy_losses.append(entropy_loss.item()) loss = policy_loss + self.ent_coef * entropy_loss + self.vf_coef * value_loss # Optimization step self.policy.optimizer.zero_grad() loss.backward() # Clip grad norm th.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm) self.policy.optimizer.step() approx_kl_divs.append( th.mean(rollout_data.old_log_prob - log_prob).detach().cpu().numpy()) num_steps_before_termination += 1 if self.step_constraint.check_constraint( self ) or num_steps_before_termination >= self.step_constraint_max_updates: continue_updating = False break all_kl_divs.append(np.mean(approx_kl_divs)) self._n_updates += num_steps_before_termination explained_var = explained_variance( self.rollout_buffer.values.flatten(), self.rollout_buffer.returns.flatten()) # Logs logger.record("train/entropy_loss", np.mean(entropy_losses)) logger.record("train/policy_gradient_loss", np.mean(pg_losses)) logger.record("train/value_loss", np.mean(value_losses)) logger.record("train/approx_kl", np.mean(approx_kl_divs)) logger.record("train/loss", loss.item()) logger.record("train/explained_variance", explained_var) if hasattr(self.policy, "log_std"): logger.record("train/std", th.exp(self.policy.log_std).mean().item()) logger.record("train/n_updates", self._n_updates, exclude="tensorboard")
def train(self, gradient_steps: int, batch_size: Optional[int] = None) -> None: # Update optimizer learning rate self._update_learning_rate(self.policy.optimizer) # A2C with gradient_steps > 1 does not make sense assert gradient_steps == 1, "A2C does not support multiple gradient steps" # We do not use minibatches for A2C assert batch_size is None, "A2C does not support minibatch" for rollout_data in self.rollout_buffer.get(batch_size=None): actions = rollout_data.actions if isinstance(self.action_space, spaces.Discrete): # Convert discrete action from float to long actions = actions.long().flatten() # TODO: avoid second computation of everything because of the gradient values, log_prob, entropy = self.policy.evaluate_actions( rollout_data.observations, actions) values = values.flatten() # Normalize advantage (not present in the original implementation) advantages = rollout_data.advantages if self.normalize_advantage: advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8) # Policy gradient loss policy_loss = -(advantages * log_prob).mean() # Value loss using the TD(gae_lambda) target value_loss = F.mse_loss(rollout_data.returns, values) # Entropy loss favor exploration if entropy is None: # Approximate entropy when no analytical form entropy_loss = -log_prob.mean() else: entropy_loss = -th.mean(entropy) loss = policy_loss + self.ent_coef * entropy_loss + self.vf_coef * value_loss # Optimization step self.policy.optimizer.zero_grad() loss.backward() # Clip grad norm th.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm) self.policy.optimizer.step() explained_var = explained_variance( self.rollout_buffer.returns.flatten(), self.rollout_buffer.values.flatten()) self._n_updates += 1 logger.logkv("n_updates", self._n_updates) logger.logkv("explained_variance", explained_var) logger.logkv("entropy_loss", entropy_loss.item()) logger.logkv("policy_loss", policy_loss.item()) logger.logkv("value_loss", value_loss.item()) if hasattr(self.policy, 'log_std'): logger.logkv("std", th.exp(self.policy.log_std).mean().item())