Example #1
0
    def calc_loss(self, q_values: Tensor, target_q_values: Tensor,
                  actions: Tensor, rewards: Tensor,
                  done_mask: Tensor) -> Tensor:
        """
        Calculate the MSE loss of this step.
        The loss for an example is defined as:
            Q_samp(s) = r if done
                        = r + gamma * max_a' Q_target(s', a') otherwise
            loss = (Q_samp(s) - Q(s, a))^2

        Args:
            q_values: (torch tensor) shape = (batch_size, num_actions)
                The Q-values that your current network estimates (i.e. Q(s, a') for all a')
            target_q_values: (torch tensor) shape = (batch_size, num_actions)
                The Target Q-values that your target network estimates (i.e. (i.e. Q_target(s', a') for all a')
            actions: (torch tensor) shape = (batch_size,)
                The actions that you actually took at each step (i.e. a)
            rewards: (torch tensor) shape = (batch_size,)
                The rewards that you actually got at each step (i.e. r)
            done_mask: (torch tensor) shape = (batch_size,)
                A boolean mask of examples where we reached the terminal state

        Hint:
            You may find the following functions useful
                - torch.max
                - torch.sum
                - torch.nn.functional.one_hot
                - torch.nn.functional.mse_loss
            You can treat `done_mask` as a 0 and 1 where 0 is not done and 1 is done using torch.type as
            done below

            To extract Q(a) for a specific "a" you can use the torch.sum and torch.nn.functional.one_hot. 
            Think about how.
        """
        # you may need this variable
        num_actions = self.env.action_space.n
        gamma = self.config.gamma
        done_mask = done_mask.type(torch.int)
        actions = actions.type(torch.int64)
        ##############################################################
        ##################### YOUR CODE HERE - 3-5 lines #############
        target_q = torch.reshape(
            torch.max(target_q_values, dim=1, keepdim=True).values, (-1, ))
        q_val1 = rewards + (1 - done_mask) * gamma * target_q
        q_val2 = torch.sum(
            q_values *
            torch.nn.functional.one_hot(actions, self.env.action_space.n),
            dim=1)
        loss = torch.nn.functional.mse_loss(q_val1, q_val2)
        ##############################################################
        ######################## END YOUR CODE #######################
        return loss
Example #2
0
def jaccard_loss(logits: Tensor, mask: Tensor) -> Tensor:
    """Computes the Jaccard loss, a.k.a the IoU loss.
    Note that PyTorch optimizers minimize a loss. In this
    case, we would like to maximize the jaccard loss so we
    return the negated jaccard loss.
    Args:
        mask: a tensor of shape [B, H, W] or [B, 1, H, W].
        logits: a tensor of shape [B, C, H, W]. Corresponds to
            the raw output or logits of the model.
    Returns:
        jacc_loss: the Jaccard loss.
    """
    num_classes = logits.shape[1]
    if num_classes == 1:
        true_1_hot = torch.eye(num_classes + 1)[mask.squeeze(1)]
        true_1_hot = true_1_hot.permute(0, 3, 1, 2).float()
        true_1_hot_f = true_1_hot[:, 0:1, :, :]
        true_1_hot_s = true_1_hot[:, 1:2, :, :]
        true_1_hot = torch.cat([true_1_hot_s, true_1_hot_f], dim=1)
        pos_prob = logits.sigmoid()
        neg_prob = 1 - pos_prob
        probas = torch.cat([pos_prob, neg_prob], dim=1)
    else:
        true_1_hot = torch.eye(num_classes)[mask.squeeze(1)]
        true_1_hot = true_1_hot.permute(0, 3, 1, 2).float()
        probas = logits.softmax(dim=1)
    true_1_hot = true_1_hot.type(logits.type())
    dims = (0, ) + tuple(range(2, mask.ndimension()))
    intersection = torch.sum(probas * true_1_hot, dims)
    cardinality = torch.sum(probas + true_1_hot, dims)
    union = cardinality - intersection
    jacc_loss = (intersection / (union + torch.finfo(logits.dtype))).mean()
    return 1.0 - jacc_loss
Example #3
0
    def calc_loss(self, q_values: Tensor, target_q_values: Tensor,
                  actions: Tensor, rewards: Tensor, done_mask: Tensor,
                  state: Tensor, next_state: Tensor) -> Tensor:
        """
        Calculate the MSE loss of this step.
        The loss for an example is defined as:
            Q_samp(s) = r if done
                        = r + gamma * max_a' Q_target(s', a') otherwise
            loss = (Q_samp(s) - Q(s, a))^2

        Args:
            q_values: (torch tensor) shape = (batch_size, num_actions)
                The Q-values that your current network estimates (i.e. Q(s, a') for all a')
            target_q_values: (torch tensor) shape = (batch_size, num_actions)
                The Target Q-values that your target network estimates (i.e. (i.e. Q_target(s', a') for all a')
            actions: (torch tensor) shape = (batch_size,)
                The actions that you actually took at each step (i.e. a)
            rewards: (torch tensor) shape = (batch_size,)
                The rewards that you actually got at each step (i.e. r)
            done_mask: (torch tensor) shape = (batch_size,)
                A boolean mask of examples where we reached the terminal state

        Hint:
            You may find the following functions useful
                - torch.max
                - torch.sum
                - torch.nn.functional.one_hot
                - torch.nn.functional.mse_loss
            You can treat `done_mask` as a 0 and 1 where 0 is not done and 1 is done using torch.type as
            done below

            To extract Q(a) for a specific "a" you can use the torch.sum and torch.nn.functional.one_hot. 
            Think about how.
        """
        # you may need this variable
        num_actions = self.env.action_space.n
        gamma = self.config.gamma
        done_mask = done_mask.type(torch.int)
        actions = actions.type(torch.int64)
        ##############################################################
        ##################### YOUR CODE HERE - 3-5 lines #############
        '''
        # This is the vanilla DQN Loss function. The uncommented code is the DDQN Loss function
        best_target_q = torch.reshape(torch.max(target_q_values, dim=1, keepdim=True).values, (-1,))
        Q_samp = rewards + (1 - done_mask) * gamma * best_target_q
        Q_sa = torch.sum(q_values * torch.nn.functional.one_hot(actions, self.env.action_space.n), dim=1)
        loss = torch.nn.functional.mse_loss(Q_samp, Q_sa)'''
        state = state.to('cuda:0')
        next_state = next_state.to('cuda:0')
        actions = actions.to('cuda:0')
        rewards = rewards.to('cuda:0')
        done_mask = done_mask.to('cuda:0')
        actions = actions.unsqueeze(-1)
        state_action_vals = self.get_q_values(state,
                                              'q_network').gather(1, actions)
        state_action_vals = state_action_vals.squeeze(-1)
        next_state_action = self.get_q_values(next_state,
                                              'q_network').max(1)[1]
        next_state_action = next_state_action.unsqueeze(-1)
        next_state_vals = self.get_q_values(next_state, 'target').gather(
            1, next_state_action).squeeze(-1)

        exp_sa_vals = next_state_vals.detach() * gamma * (1 -
                                                          done_mask) + rewards
        loss = torch.nn.functional.mse_loss(state_action_vals, exp_sa_vals)

        ##############################################################
        ######################## END YOUR CODE #######################
        return loss