def categorical_q_loss(agent: DQN, batch: collections.namedtuple): """Categorical DQN loss function to calculate the loss of the Q-function Args: agent (:obj:`DQN`): The agent batch (:obj:`collections.namedtuple` of :obj:`torch.Tensor`): Batch of experiences Returns: loss (:obj:`torch.Tensor`): Calculateed loss of the Q-function """ q_values = agent.get_q_values(batch.states, batch.actions) target_q_values = agent.get_target_q_values(batch.next_states, batch.rewards, batch.dones) # For the loss, we take the difference loss = -(target_q_values * q_values.log()).sum(1).mean() return loss
def prioritized_q_loss(agent: DQN, batch: collections.namedtuple): """Function to calculate the loss of the Q-function Returns: agent (:obj:`DQN`): The agent loss (:obj:`torch.Tensor`): Calculateed loss of the Q-function """ q_values = agent.get_q_values(batch.states, batch.actions) target_q_values = agent.get_target_q_values(batch.next_states, batch.rewards, batch.dones) # Weighted MSE Loss loss = batch.weights * (q_values - target_q_values.detach())**2 # Priorities are taken as the td-errors + some small value to avoid 0s priorities = loss + 1e-5 loss = loss.mean() agent.replay_buffer.update_priorities(batch.indices, priorities.detach().cpu().numpy()) agent.logs["value_loss"].append(loss.item()) return loss