Exemplo n.º 1
0
def categorical_q_loss(agent: DQN, batch: collections.namedtuple):
    """Categorical DQN loss function to calculate the loss of the Q-function

    Args:
        agent (:obj:`DQN`): The agent
        batch (:obj:`collections.namedtuple` of :obj:`torch.Tensor`): Batch of experiences

    Returns:
        loss (:obj:`torch.Tensor`): Calculateed loss of the Q-function
    """
    q_values = agent.get_q_values(batch.states, batch.actions)
    target_q_values = agent.get_target_q_values(batch.next_states,
                                                batch.rewards, batch.dones)

    # For the loss, we take the difference
    loss = -(target_q_values * q_values.log()).sum(1).mean()
    return loss
Exemplo n.º 2
0
def prioritized_q_loss(agent: DQN, batch: collections.namedtuple):
    """Function to calculate the loss of the Q-function

    Returns:
        agent (:obj:`DQN`): The agent
        loss (:obj:`torch.Tensor`): Calculateed loss of the Q-function
    """
    q_values = agent.get_q_values(batch.states, batch.actions)
    target_q_values = agent.get_target_q_values(batch.next_states,
                                                batch.rewards, batch.dones)

    # Weighted MSE Loss
    loss = batch.weights * (q_values - target_q_values.detach())**2
    # Priorities are taken as the td-errors + some small value to avoid 0s
    priorities = loss + 1e-5
    loss = loss.mean()
    agent.replay_buffer.update_priorities(batch.indices,
                                          priorities.detach().cpu().numpy())
    agent.logs["value_loss"].append(loss.item())
    return loss