def test_quantile_huber_loss(): assert np.isclose(quantile_huber_loss(th.zeros(1, 10), th.ones(1, 10)), 2.5) assert np.isclose( quantile_huber_loss(th.zeros(1, 10), th.ones(1, 10), sum_over_quantiles=False), 0.25) with pytest.raises(ValueError): quantile_huber_loss(th.zeros(1, 4, 4), th.zeros(1, 4)) with pytest.raises(ValueError): quantile_huber_loss(th.zeros(1, 4), th.zeros(1, 1, 4)) with pytest.raises(ValueError): quantile_huber_loss(th.zeros(4, 4), th.zeros(3, 4)) with pytest.raises(ValueError): quantile_huber_loss(th.zeros(4, 4, 4, 4), th.zeros(4, 4, 4, 4))
def train(self, gradient_steps: int, batch_size: int = 100) -> None: # Update learning rate according to schedule self._update_learning_rate(self.policy.optimizer) losses = [] for gradient_step in range(gradient_steps): # Sample replay buffer replay_data = self.replay_buffer.sample( batch_size, env=self._vec_normalize_env) with th.no_grad(): # Compute the quantiles of next observation next_quantiles = self.quantile_net_target( replay_data.next_observations) # Compute the greedy actions which maximize the next Q values next_greedy_actions = next_quantiles.mean( dim=1, keepdim=True).argmax(dim=2, keepdim=True) # Make "n_quantiles" copies of actions, and reshape to (batch_size, n_quantiles, 1) next_greedy_actions = next_greedy_actions.expand( batch_size, self.n_quantiles, 1) # Follow greedy policy: use the one with the highest Q values next_quantiles = next_quantiles.gather( dim=2, index=next_greedy_actions).squeeze(dim=2) # 1-step TD target target_quantiles = replay_data.rewards + ( 1 - replay_data.dones) * self.gamma * next_quantiles # Get current quantile estimates current_quantiles = self.quantile_net(replay_data.observations) # Make "n_quantiles" copies of actions, and reshape to (batch_size, n_quantiles, 1). actions = replay_data.actions[..., None].long().expand( batch_size, self.n_quantiles, 1) # Retrieve the quantiles for the actions from the replay buffer current_quantiles = th.gather(current_quantiles, dim=2, index=actions).squeeze(dim=2) # Compute Quantile Huber loss, summing over a quantile dimension as in the paper. loss = quantile_huber_loss(current_quantiles, target_quantiles, sum_over_quantiles=True) losses.append(loss.item()) # Optimize the policy self.policy.optimizer.zero_grad() loss.backward() # Clip gradient norm if self.max_grad_norm is not None: th.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm) self.policy.optimizer.step() # Increase update counter self._n_updates += gradient_steps logger.record("train/n_updates", self._n_updates, exclude="tensorboard") logger.record("train/loss", np.mean(losses))
def train(self, gradient_steps: int, batch_size: int = 64) -> None: # Update optimizers learning rate optimizers = [self.actor.optimizer, self.critic.optimizer] if self.ent_coef_optimizer is not None: optimizers += [self.ent_coef_optimizer] # Update learning rate according to lr schedule self._update_learning_rate(optimizers) ent_coef_losses, ent_coefs = [], [] actor_losses, critic_losses = [], [] for gradient_step in range(gradient_steps): # Sample replay buffer replay_data = self.replay_buffer.sample( batch_size, env=self._vec_normalize_env) # We need to sample because `log_std` may have changed between two gradient steps if self.use_sde: self.actor.reset_noise() # Action by the current actor for the sampled state actions_pi, log_prob = self.actor.action_log_prob( replay_data.observations) log_prob = log_prob.reshape(-1, 1) ent_coef_loss = None if self.ent_coef_optimizer is not None: # Important: detach the variable from the graph # so we don't change it with other losses # see https://github.com/rail-berkeley/softlearning/issues/60 ent_coef = th.exp(self.log_ent_coef.detach()) ent_coef_loss = -( self.log_ent_coef * (log_prob + self.target_entropy).detach()).mean() ent_coef_losses.append(ent_coef_loss.item()) else: ent_coef = self.ent_coef_tensor ent_coefs.append(ent_coef.item()) self.replay_buffer.ent_coef = ent_coef.item() # Optimize entropy coefficient, also called # entropy temperature or alpha in the paper if ent_coef_loss is not None: self.ent_coef_optimizer.zero_grad() ent_coef_loss.backward() self.ent_coef_optimizer.step() with th.no_grad(): # Select action according to policy next_actions, next_log_prob = self.actor.action_log_prob( replay_data.next_observations) # Compute and cut quantiles at the next state # batch x nets x quantiles next_quantiles = self.critic_target( replay_data.next_observations, next_actions) # Sort and drop top k quantiles to control overestimation. n_target_quantiles = self.critic.quantiles_total - self.top_quantiles_to_drop_per_net * self.critic.n_critics next_quantiles, _ = th.sort( next_quantiles.reshape(batch_size, -1)) next_quantiles = next_quantiles[:, :n_target_quantiles] # td error + entropy term target_quantiles = next_quantiles - ent_coef * next_log_prob.reshape( -1, 1) target_quantiles = replay_data.rewards + ( 1 - replay_data.dones) * self.gamma * target_quantiles # Make target_quantiles broadcastable to (batch_size, n_critics, n_target_quantiles). target_quantiles.unsqueeze_(dim=1) # Get current Quantile estimates using action from the replay buffer current_quantiles = self.critic(replay_data.observations, replay_data.actions) # Compute critic loss, not summing over the quantile dimension as in the paper. critic_loss = quantile_huber_loss(current_quantiles, target_quantiles, sum_over_quantiles=False) critic_losses.append(critic_loss.item()) # Optimize the critic self.critic.optimizer.zero_grad() critic_loss.backward() self.critic.optimizer.step() # Compute actor loss qf_pi = self.critic(replay_data.observations, actions_pi).mean(dim=2).mean(dim=1, keepdim=True) actor_loss = (ent_coef * log_prob - qf_pi).mean() actor_losses.append(actor_loss.item()) # Optimize the actor self.actor.optimizer.zero_grad() actor_loss.backward() self.actor.optimizer.step() # Update target networks if gradient_step % self.target_update_interval == 0: polyak_update(self.critic.parameters(), self.critic_target.parameters(), self.tau) self._n_updates += gradient_steps logger.record("train/n_updates", self._n_updates, exclude="tensorboard") logger.record("train/ent_coef", np.mean(ent_coefs)) logger.record("train/actor_loss", np.mean(actor_losses)) logger.record("train/critic_loss", np.mean(critic_losses)) if len(ent_coef_losses) > 0: logger.record("train/ent_coef_loss", np.mean(ent_coef_losses))