예제 #1
0
    def forward(self, input, target):
        """
        Expand to one hot added extra for consistency reasons
        """
        # import ipdb;ipdb.set_trace()
        target = expand_as_one_hot(target.long(), self.classes)

        assert input.dim() == target.dim() == 5, "'input' and 'target' have different number of dims"

        if self.skip_index_after is not None:
            before_size = target.size()
            target = self.skip_target_channels(target, self.skip_index_after)
            # print("Target {} after skip index {}".format(before_size, target.size()))

        assert input.size() == target.size(), "'input' and 'target' must have the same shape"
        # get probabilities from logits
        input = self.normalization(input)

        # compute per channel Dice coefficient
        per_channel_dice = self.dice(input, target, weight=self.weight)

        loss = (1. - torch.mean(per_channel_dice))
        per_channel_dice = per_channel_dice.detach().cpu().numpy()

        # average Dice score across all channels/classes
        return loss, per_channel_dice
예제 #2
0
    def forward(self, input, target):
        target = expand_as_one_hot(target, self.classes)
        assert input.size() == target.size(
        ), "'input' and 'target' must have the same shape"
        l1 = super().forward(input, target)

        if self.apply_below_threshold:
            mask = target < self.threshold
        else:
            mask = target >= self.threshold

        l1[mask] = l1[mask] * self.weight

        return l1.mean()
예제 #3
0
    def forward(self, inputs, targets, weight=None):
        assert isinstance(inputs, list)
        # if there is just one output head the 'inputs' is going to be a singleton list [tensor]
        # and 'targets' is just going to be a tensor (that's how the HDF5Dataloader works)
        # so wrap targets in a list in this case
        if len(inputs) == 1:
            targets = [targets]
        assert len(inputs) == len(targets) == len(self.tags_coefficients)
        loss = 0
        for input, target, alpha in zip(inputs, targets,
                                        self.tags_coefficients):
            """
            New code here: add expand for consistency
            """
            target = expand_as_one_hot(target, self.classes)
            assert input.size() == target.size(
            ), "'input' and 'target' must have the same shape"

            loss += alpha * square_angular_loss(input, target, weight)

        return loss
예제 #4
0
 def forward(self, input, target):
     target_expanded = expand_as_one_hot(target.long(), self.classes)
     assert input.size() == target_expanded.size(), "'input' and 'target' must have the same shape"
     loss_1 = self.alpha * self.bce(input, target_expanded)
     loss_2, channel_score = self.beta * self.dice(input, target_expanded)
     return  (loss_1+loss_2) , channel_score