Ejemplo n.º 1
0
def _focal_loss(p: torch.Tensor, t: torch.Tensor, gamma: float,
                loss_val: torch.Tensor, reduction: str):
    """
    Focal loss helper function
    Parameters
    ----------
    p: torch.Tensor
        the prediction tensor
    t : torch.Tensor
        the target tensor
    gamma : float
        focusing parameter
    loss_val : torch.Tensor
        value coming from the previous (weighted) loss function
    reduction : str
        reduction parameter

    Returns
    -------
    torch.Tensor
        loss value

    Raises
    ------
    ValueError
        invalid reduction parameter

    """
    n_classes = p.size(1)
    target_onehot = one_hot_batch(t.unsqueeze(1), num_classes=n_classes)
    return _general_focal_loss(p=p, t=target_onehot, gamma=gamma,
                               loss_val=loss_val, reduction=reduction,
                               alpha_weight=1.)
Ejemplo n.º 2
0
 def test_one_hot_batch_2dim(self):
     inp = torch.zeros(1, 1, 3, 3).long()
     inp[0, 0, 0, 0] = 1
     outp = one_hot_batch(inp)
     exp = torch.zeros(1, 2, 3, 3).long()
     exp[0, 0] = 1
     exp[0, 0, 0, 0] = 0
     exp[0, 1, 0, 0] = 1
     self.assertTrue((outp == exp).all())
Ejemplo n.º 3
0
    def forward(self, imgs, labels=None, z=None, code=None):
        """
        Forwards a single set of batches through the network

        Parameters
        ----------
        imgs : :class:`torch.Tensor`
            the image batch
        labels : :class:`torch.Tensor`
            the labels batch, will be sampled if not given
        z : :class:`torch.Tensor`
            the noise batch, will be sampled if not given
        code : :class:`torch.Tensor`
            the code batch, will be sampled if not given

        Returns
        -------
        dict
            a dictionary containing all (intermediate) results for loss
            calculation and training

        """

        if z is None:
            z = torch.randn(imgs.size(0), self._latent_dim, device=imgs.device,
                            dtype=imgs.dtype)

        if labels is None:
            labels = torch.randint(self._n_classes,
                                   (imgs.size(0), 1),
                                   device=imgs.device,
                                   dtype=torch.long)

        if labels.size(-1) != self._n_classes:
            labels = one_hot_batch(labels.unsqueeze(1),
                                   num_classes=self._n_classes)

        if code is None:
            code = torch.empty(imgs.size(0), self._code_dim,
                               device=imgs.device, dtype=imgs.dtype)
            code.uniform_(-1, 1)

        gen_imgs = self.generator(z, labels, code)

        validity_real, _, _ = self.discriminator(imgs)
        validity_fake, labels_fake, code_fake = self.discriminator(gen_imgs)

        return {
            "validity_real": validity_real, "validity_fake": validity_fake,
            "labels_real": labels, "labels_fake": labels_fake,
            "code_real": code, "code_fake": code_fake,
            "gen_imgs": gen_imgs,
        }
Ejemplo n.º 4
0
 def test_one_hot_batch_2dim(self):
     for dtype, expected_dtype in zip([None, torch.float],
                                      [torch.long, torch.float]):
         with self.subTest(dtype=dtype, expected_dtype=expected_dtype):
             inp = torch.zeros(1, 1, 3, 3).long()
             inp[0, 0, 0, 0] = 1
             outp = one_hot_batch(inp, dtype=dtype)
             exp = torch.zeros(1, 2, 3, 3).long()
             exp[0, 0] = 1
             exp[0, 0, 0, 0] = 0
             exp[0, 1, 0, 0] = 1
             self.assertTrue((outp == exp).all())
             self.assertEqual(outp.dtype, expected_dtype)
Ejemplo n.º 5
0
 def test_one_hot_batch_1dim(self):
     inp = torch.tensor([0, 1, 2]).long()
     outp = one_hot_batch(inp)
     exp = torch.eye(3)
     self.assertTrue((outp == exp).all())
     self.assertTrue(outp.dtype == torch.long)
Ejemplo n.º 6
0
 def test_one_hot_batch_float_error(self):
     with self.assertRaises(TypeError):
         inp = torch.zeros(1, 1, 3, 3).float()
         one_hot_batch(inp)
Ejemplo n.º 7
0
def tversky_loss(predictions: torch.Tensor, targets: torch.Tensor,
                 alpha: float = 0.5, beta: float = 0.5,
                 weight: torch.Tensor = None,
                 non_lin: Callable = None, square_nom: bool = False,
                 square_denom: bool = False, smooth: float = 1.,
                 reduction: str = 'elementwise_mean') -> torch.Tensor:
    """
    Calculates the tversky loss

    Parameters
    ----------
    predictions : torch.Tensor
        the predicted segmentation (of shape NxCx(Dx)HxW)
    targets : torch.Tensor
        the groundtruth segmentation (of shape Nx(Dx)HxW
    alpha : float
        scaling factor for false negatives
    beta : float
        scaling factor for false positives
    weight : torch.Tensor
        weighting factors for each class
    non_lin : Callable
        a non linearity to apply on the predictions before calculating
        the loss value
    square_nom : bool
        whether to square the nominator
    square_denom : bool
        whether to square the denominator
    smooth : float
        smoothing value (to avid divisions by 0)
    reduction : str
        kind of reduction to apply to the final loss

    Returns
    -------
    torch.Tensor
        reduced loss value

    """
    n_classes = predictions.shape[1]
    dims = tuple(range(2, predictions.dim()))

    if non_lin is not None:
        predictions = non_lin(predictions)

    target_onehot = one_hot_batch(targets.unsqueeze(1), num_classes=n_classes)
    target_onehot = target_onehot.float()

    tp = predictions * target_onehot
    fp = predictions * (1 - target_onehot)
    fn = (1 - predictions) * target_onehot

    if square_nom:
        tp = tp ** 2
    if square_denom:
        fp = fp ** 2
        fn = fn ** 2

    # compute nominator
    tp_sum = torch.sum(tp, dim=dims)
    nom = tp_sum + smooth

    # compute denominator
    denom = tp_sum + alpha * torch.sum(fn, dim=dims) + \
        beta * torch.sum(fp, dim=dims) + smooth

    # compute loss
    frac = nom / denom

    # apply weights to individual classes
    if weight is not None:
        frac = weight * frac

    # average over classes
    frac = 1 - torch.mean(frac, dim=1)
    return reduce(-frac, reduction)
Ejemplo n.º 8
0
def soft_dice_loss(predictions: torch.Tensor,
                   targets: torch.Tensor,
                   weight: torch.Tensor = None,
                   non_lin: Callable = None,
                   square_nom: bool = False,
                   square_denom: bool = False,
                   smooth: float = 1.,
                   reduction: str = 'elementwise_mean') -> torch.Tensor:
    """
    Calculates the soft dice loss

    Parameters
    ----------
    predictions : torch.Tensor
        the predicted segmentation (of shape NxCx(Dx)HxW)
    targets : torch.Tensor
        the groundtruth segmentation (of shape Nx(Dx)HxW
    weight : torch.Tensor
        weighting factors for each class
    non_lin : Callable
        a non linearity to apply on the predictions before calculating
        the loss value
    square_nom : bool
        whether to square the nominator
    square_denom : bool
        whether to square the denominator
    smooth : float
        smoothing value (to avid divisions by 0)
    reduction : str
        kind of reduction to apply to the final loss

    Returns
    -------
    torch.Tensor
        reduced loss value

    """
    # number of classes for onehot
    n_classes = predictions.shape[1]
    with torch.no_grad():
        targets_onehot = one_hot_batch(targets.unsqueeze(1),
                                       num_classes=n_classes)
    # sum over spatial dimensions
    dims = tuple(range(2, predictions.dim()))

    # apply nonlinearity
    if non_lin is not None:
        predictions = non_lin(predictions)

    # compute nominator
    if square_nom:
        nom = torch.sum((predictions * targets_onehot.float())**2, dim=dims)
    else:
        nom = torch.sum(predictions * targets_onehot.float(), dim=dims)
    nom = 2 * nom + smooth

    # compute denominator
    if square_denom:
        i_sum = torch.sum(predictions**2, dim=dims)
        t_sum = torch.sum(targets_onehot**2, dim=dims)
    else:
        i_sum = torch.sum(predictions, dim=dims)
        t_sum = torch.sum(targets_onehot, dim=dims)

    denom = i_sum + t_sum.float() + smooth

    # compute loss
    frac = nom / denom

    # apply weight for individual classesproperly
    if weight is not None:
        frac = weight * frac

    # average over classes
    frac = -torch.mean(frac, dim=1)

    return reduce(frac, reduction)