示例#1
0
 def _update_histograms(self, inputs, target, gen_pred):
     self.custom_variables["Generated Intensity Histogram"] = flatten(
         gen_pred.cpu().detach())
     self.custom_variables["Input Intensity Histogram"] = flatten(
         inputs.cpu().detach())
     self.custom_variables["Per-Dataset Histograms"] = cv2.imread(
         construct_class_histogram(inputs, target, gen_pred)).transpose(
             (2, 0, 1))
     self.custom_variables[
         "Background Generated Intensity Histogram"] = gen_pred[torch.where(
             target[IMAGE_TARGET] == 0)].cpu().detach()
     self.custom_variables["CSF Generated Intensity Histogram"] = gen_pred[
         torch.where(target[IMAGE_TARGET] == 1)].cpu().detach()
     self.custom_variables["GM Generated Intensity Histogram"] = gen_pred[
         torch.where(target[IMAGE_TARGET] == 2)].cpu().detach()
     self.custom_variables["WM Generated Intensity Histogram"] = gen_pred[
         torch.where(target[IMAGE_TARGET] == 3)].cpu().detach()
     self.custom_variables["Background Input Intensity Histogram"] = inputs[
         torch.where(target[IMAGE_TARGET] == 0)].cpu().detach()
     self.custom_variables["CSF Input Intensity Histogram"] = inputs[
         torch.where(target[IMAGE_TARGET] == 1)].cpu().detach()
     self.custom_variables["GM Input Intensity Histogram"] = inputs[
         torch.where(target[IMAGE_TARGET] == 2)].cpu().detach()
     self.custom_variables["WM Input Intensity Histogram"] = inputs[
         torch.where(target[IMAGE_TARGET] == 3)].cpu().detach()
示例#2
0
    def forward(self, inputs: torch.Tensor, targets: torch.Tensor):
        """
           Computes the Tversky loss based on https://arxiv.org/pdf/1706.05721.pdf
           Note that PyTorch optimizers minimize a loss. In this case, we would like to maximize the dice loss so we
           return the negated dice loss.
           Args:
               inputs (:obj:`torch.Tensor`) : A tensor of shape (B, C, ..). The model prediction on which the loss has to
                be computed.
               targets (:obj:`torch.Tensor`) : A tensor of shape (B, C, ..). The ground truth.
           Returns:
               :obj:`torch.Tensor`: The Tversky loss for each class or reduced according to reduction method.
           """
        if not inputs.size() == targets.size():
            raise ValueError(
                "'Inputs' and 'Targets' must have the same shape.")

        inputs = flatten(inputs)
        targets = flatten(targets).float()
        ones = torch.Tensor().new_ones((inputs.size()),
                                       dtype=torch.float,
                                       device=inputs.device)

        P_G = (inputs * targets).sum(-1)
        if self.weight is not None:
            P_G = self.weight * P_G

        P_NG = (inputs * (ones - targets)).sum(-1)
        NP_G = ((ones - inputs) * targets).sum(-1)

        ones = torch.Tensor().new_ones((inputs.size(0), ),
                                       dtype=torch.float,
                                       device=inputs.device)
        tversky = P_G / (P_G + self._alpha * P_NG + self._beta * NP_G +
                         EPSILON)

        tversky_loss = ones - tversky

        if self._ignore_index != -100:

            def ignore_index_fn(tversky_vector):
                try:
                    indices = list(range(len(tversky_vector)))
                    indices.remove(self._ignore_index)
                    return tversky_vector[indices]
                except ValueError as e:
                    raise IndexError(
                        "'ignore_index' must be non-negative, and lower than the number of classes in confusion matrix, but {} was given. "
                        .format(self._ignore_index))

            tversky_loss = MetricsLambda(ignore_index_fn,
                                         tversky_loss).compute()

        if self.reduction == "mean":
            tversky_loss = tversky_loss.mean()

        return tversky_loss
示例#3
0
def mean_hausdorff_distance(seg_pred, target):
    distances = np.zeros((4, ))
    for channel in range(seg_pred.size(1)):
        distances[channel] = max(
            directed_hausdorff(
                flatten(seg_pred[:, channel, ...]).cpu().detach().numpy(),
                flatten(target[:, channel, ...]).cpu().detach().numpy())[0],
            directed_hausdorff(
                flatten(target[:, channel, ...]).cpu().detach().numpy(),
                flatten(seg_pred[:, channel, ...]).cpu().detach().numpy())[0])
    return distances
示例#4
0
    def forward(self, inputs: torch.Tensor, targets: torch.Tensor):
        """
        Computes the Sørensen–Dice loss.
        Note that PyTorch optimizers minimize a loss. In this case, we would like to maximize the dice loss so we
        return the negated dice loss.
        Args:
            inputs (:obj:`torch.Tensor`) : A tensor of shape (B, C, ..). The model prediction on which the loss has to
             be computed.
            targets (:obj:`torch.Tensor`) : A tensor of shape (B, C, ..). The ground truth.
        Returns:
            :obj:`torch.Tensor`: The Sørensen–Dice loss for each class or reduced according to reduction method.
        """
        if not inputs.size() == targets.size():
            raise ValueError(
                "'Inputs' and 'Targets' must have the same shape.")

        inputs = flatten(inputs)
        targets = flatten(targets).float()

        # Compute per channel Dice Coefficient
        intersection = (inputs * targets).sum(-1)

        if self.weight is not None:
            intersection = self.weight * intersection

        cardinality = (inputs + targets).sum(-1)

        ones = torch.Tensor().new_ones((inputs.size(0), ),
                                       dtype=torch.float,
                                       device=inputs.device)

        dice = ones - (2.0 * intersection / cardinality.clamp(min=EPSILON))

        if self._ignore_index != -100:

            def ignore_index_fn(dice_vector):
                try:
                    indices = list(range(len(dice_vector)))
                    indices.remove(self._ignore_index)
                    return dice_vector[indices]
                except ValueError as e:
                    raise IndexError(
                        "'ignore_index' must be non-negative, and lower than the number of classes in confusion matrix, but {} was given. "
                        .format(self._ignore_index))

            dice = MetricsLambda(ignore_index_fn, dice).compute()

        if self.reduction == "mean":
            dice = dice.mean()

        return dice
示例#5
0
 def _update_histograms(self, inputs, target):
     self.custom_variables["Input Intensity Histogram"] = flatten(inputs.cpu().detach())
     self.custom_variables["Background Input Intensity Histogram"] = inputs[
         torch.where(target[IMAGE_TARGET] == 0)].cpu().detach()
     self.custom_variables["CSF Input Intensity Histogram"] = inputs[
         torch.where(target[IMAGE_TARGET] == 1)].cpu().detach()
     self.custom_variables["GM Input Intensity Histogram"] = inputs[
         torch.where(target[IMAGE_TARGET] == 2)].cpu().detach()
     self.custom_variables["WM Input Intensity Histogram"] = inputs[
         torch.where(target[IMAGE_TARGET] == 3)].cpu().detach()
示例#6
0
 def compute_class_weights(inputs: torch.Tensor):
     """
     Compute weights for each class as described in https://arxiv.org/pdf/1707.03237.pdf
     Args:
         inputs: (:obj:`torch.Tensor`): A tensor of shape (B, C, ..). The model's prediction on which the loss has to be computed.
     Returns:
         :obj:`torch.Tensor`: A tensor containing class weights.
     """
     flattened_inputs = flatten(inputs)
     class_weights = (flattened_inputs.shape[1] -
                      flattened_inputs.sum(-1)) / flattened_inputs.sum(-1)
     return class_weights
示例#7
0
    def update(self, output: Tuple[torch.Tensor, torch.Tensor]) -> None:
        """
        Update the confusion matrix with output values.
        Args:
            output (tuple of :obj:`torch.Tensor`): A tuple containing predictions and ground truth of the form `(y_pred, y)`.
        """
        flattened_targets = flatten(to_onehot(output[1],
                                              output[0].size(1))).double()
        ones = torch.Tensor().new_ones((output[0].size(1), ),
                                       dtype=torch.double,
                                       device=flattened_targets.device)
        weights = ones / torch.pow(flattened_targets.sum(-1),
                                   2).clamp(min=EPSILON)

        self._metric = self.create_generalized_dice_metric(self._cm, weights)

        if self._reduction == "mean":
            self._metric = self._metric.mean()
        elif self._reduction is None:
            pass
        else:
            raise NotImplementedError("Reduction method not implemented.")

        self._cm.update(output)