def calc_loss(self, q_values: Tensor, target_q_values: Tensor, actions: Tensor, rewards: Tensor, done_mask: Tensor) -> Tensor: """ Calculate the MSE loss of this step. The loss for an example is defined as: Q_samp(s) = r if done = r + gamma * max_a' Q_target(s', a') otherwise loss = (Q_samp(s) - Q(s, a))^2 Args: q_values: (torch tensor) shape = (batch_size, num_actions) The Q-values that your current network estimates (i.e. Q(s, a') for all a') target_q_values: (torch tensor) shape = (batch_size, num_actions) The Target Q-values that your target network estimates (i.e. (i.e. Q_target(s', a') for all a') actions: (torch tensor) shape = (batch_size,) The actions that you actually took at each step (i.e. a) rewards: (torch tensor) shape = (batch_size,) The rewards that you actually got at each step (i.e. r) done_mask: (torch tensor) shape = (batch_size,) A boolean mask of examples where we reached the terminal state Hint: You may find the following functions useful - torch.max - torch.sum - torch.nn.functional.one_hot - torch.nn.functional.mse_loss You can treat `done_mask` as a 0 and 1 where 0 is not done and 1 is done using torch.type as done below To extract Q(a) for a specific "a" you can use the torch.sum and torch.nn.functional.one_hot. Think about how. """ # you may need this variable num_actions = self.env.action_space.n gamma = self.config.gamma done_mask = done_mask.type(torch.int) actions = actions.type(torch.int64) ############################################################## ##################### YOUR CODE HERE - 3-5 lines ############# target_q = torch.reshape( torch.max(target_q_values, dim=1, keepdim=True).values, (-1, )) q_val1 = rewards + (1 - done_mask) * gamma * target_q q_val2 = torch.sum( q_values * torch.nn.functional.one_hot(actions, self.env.action_space.n), dim=1) loss = torch.nn.functional.mse_loss(q_val1, q_val2) ############################################################## ######################## END YOUR CODE ####################### return loss
def jaccard_loss(logits: Tensor, mask: Tensor) -> Tensor: """Computes the Jaccard loss, a.k.a the IoU loss. Note that PyTorch optimizers minimize a loss. In this case, we would like to maximize the jaccard loss so we return the negated jaccard loss. Args: mask: a tensor of shape [B, H, W] or [B, 1, H, W]. logits: a tensor of shape [B, C, H, W]. Corresponds to the raw output or logits of the model. Returns: jacc_loss: the Jaccard loss. """ num_classes = logits.shape[1] if num_classes == 1: true_1_hot = torch.eye(num_classes + 1)[mask.squeeze(1)] true_1_hot = true_1_hot.permute(0, 3, 1, 2).float() true_1_hot_f = true_1_hot[:, 0:1, :, :] true_1_hot_s = true_1_hot[:, 1:2, :, :] true_1_hot = torch.cat([true_1_hot_s, true_1_hot_f], dim=1) pos_prob = logits.sigmoid() neg_prob = 1 - pos_prob probas = torch.cat([pos_prob, neg_prob], dim=1) else: true_1_hot = torch.eye(num_classes)[mask.squeeze(1)] true_1_hot = true_1_hot.permute(0, 3, 1, 2).float() probas = logits.softmax(dim=1) true_1_hot = true_1_hot.type(logits.type()) dims = (0, ) + tuple(range(2, mask.ndimension())) intersection = torch.sum(probas * true_1_hot, dims) cardinality = torch.sum(probas + true_1_hot, dims) union = cardinality - intersection jacc_loss = (intersection / (union + torch.finfo(logits.dtype))).mean() return 1.0 - jacc_loss
def calc_loss(self, q_values: Tensor, target_q_values: Tensor, actions: Tensor, rewards: Tensor, done_mask: Tensor, state: Tensor, next_state: Tensor) -> Tensor: """ Calculate the MSE loss of this step. The loss for an example is defined as: Q_samp(s) = r if done = r + gamma * max_a' Q_target(s', a') otherwise loss = (Q_samp(s) - Q(s, a))^2 Args: q_values: (torch tensor) shape = (batch_size, num_actions) The Q-values that your current network estimates (i.e. Q(s, a') for all a') target_q_values: (torch tensor) shape = (batch_size, num_actions) The Target Q-values that your target network estimates (i.e. (i.e. Q_target(s', a') for all a') actions: (torch tensor) shape = (batch_size,) The actions that you actually took at each step (i.e. a) rewards: (torch tensor) shape = (batch_size,) The rewards that you actually got at each step (i.e. r) done_mask: (torch tensor) shape = (batch_size,) A boolean mask of examples where we reached the terminal state Hint: You may find the following functions useful - torch.max - torch.sum - torch.nn.functional.one_hot - torch.nn.functional.mse_loss You can treat `done_mask` as a 0 and 1 where 0 is not done and 1 is done using torch.type as done below To extract Q(a) for a specific "a" you can use the torch.sum and torch.nn.functional.one_hot. Think about how. """ # you may need this variable num_actions = self.env.action_space.n gamma = self.config.gamma done_mask = done_mask.type(torch.int) actions = actions.type(torch.int64) ############################################################## ##################### YOUR CODE HERE - 3-5 lines ############# ''' # This is the vanilla DQN Loss function. The uncommented code is the DDQN Loss function best_target_q = torch.reshape(torch.max(target_q_values, dim=1, keepdim=True).values, (-1,)) Q_samp = rewards + (1 - done_mask) * gamma * best_target_q Q_sa = torch.sum(q_values * torch.nn.functional.one_hot(actions, self.env.action_space.n), dim=1) loss = torch.nn.functional.mse_loss(Q_samp, Q_sa)''' state = state.to('cuda:0') next_state = next_state.to('cuda:0') actions = actions.to('cuda:0') rewards = rewards.to('cuda:0') done_mask = done_mask.to('cuda:0') actions = actions.unsqueeze(-1) state_action_vals = self.get_q_values(state, 'q_network').gather(1, actions) state_action_vals = state_action_vals.squeeze(-1) next_state_action = self.get_q_values(next_state, 'q_network').max(1)[1] next_state_action = next_state_action.unsqueeze(-1) next_state_vals = self.get_q_values(next_state, 'target').gather( 1, next_state_action).squeeze(-1) exp_sa_vals = next_state_vals.detach() * gamma * (1 - done_mask) + rewards loss = torch.nn.functional.mse_loss(state_action_vals, exp_sa_vals) ############################################################## ######################## END YOUR CODE ####################### return loss