def focal_loss( input: torch.Tensor, target: torch.Tensor, alpha: float, gamma: float = 2.0, reduction: str = 'mean', ) -> torch.Tensor: r"""Function that computes Focal loss. See :class:`fastreid.modeling.losses.FocalLoss` for details. """ if not torch.is_tensor(input): raise TypeError("Input type is not a torch.Tensor. Got {}" .format(type(input))) if not len(input.shape) >= 2: raise ValueError("Invalid input shape, we expect BxCx*. Got: {}" .format(input.shape)) if input.size(0) != target.size(0): raise ValueError('Expected input batch_size ({}) to match target batch_size ({}).' .format(input.size(0), target.size(0))) n = input.size(0) out_size = (n,) + input.size()[2:] if target.size()[1:] != input.size()[2:]: raise ValueError('Expected target size {}, got {}'.format( out_size, target.size())) if not input.device == target.device: raise ValueError( "input and target must be in the same device. Got: {}".format( input.device, target.device)) # compute softmax over the classes axis input_soft = F.softmax(input, dim=1) # create the labels one hot tensor target_one_hot = one_hot( target, num_classes=input.shape[1], dtype=input.dtype) # compute the actual focal loss weight = torch.pow(-input_soft + 1., gamma) focal = -alpha * weight * torch.log(input_soft) loss_tmp = torch.sum(target_one_hot * focal, dim=1) if reduction == 'none': loss = loss_tmp elif reduction == 'mean': loss = torch.mean(loss_tmp) elif reduction == 'sum': loss = torch.sum(loss_tmp) else: raise NotImplementedError("Invalid reduction mode: {}" .format(reduction)) return loss
def forward(self, features, targets): sim_mat = F.linear(F.normalize(features), F.normalize(self.weight)) alpha_p = F.relu(-sim_mat.detach() + 1 + self._m) alpha_n = F.relu(sim_mat.detach() + self._m) delta_p = 1 - self._m delta_n = self._m s_p = self._s * alpha_p * (sim_mat - delta_p) s_n = self._s * alpha_n * (sim_mat - delta_n) targets = one_hot(targets, self._num_classes) pred_class_logits = targets * s_p + (1.0 - targets) * s_n return pred_class_logits
def forward(self, features, targets): # get cos(theta) cosine = F.linear(F.normalize(features), F.normalize(self.weight)) # add margin theta = torch.acos(torch.clamp(cosine, -1.0 + 1e-7, 1.0 - 1e-7)) phi = torch.cos(theta + self._m) # --------------------------- convert label to one-hot --------------------------- targets = one_hot(targets, self._num_classes) pred_class_logits = targets * phi + (1.0 - targets) * cosine # logits re-scale pred_class_logits *= self._s return pred_class_logits