def forward(self, input, target): """ Args: input (torch.tensor): embeddings predicted by the network (NxExDxHxW) (E - embedding dims) target (torch.tensor): ground truth instance segmentation (NxDxHxW) Returns: Combined loss defined as: alpha * variance_term + beta * distance_term + gamma * regularization_term """ # get number of instances in the batch C = torch.unique(target).size()[0] # expand each label as a one-hot vector: N x D x H x W -> N x C x D x H x W target = expand_as_one_hot(target, C) # compare spatial dimensions assert input.dim() == target.dim() == 5 assert input.size()[2:] == target.size()[2:] # compute mean embeddings and assign embeddings to instances cluster_means, embeddings_per_instance = self._compute_cluster_means( input, target) variance_term = self._compute_variance_term(cluster_means, embeddings_per_instance, target) distance_term = self._compute_distance_term(cluster_means, C) regularization_term = self._compute_regularizer_term(cluster_means, C) # total loss loss = self.alpha * variance_term + self.beta * distance_term + self.gamma * regularization_term # reduce batch dimension return torch.mean(loss)
def forward(self, input, target, weights): assert target.size() == weights.size() # normalize the input log_probabilities = self.log_softmax(input) # standard CrossEntropyLoss requires the target to be (NxDxHxW), so we need to expand it to (NxCxDxHxW) target = expand_as_one_hot(target, C=input.size()[1], ignore_index=self.ignore_index) # expand weights weights = weights.unsqueeze(0) weights = weights.expand_as(input) # mask ignore_index if present if self.ignore_index is not None: mask = Variable(target.data.ne(self.ignore_index).float(), requires_grad=False) log_probabilities = log_probabilities * mask target = target * mask # create default class_weights if None if self.class_weights is None: class_weights = torch.ones(input.size()[1]).float().to(input.device) self.register_buffer('class_weights', class_weights) # resize class_weights to be broadcastable into the weights class_weights = self.class_weights.view(1, -1, 1, 1, 1) # multiply weights tensor by class weights weights = class_weights * weights # compute the losses result = -weights * target * log_probabilities # average the losses return result.mean()
def __call__(self, input, target): """ :param input: 5D probability maps torch float tensor (NxCxDxHxW) :param target: 4D or 5D ground truth torch tensor. 4D (NxDxHxW) tensor will be expanded to 5D as one-hot :return: intersection over union averaged over all channels """ assert input.dim() == 5 n_classes = input.size()[1] if target.dim() == 4: target = expand_as_one_hot(target, C=n_classes, ignore_index=self.ignore_index) assert input.size() == target.size() per_batch_iou = [] for _input, _target in zip(input, target): binary_prediction = self._binarize_predictions(_input, n_classes) if self.ignore_index is not None: # zero out ignore_index mask = _target == self.ignore_index binary_prediction[mask] = 0 _target[mask] = 0 # convert to uint8 just in case binary_prediction = binary_prediction.byte() _target = _target.byte() per_channel_iou = [] for c in range(n_classes): if c in self.skip_channels: continue per_channel_iou.append( self._jaccard_index(binary_prediction[c], _target[c])) assert per_channel_iou, "All channels were ignored from the computation" mean_iou = torch.mean(torch.tensor(per_channel_iou)) per_batch_iou.append(mean_iou) return torch.mean(torch.tensor(per_batch_iou))