def _focal_loss(p: torch.Tensor, t: torch.Tensor, gamma: float, loss_val: torch.Tensor, reduction: str): """ Focal loss helper function Parameters ---------- p: torch.Tensor the prediction tensor t : torch.Tensor the target tensor gamma : float focusing parameter loss_val : torch.Tensor value coming from the previous (weighted) loss function reduction : str reduction parameter Returns ------- torch.Tensor loss value Raises ------ ValueError invalid reduction parameter """ n_classes = p.size(1) target_onehot = one_hot_batch(t.unsqueeze(1), num_classes=n_classes) return _general_focal_loss(p=p, t=target_onehot, gamma=gamma, loss_val=loss_val, reduction=reduction, alpha_weight=1.)
def test_one_hot_batch_2dim(self): inp = torch.zeros(1, 1, 3, 3).long() inp[0, 0, 0, 0] = 1 outp = one_hot_batch(inp) exp = torch.zeros(1, 2, 3, 3).long() exp[0, 0] = 1 exp[0, 0, 0, 0] = 0 exp[0, 1, 0, 0] = 1 self.assertTrue((outp == exp).all())
def forward(self, imgs, labels=None, z=None, code=None): """ Forwards a single set of batches through the network Parameters ---------- imgs : :class:`torch.Tensor` the image batch labels : :class:`torch.Tensor` the labels batch, will be sampled if not given z : :class:`torch.Tensor` the noise batch, will be sampled if not given code : :class:`torch.Tensor` the code batch, will be sampled if not given Returns ------- dict a dictionary containing all (intermediate) results for loss calculation and training """ if z is None: z = torch.randn(imgs.size(0), self._latent_dim, device=imgs.device, dtype=imgs.dtype) if labels is None: labels = torch.randint(self._n_classes, (imgs.size(0), 1), device=imgs.device, dtype=torch.long) if labels.size(-1) != self._n_classes: labels = one_hot_batch(labels.unsqueeze(1), num_classes=self._n_classes) if code is None: code = torch.empty(imgs.size(0), self._code_dim, device=imgs.device, dtype=imgs.dtype) code.uniform_(-1, 1) gen_imgs = self.generator(z, labels, code) validity_real, _, _ = self.discriminator(imgs) validity_fake, labels_fake, code_fake = self.discriminator(gen_imgs) return { "validity_real": validity_real, "validity_fake": validity_fake, "labels_real": labels, "labels_fake": labels_fake, "code_real": code, "code_fake": code_fake, "gen_imgs": gen_imgs, }
def test_one_hot_batch_2dim(self): for dtype, expected_dtype in zip([None, torch.float], [torch.long, torch.float]): with self.subTest(dtype=dtype, expected_dtype=expected_dtype): inp = torch.zeros(1, 1, 3, 3).long() inp[0, 0, 0, 0] = 1 outp = one_hot_batch(inp, dtype=dtype) exp = torch.zeros(1, 2, 3, 3).long() exp[0, 0] = 1 exp[0, 0, 0, 0] = 0 exp[0, 1, 0, 0] = 1 self.assertTrue((outp == exp).all()) self.assertEqual(outp.dtype, expected_dtype)
def test_one_hot_batch_1dim(self): inp = torch.tensor([0, 1, 2]).long() outp = one_hot_batch(inp) exp = torch.eye(3) self.assertTrue((outp == exp).all()) self.assertTrue(outp.dtype == torch.long)
def test_one_hot_batch_float_error(self): with self.assertRaises(TypeError): inp = torch.zeros(1, 1, 3, 3).float() one_hot_batch(inp)
def tversky_loss(predictions: torch.Tensor, targets: torch.Tensor, alpha: float = 0.5, beta: float = 0.5, weight: torch.Tensor = None, non_lin: Callable = None, square_nom: bool = False, square_denom: bool = False, smooth: float = 1., reduction: str = 'elementwise_mean') -> torch.Tensor: """ Calculates the tversky loss Parameters ---------- predictions : torch.Tensor the predicted segmentation (of shape NxCx(Dx)HxW) targets : torch.Tensor the groundtruth segmentation (of shape Nx(Dx)HxW alpha : float scaling factor for false negatives beta : float scaling factor for false positives weight : torch.Tensor weighting factors for each class non_lin : Callable a non linearity to apply on the predictions before calculating the loss value square_nom : bool whether to square the nominator square_denom : bool whether to square the denominator smooth : float smoothing value (to avid divisions by 0) reduction : str kind of reduction to apply to the final loss Returns ------- torch.Tensor reduced loss value """ n_classes = predictions.shape[1] dims = tuple(range(2, predictions.dim())) if non_lin is not None: predictions = non_lin(predictions) target_onehot = one_hot_batch(targets.unsqueeze(1), num_classes=n_classes) target_onehot = target_onehot.float() tp = predictions * target_onehot fp = predictions * (1 - target_onehot) fn = (1 - predictions) * target_onehot if square_nom: tp = tp ** 2 if square_denom: fp = fp ** 2 fn = fn ** 2 # compute nominator tp_sum = torch.sum(tp, dim=dims) nom = tp_sum + smooth # compute denominator denom = tp_sum + alpha * torch.sum(fn, dim=dims) + \ beta * torch.sum(fp, dim=dims) + smooth # compute loss frac = nom / denom # apply weights to individual classes if weight is not None: frac = weight * frac # average over classes frac = 1 - torch.mean(frac, dim=1) return reduce(-frac, reduction)
def soft_dice_loss(predictions: torch.Tensor, targets: torch.Tensor, weight: torch.Tensor = None, non_lin: Callable = None, square_nom: bool = False, square_denom: bool = False, smooth: float = 1., reduction: str = 'elementwise_mean') -> torch.Tensor: """ Calculates the soft dice loss Parameters ---------- predictions : torch.Tensor the predicted segmentation (of shape NxCx(Dx)HxW) targets : torch.Tensor the groundtruth segmentation (of shape Nx(Dx)HxW weight : torch.Tensor weighting factors for each class non_lin : Callable a non linearity to apply on the predictions before calculating the loss value square_nom : bool whether to square the nominator square_denom : bool whether to square the denominator smooth : float smoothing value (to avid divisions by 0) reduction : str kind of reduction to apply to the final loss Returns ------- torch.Tensor reduced loss value """ # number of classes for onehot n_classes = predictions.shape[1] with torch.no_grad(): targets_onehot = one_hot_batch(targets.unsqueeze(1), num_classes=n_classes) # sum over spatial dimensions dims = tuple(range(2, predictions.dim())) # apply nonlinearity if non_lin is not None: predictions = non_lin(predictions) # compute nominator if square_nom: nom = torch.sum((predictions * targets_onehot.float())**2, dim=dims) else: nom = torch.sum(predictions * targets_onehot.float(), dim=dims) nom = 2 * nom + smooth # compute denominator if square_denom: i_sum = torch.sum(predictions**2, dim=dims) t_sum = torch.sum(targets_onehot**2, dim=dims) else: i_sum = torch.sum(predictions, dim=dims) t_sum = torch.sum(targets_onehot, dim=dims) denom = i_sum + t_sum.float() + smooth # compute loss frac = nom / denom # apply weight for individual classesproperly if weight is not None: frac = weight * frac # average over classes frac = -torch.mean(frac, dim=1) return reduce(frac, reduction)