Exemplo n.º 1
0
def ddqn_q_target(
    agent: DQN,
    next_states: torch.Tensor,
    rewards: torch.Tensor,
    dones: torch.Tensor,
) -> torch.Tensor:
    """Double Q-learning target

    Can be used to replace the `get_target_values` method of the Base DQN
    class in any DQN algorithm

    Args:
        agent (:obj:`DQN`): The agent
        next_states (:obj:`torch.Tensor`): Next states being encountered by the agent
        rewards (:obj:`torch.Tensor`): Rewards received by the agent
        dones (:obj:`torch.Tensor`): Game over status of each environment

    Returns:
        target_q_values (:obj:`torch.Tensor`): Target Q values using Double Q-learning
    """
    next_q_value_dist = agent.model(next_states)
    next_best_actions = torch.argmax(next_q_value_dist, dim=-1).unsqueeze(-1)

    rewards, dones = rewards.unsqueeze(-1), dones.unsqueeze(-1)

    next_q_target_value_dist = agent.target_model(next_states)
    max_next_q_target_values = next_q_target_value_dist.gather(
        2, next_best_actions)
    target_q_values = rewards + agent.gamma * torch.mul(
        max_next_q_target_values, (1 - dones))
    return target_q_values
Exemplo n.º 2
0
def categorical_q_target(
    agent: DQN,
    next_states: torch.Tensor,
    rewards: torch.Tensor,
    dones: torch.Tensor,
):
    """Projected Distribution of Q-values

    Helper function for Categorical/Distributional DQN

    Args:
        agent (:obj:`DQN`): The agent
        next_states (:obj:`torch.Tensor`): Next states being encountered by the agent
        rewards (:obj:`torch.Tensor`): Rewards received by the agent
        dones (:obj:`torch.Tensor`): Game over status of each environment

    Returns:
        target_q_values (object): Projected Q-value Distribution or Target Q Values
    """
    delta_z = float(agent.v_max - agent.v_min) / (agent.num_atoms - 1)
    support = torch.linspace(agent.v_min, agent.v_max, agent.num_atoms)

    next_q_value_dist = agent.target_model(next_states) * support
    next_actions = (torch.argmax(next_q_value_dist.sum(-1),
                                 axis=-1).unsqueeze(-1).unsqueeze(-1))

    next_actions = next_actions.expand(agent.batch_size, agent.env.n_envs, 1,
                                       agent.num_atoms)
    next_q_values = next_q_value_dist.gather(2, next_actions).squeeze(2)

    rewards = rewards.unsqueeze(-1).expand_as(next_q_values)
    dones = dones.unsqueeze(-1).expand_as(next_q_values)

    # Refer to the paper in section 4 for notation
    Tz = rewards + (1 - dones) * 0.99 * support
    Tz = Tz.clamp(min=agent.v_min, max=agent.v_max)
    bz = (Tz - agent.v_min) / delta_z
    l = bz.floor().long()
    u = bz.ceil().long()

    offset = (torch.linspace(
        0,
        (agent.batch_size * agent.env.n_envs - 1) * agent.num_atoms,
        agent.batch_size * agent.env.n_envs,
    ).long().view(agent.batch_size, agent.env.n_envs,
                  1).expand(agent.batch_size, agent.env.n_envs,
                            agent.num_atoms))

    target_q_values = torch.zeros(next_q_values.size())
    target_q_values.view(-1).index_add_(
        0,
        (l + offset).view(-1),
        (next_q_values * (u.float() - bz)).view(-1),
    )
    target_q_values.view(-1).index_add_(
        0,
        (u + offset).view(-1),
        (next_q_values * (bz - l.float())).view(-1),
    )
    return target_q_values