Esempio n. 1
0
def test_get_class_weights() -> None:
    """
    Test get_class_weights for segmentation models.
    """
    # 5 voxels are class 0, 3 are class 1, none are class 2.
    target = torch.zeros((2, 3, 4))
    target[0, 1, 0] = 1
    target[0, 0, 1] = 1
    target[0, 0, 2] = 1
    target[0, 0, 3] = 1
    target[1, 0, 0] = 1
    target[1, 0, 1] = 1
    target[1, 1, 2] = 1
    target[1, 1, 3] = 1
    voxel_counts = target.sum((0, 2))
    voxel_counts[voxel_counts ==
                 0] = 1  # empty classes are treated as if they had one voxel
    # noinspection PyTypeChecker
    inverses: torch.Tensor = 1.0 / voxel_counts  # type: ignore
    counts = get_class_weights(target, class_weight_power=1.0)
    expected = target.shape[1] * inverses / inverses.sum()
    assert torch.allclose(expected, counts, atol=0.001)
    counts = get_class_weights(target, class_weight_power=2.0)
    inverses *= inverses
    expected = target.shape[1] * inverses / inverses.sum()
    assert torch.allclose(expected, counts, atol=0.001)
Esempio n. 2
0
    def forward_minibatch(self, output: torch.Tensor, target: torch.Tensor, **kwargs: Any) -> torch.Tensor:
        """
        Wrapper for multi-class cross entropy function implemented in PyTorch.
        The implementation supports tensors with arbitrary spatial dimension.
        Input logits are normalised internally in `F.cross_entropy` function.
        :param output: Class logits (unnormalised), e.g. in 3D : BxCxWxHxD  or in 1D BxC
        :param target: Target labels encoded in one-hot representation, e.g. in 3D BxCxWxHxD or in 1D BxC
        """

        # Convert one hot-encoded target to class indices
        class_weight = None

        # Check input tensors
        self._verify_inputs(output, target)

        # Determine class weights for unbalanced datasets
        if self.class_weight_power is not None and self.class_weight_power != 0.0:
            class_weight = get_class_weights(target, class_weight_power=self.class_weight_power)

            # Compute negative log-likelihood
        log_prob = F.log_softmax(output, dim=1)
        if self.smoothing_eps > 0.0:
            loss = -1.0 * log_prob * target
            if class_weight is not None:
                loss = torch.einsum('bc...,c->b...', loss, class_weight)
        else:
            loss = F.nll_loss(log_prob, torch.argmax(target, dim=1), weight=class_weight, reduction='none')

        # If focal loss is specified, apply pixel weighting
        if self.focal_loss_gamma is not None:
            pixel_weights = self.get_focal_loss_pixel_weights(output, target)
            loss = loss * pixel_weights

        return torch.mean(loss)
    def forward_minibatch(self, output: torch.Tensor, target: torch.Tensor,
                          **kwargs: Any) -> torch.Tensor:
        """
        Computes the forward pass of soft-dice loss. It assumes the output and target have Batch x Classes x ...
        dimensions, with the last dimensions being an arbitrary number of spatial dimensions.

        :param output: The output of the network.
        :param target: The target of the network.
        :return: The soft-dice loss.
        :raises ValueError: If the shape of the tensors is incorrect.
        :raises TypeError: If output or target are not torch.tensors.
        """
        # Check Types
        if not torch.is_tensor(output) or not torch.is_tensor(target):
            raise TypeError(
                "Output and target must be torch.Tensors (type(output): {}, type(target): {})"
                .format(type(output), type(target)))

        # Check dimensions
        if len(output.shape) < 3:
            raise ValueError(
                "The shape of the output and target must be at least 3, Batch x Class x ... "
                "(output.shape: {})".format(output.shape))

        if output.shape != target.shape:
            raise ValueError(
                "The output and target must have the same shape (output.shape: {}, target.shape: {})"
                .format(output.shape, target.shape))

        if self.apply_softmax:
            output = torch.nn.functional.softmax(output, dim=1)
        # Get the spatial dimensions; we'll sum numerator and denominator over these for efficiency.
        axes = list(range(2, len(output.shape)))

        # Eps is added to all products, avoiding division errors and problems
        # when a class does not exist in the current patch
        eps = torch.tensor([self.eps])
        if output.is_cuda:
            eps = eps.cuda(device=output.device)

        intersection = torch.sum(output * target + eps, axes)

        if self.class_weight_power is not None and self.class_weight_power != 0.0:
            # Multiply target by the class weight.
            class_weights = get_class_weights(target, self.class_weight_power)
            # noinspection PyTypeChecker
            intersection = torch.einsum("ij,j->ij", intersection,
                                        class_weights)

        output_sum_square = torch.sum(output * output + eps, axes)
        target_sum_square = torch.sum(target * target + eps, axes)
        sum_squares = output_sum_square + target_sum_square

        # Average per Batch and Class
        # noinspection PyTypeChecker
        return 1.0 - 2.0 * torch.mean(
            intersection / sum_squares)  # type: ignore
Esempio n. 4
0
    def forward_minibatch(self, output: torch.Tensor, target: torch.Tensor, **kwargs: Any) -> torch.Tensor:
        """
        Wrapper for multi-class cross entropy function implemented in PyTorch.
        The implementation supports tensors with arbitrary spatial dimension.
        Input logits are normalised internally in `F.cross_entropy` function.
        :param output: Class logits (unnormalised), e.g. in 3D : BxCxWxHxD  or in 1D BxC
        :param target: Target labels encoded in one-hot representation, e.g. in 3D BxCxWxHxD or in 1D BxC
        """

        # Convert one hot-encoded target to class indices
        class_weight = None

        # Check input tensors
        self._verify_inputs(output, target)

        # Compute log posterior probabilities
        log_prob = F.log_softmax(output, 1)

        axes = list(range(2, len(output.shape)))

        # Determine pixel/class weights for unbalanced datasets
        if self.focal_loss_gamma is not None:
            pixel_weights = self.get_focal_loss_pixel_weights(output, target)
            log_prob = log_prob * pixel_weights
        if self.class_weight_power is not None and self.class_weight_power != 0.0:
            # class_weight is shape (C).
            class_weight = get_class_weights(target, class_weight_power=self.class_weight_power)

        # product: shape (B, C), value is total log prob of target class over all voxels in that batch that have
        # that class.
        product = target * log_prob
        if axes:
            product = product.sum(axes)
        # loss will be shape (B):
        if class_weight is not None:
            # noinspection PyTypeChecker
            loss = -torch.einsum("ij,j->i", product, class_weight)
        else:
            loss = -product.sum(dim=1)
        if output.shape[2:]:
            voxels_per_batch = reduce(lambda x, y: x * y, output.shape[2:])
        else:
            voxels_per_batch = 1
        return loss.sum() / (output.shape[0] * voxels_per_batch)  # mean over all voxels in all batches