def _quantile_loss(self, states_t, actions_t, rewards_t, states_tp1, done_t): gammas = self._gammas**self._n_step # actor loss policy_loss = -torch.mean(self.critic(states_t, self.actor(states_t))) # critic loss (quantile regression) atoms_t = self.critic(states_t, actions_t) # B x num_heads x num_atoms atoms_tp1 = self.target_critic(states_tp1, self.target_actor(states_tp1)).detach() # B x num_heads x num_atoms done_t = done_t[:, None, :] # B x 1 x 1 rewards_t = rewards_t[:, None, :] # B x 1 x 1 gammas = gammas[None, :, None] # 1 x num_heads x 1 atoms_target_t = rewards_t + (1 - done_t) * gammas * atoms_tp1 value_loss = quantile_loss(atoms_t.view(-1, self.num_atoms), atoms_target_t.view(-1, self.num_atoms), self.tau, self.num_atoms, self.critic_criterion) return policy_loss, value_loss
def _quantile_loss(self, states_t, actions_t, rewards_t, states_tp1, done_t): # actor loss actions_tp0 = self.actor(states_t) atoms_tp0 = [ x(states_t, actions_tp0).unsqueeze_(-1) for x in self.critics ] q_values_tp0_min = torch.cat(atoms_tp0, dim=-1).mean(dim=1).min(dim=1)[0] policy_loss = -torch.mean(q_values_tp0_min) # critic loss (quantile regression) actions_tp1 = self.target_actor(states_tp1).detach() actions_tp1 = self._add_noise_to_actions(actions_tp1) atoms_t = [x(states_t, actions_t) for x in self.critics] atoms_tp1 = torch.cat([ x(states_tp1, actions_tp1).unsqueeze_(-1) for x in self.target_critics ], dim=-1) atoms_ids_tp1_min = atoms_tp1.mean(dim=1).argmin(dim=1) atoms_tp1 = \ atoms_tp1[range(len(atoms_tp1)), :, atoms_ids_tp1_min].detach() gamma = self.gamma**self.n_step atoms_target_t = rewards_t + (1 - done_t) * gamma * atoms_tp1 value_loss = [ quantile_loss(x, atoms_target_t, self.tau, self.n_atoms, self.critic_criterion) for x in atoms_t ] return policy_loss, value_loss
def _quantile_loss(self, states_t, actions_t, rewards_t, states_tp1, done_t): gamma = self.gamma**self.n_step # actor loss policy_loss = -torch.mean(self.critic(states_t, self.actor(states_t))) # critic loss (quantile regression) atoms_t = self.critic(states_t, actions_t) atoms_tp1 = self.target_critic(states_tp1, self.target_actor(states_tp1)).detach() atoms_target_t = rewards_t + (1 - done_t) * gamma * atoms_tp1 value_loss = quantile_loss(atoms_t, atoms_target_t, self.tau, self.n_atoms, self.critic_criterion) return policy_loss, value_loss
def _quantile_loss(self, states_t, actions_t, rewards_t, states_tp1, done_t): gamma = self._gamma**self._n_step # critic loss (quantile regression) indices_t = actions_t.repeat(1, self.num_atoms).unsqueeze(1) atoms_t = self.critic(states_t).gather(1, indices_t).squeeze(1) all_atoms_tp1 = self.target_critic(states_tp1).detach() q_values_tp1 = all_atoms_tp1.mean(dim=-1) actions_tp1 = torch.argmax(q_values_tp1, dim=-1, keepdim=True) indices_tp1 = actions_tp1.repeat(1, self.num_atoms).unsqueeze(1) atoms_tp1 = all_atoms_tp1.gather(1, indices_tp1).squeeze(1) atoms_target_t = rewards_t + (1 - done_t) * gamma * atoms_tp1 value_loss = quantile_loss(atoms_t, atoms_target_t, self.tau, self.num_atoms, self.critic_criterion) return value_loss
def _quantile_loss( self, states_t, actions_t, rewards_t, states_tp1, done_t ): # actor loss actions_tp0, log_pi_tp0 = self.actor(states_t, with_log_pi=True) log_pi_tp0 = log_pi_tp0[:, None] / self.reward_scale atoms_tp0 = [ x(states_t, actions_tp0).unsqueeze_(-1) for x in self.critics ] q_values_tp0_min = torch.cat( atoms_tp0, dim=-1 ).mean(dim=1).min(dim=1)[0] policy_loss = torch.mean(log_pi_tp0 - q_values_tp0_min) # critic loss (quantile regression) actions_tp1, log_pi_tp1 = self.actor( states_tp1, with_log_pi=True ) log_pi_tp1 = log_pi_tp1[:, None] / self.reward_scale atoms_t = [x(states_t, actions_t) for x in self.critics] atoms_tp1 = torch.cat( [ x(states_tp1, actions_tp1).unsqueeze_(-1) for x in self.target_critics ], dim=-1 ) atoms_ids_tp1_min = atoms_tp1.mean(dim=1).argmin(dim=1) atoms_tp1 = atoms_tp1[range(len(atoms_tp1)), :, atoms_ids_tp1_min] gamma = self.gamma**self.n_step atoms_tp1 = (atoms_tp1 - log_pi_tp1).detach() atoms_target_t = rewards_t + (1 - done_t) * gamma * atoms_tp1 value_loss = [ quantile_loss( x, atoms_target_t, self.tau, self.n_atoms, self.critic_criterion ) for x in atoms_t ] return policy_loss, value_loss
def _quantile_loss(self, states_t, actions_t, rewards_t, states_tp1, done_t): critic = self.critic target_critic = self.target_critic gammas = (self._gammas**self._n_step)[None, :, None] # 1 x num_heads x 1 done_t = done_t[:, None, :] # B x 1 x 1 rewards_t = rewards_t[:, None, :] # B x 1 x 1 actions_t = actions_t[:, None, None, :] # B x 1 x 1 x 1 indices_t = actions_t.repeat(1, self._num_heads, 1, self.num_atoms) # B x num_heads x 1 x num_atoms # critic loss (quantile regression) atoms_t = critic(states_t).gather(-2, indices_t).squeeze(-2) # B x num_heads x num_atoms all_atoms_tp1 = target_critic(states_tp1).detach() # B x num_heads x num_actions x num_atoms q_values_tp1 = all_atoms_tp1.mean(dim=-1) # B x num_heads x num_actions actions_tp1 = torch.argmax(q_values_tp1, dim=-1, keepdim=True) # B x num_heads x 1 indices_tp1 = actions_tp1.unsqueeze(-1).repeat(1, 1, 1, self.num_atoms) # B x num_heads x 1 x num_atoms atoms_tp1 = all_atoms_tp1.gather(-2, indices_tp1).squeeze(-2) # B x num_heads x num_atoms atoms_target_t = rewards_t + (1 - done_t) * gammas * atoms_tp1 value_loss = quantile_loss(atoms_t.view(-1, self.num_atoms), atoms_target_t.view(-1, self.num_atoms), self.tau, self.num_atoms, self.critic_criterion) return value_loss