def ddqn_q_target( agent: DQN, next_states: torch.Tensor, rewards: torch.Tensor, dones: torch.Tensor, ) -> torch.Tensor: """Double Q-learning target Can be used to replace the `get_target_values` method of the Base DQN class in any DQN algorithm Args: agent (:obj:`DQN`): The agent next_states (:obj:`torch.Tensor`): Next states being encountered by the agent rewards (:obj:`torch.Tensor`): Rewards received by the agent dones (:obj:`torch.Tensor`): Game over status of each environment Returns: target_q_values (:obj:`torch.Tensor`): Target Q values using Double Q-learning """ next_q_value_dist = agent.model(next_states) next_best_actions = torch.argmax(next_q_value_dist, dim=-1).unsqueeze(-1) rewards, dones = rewards.unsqueeze(-1), dones.unsqueeze(-1) next_q_target_value_dist = agent.target_model(next_states) max_next_q_target_values = next_q_target_value_dist.gather( 2, next_best_actions) target_q_values = rewards + agent.gamma * torch.mul( max_next_q_target_values, (1 - dones)) return target_q_values
def categorical_q_target( agent: DQN, next_states: torch.Tensor, rewards: torch.Tensor, dones: torch.Tensor, ): """Projected Distribution of Q-values Helper function for Categorical/Distributional DQN Args: agent (:obj:`DQN`): The agent next_states (:obj:`torch.Tensor`): Next states being encountered by the agent rewards (:obj:`torch.Tensor`): Rewards received by the agent dones (:obj:`torch.Tensor`): Game over status of each environment Returns: target_q_values (object): Projected Q-value Distribution or Target Q Values """ delta_z = float(agent.v_max - agent.v_min) / (agent.num_atoms - 1) support = torch.linspace(agent.v_min, agent.v_max, agent.num_atoms) next_q_value_dist = agent.target_model(next_states) * support next_actions = (torch.argmax(next_q_value_dist.sum(-1), axis=-1).unsqueeze(-1).unsqueeze(-1)) next_actions = next_actions.expand(agent.batch_size, agent.env.n_envs, 1, agent.num_atoms) next_q_values = next_q_value_dist.gather(2, next_actions).squeeze(2) rewards = rewards.unsqueeze(-1).expand_as(next_q_values) dones = dones.unsqueeze(-1).expand_as(next_q_values) # Refer to the paper in section 4 for notation Tz = rewards + (1 - dones) * 0.99 * support Tz = Tz.clamp(min=agent.v_min, max=agent.v_max) bz = (Tz - agent.v_min) / delta_z l = bz.floor().long() u = bz.ceil().long() offset = (torch.linspace( 0, (agent.batch_size * agent.env.n_envs - 1) * agent.num_atoms, agent.batch_size * agent.env.n_envs, ).long().view(agent.batch_size, agent.env.n_envs, 1).expand(agent.batch_size, agent.env.n_envs, agent.num_atoms)) target_q_values = torch.zeros(next_q_values.size()) target_q_values.view(-1).index_add_( 0, (l + offset).view(-1), (next_q_values * (u.float() - bz)).view(-1), ) target_q_values.view(-1).index_add_( 0, (u + offset).view(-1), (next_q_values * (bz - l.float())).view(-1), ) return target_q_values