def _categorical_loss(self, states_t, actions_t, rewards_t, states_tp1, done_t): gammas, done_t, rewards_t = self._process_components(done_t, rewards_t) # actor loss # For now we have the same actor for all heads of the critic # [bs; num_heads; num_atoms] -> many-heads view transform # [{bs * num_heads}; num_atoms] logits_tp0 = (self.critic(states_t, self.actor(states_t)).squeeze_(dim=2).view( -1, self.num_atoms)) # [{bs * num_heads}; num_atoms] probs_tp0 = torch.softmax(logits_tp0, dim=-1) # [{bs * num_heads}; 1] q_values_tp0 = torch.sum(probs_tp0 * self.z, dim=-1) policy_loss = -torch.mean(q_values_tp0) # critic loss (kl-divergence between categorical distributions) # [bs; num_heads; num_atoms] -> many-heads view transform # [{bs * num_heads}; num_atoms] logits_t = (self.critic(states_t, actions_t).squeeze_(dim=2).view( -1, self.num_atoms)) # [bs; action_size] actions_tp1 = self.target_actor(states_tp1) # [bs; num_heads; num_atoms] -> many-heads view transform # [{bs * num_heads}; num_atoms] logits_tp1 = (self.target_critic(states_tp1, actions_tp1).squeeze_(dim=2).view( -1, self.num_atoms)).detach() # [bs; num_heads; num_atoms] -> many-heads view transform # [{bs * num_heads}; num_atoms] atoms_target_t = (rewards_t + (1 - done_t) * gammas * self.z).view( -1, self.num_atoms) value_loss = utils.categorical_loss( # [{bs * num_heads}; num_atoms] logits_t, # [{bs * num_heads}; num_atoms] logits_tp1, # [{bs * num_heads}; num_atoms] atoms_target_t, self.z, self.delta_z, self.v_min, self.v_max) return policy_loss, value_loss
def _categorical_value_loss(self, states_t, logits_t, returns_t, states_tp1, done_t): # @TODO: WIP, no guaranties logits_tp0 = self.critic(states_t).squeeze_(dim=2) probs_tp0 = torch.softmax(logits_tp0, dim=-1) values_tp0 = torch.sum(probs_tp0 * self.z, dim=-1, keepdim=True) probs_t = torch.softmax(logits_t, dim=-1) values_t = torch.sum(probs_t * self.z, dim=-1, keepdim=True) value_loss = 0.5 * self._value_loss(values_tp0, values_t, returns_t) # B x num_heads x num_atoms logits_tp1 = self.critic(states_tp1).squeeze_(dim=2).detach() # B x num_heads x num_atoms atoms_target_t = returns_t + (1 - done_t) * self._gammas_torch * self.z value_loss += 0.5 * utils.categorical_loss( logits_tp0.view(-1, self.num_atoms), logits_tp1.view(-1, self.num_atoms), atoms_target_t.view(-1, self.num_atoms), self.z, self.delta_z, self.v_min, self.v_max) return value_loss
def _categorical_loss( self, states_t, actions_t, rewards_t, states_tp1, done_t ): gammas, done_t, rewards_t = self._process_components(done_t, rewards_t) # [bs; 1] -> # [bs; 1; 1; 1;] actions_t = actions_t[:, None, None, :] # [bs; num_heads; 1; num_atoms] indices_t = actions_t.repeat(1, self._num_heads, 1, self.num_atoms) # [bs; num_heads; num_actions; num_atoms] q_logits_t = self.critic(states_t) # [bs; num_heads; 1; num_atoms] -> gathering selected actions # [bs; num_heads; num_atoms] -> many-heads view transform # [{bs * num_heads}; num_atoms] logits_t = ( q_logits_t.gather(-2, indices_t).squeeze(-2).view(-1, self.num_atoms) ) # [bs; num_heads; num_actions; num_atoms] q_logits_tp1 = self.target_critic(states_tp1).detach() # [bs; num_heads; num_actions; num_atoms] -> categorical value # [bs; num_heads; num_actions] -> gathering best actions # [bs; num_heads; 1] actions_tp1 = ( (torch.softmax(q_logits_tp1, dim=-1) * self.z).sum(dim=-1).argmax(dim=-1, keepdim=True) ) # [bs; num_heads; 1] -> # [bs; num_heads; 1; 1] -> # [bs; num_heads; 1; num_atoms] indices_tp1 = actions_tp1.unsqueeze(-1).repeat(1, 1, 1, self.num_atoms) # [bs; num_heads; 1; num_atoms] -> gathering best actions # [bs; num_heads; num_atoms] -> many-heads view transform # [{bs * num_heads}; num_atoms] logits_tp1 = ( q_logits_tp1.gather(-2, indices_tp1).squeeze(-2).view( -1, self.num_atoms ) ).detach() # [bs; num_heads; num_atoms] -> many-heads view transform # [{bs * num_heads}; num_atoms] atoms_target_t = (rewards_t + (1 - done_t) * gammas * self.z).view(-1, self.num_atoms).detach() value_loss = utils.categorical_loss( # [{bs * num_heads}; num_atoms] logits_t, # [{bs * num_heads}; num_atoms] logits_tp1, # [{bs * num_heads}; num_atoms] atoms_target_t, self.z, self.delta_z, self.v_min, self.v_max ) if self.entropy_regularization is not None: q_values_t = torch.sum( torch.softmax(q_logits_t, dim=-1) * self.z, dim=-1 ) value_loss -= \ self.entropy_regularization * self._compute_entropy(q_values_t) return value_loss
def _categorical_loss(self, states_t, actions_t, rewards_t, states_tp1, done_t): gammas, done_t, rewards_t = self._process_components(done_t, rewards_t) # actor loss # [bs; action_size] actions_tp0, logprob_tp0 = self.actor(states_t, logprob=True) logprob_tp0 = logprob_tp0 / self.reward_scale # {num_critics} * [bs; num_heads; num_atoms] # -> many-heads view transform # {num_critics} * [{bs * num_heads}; num_atoms] logits_tp0 = [ x(states_t, actions_tp0).squeeze_(dim=2).view(-1, self.num_atoms) for x in self.critics ] # -> categorical probs # {num_critics} * [{bs * num_heads}; num_atoms] probs_tp0 = [torch.softmax(x, dim=-1) for x in logits_tp0] # -> categorical value # {num_critics} * [{bs * num_heads}; 1] q_values_tp0 = [ torch.sum(x * self.z, dim=-1, keepdim=True) for x in probs_tp0 ] # [{bs * num_heads}; num_critics] -> min over all critics # [{bs * num_heads}] q_values_tp0_min = torch.cat(q_values_tp0, dim=-1).min(dim=-1)[0] # For now we use the same actor for each gamma policy_loss = torch.mean(logprob_tp0[:, None] - q_values_tp0_min) # critic loss (kl-divergence between categorical distributions) # [bs; action_size] actions_tp1, logprob_tp1 = self.actor(states_tp1, logprob=True) logprob_tp1 = logprob_tp1 / self.reward_scale # {num_critics} * [bs; num_heads; num_atoms] # -> many-heads view transform # {num_critics} * [{bs * num_heads}; num_atoms] logits_t = [ x(states_t, actions_t).squeeze_(dim=2).view(-1, self.num_atoms) for x in self.critics ] # {num_critics} * [bs; num_heads; num_atoms] logits_tp1 = [ x(states_tp1, actions_tp1).squeeze_(dim=2) for x in self.target_critics ] # {num_critics} * [{bs * num_heads}; num_atoms] probs_tp1 = [torch.softmax(x, dim=-1) for x in logits_tp1] # {num_critics} * [bs; num_heads; 1] q_values_tp1 = [ torch.sum(x * self.z, dim=-1, keepdim=True) for x in probs_tp1 ] # [{bs * num_heads}; num_critics] -> argmin over all critics # [{bs * num_heads}] probs_ids_tp1_min = torch.cat(q_values_tp1, dim=-1).argmin(dim=-1) # [bs; num_heads; num_atoms; num_critics] logits_tp1 = torch.cat([x.unsqueeze(-1) for x in logits_tp1], dim=-1) # @TODO: smarter way to do this (other than reshaping)? probs_ids_tp1_min = probs_ids_tp1_min.view(-1) # [bs; num_heads; num_atoms; num_critics] -> many-heads view transform # [{bs * num_heads}; num_atoms; num_critics] -> min over all critics # [{bs * num_heads}; num_atoms; 1] -> target view transform # [{bs; num_heads}; num_atoms] logits_tp1 = (logits_tp1.view( -1, self.num_atoms, self._num_critics)[range(len(probs_ids_tp1_min)), :, probs_ids_tp1_min].view( -1, self.num_atoms)).detach() # [bs; num_atoms] -> unsqueeze so its the same for each head # [bs; 1; num_atoms] z_target_tp1 = (self.z[None, :] - logprob_tp1[:, None]).unsqueeze(1).detach() # [bs; num_heads; num_atoms] -> many-heads view transform # [{bs * num_heads}; num_atoms] atoms_target_t = (rewards_t + (1 - done_t) * gammas * z_target_tp1).view( -1, self.num_atoms) value_loss = [ utils.categorical_loss( # [{bs * num_heads}; num_atoms] x, # [{bs * num_heads}; num_atoms] logits_tp1, # [{bs * num_heads}; num_atoms] atoms_target_t, self.z, self.delta_z, self.v_min, self.v_max) for x in logits_t ] return policy_loss, value_loss