Ejemplo n.º 1
0
    def forward(self, predict, target):

        assert predict.size() == target.size()
        dice = DiceLoss(weight=self.weight, ignore_index=self.ignore_index)
        dice_loss = dice(predict, target)

        topk = DynamicTopKLoss(weight=self.weight, **self.kwargs)
        topk_loss = topk(predict, target)

        total_loss = topk_loss + dice_loss

        return total_loss
Ejemplo n.º 2
0
    def forward(self, predict, target):
        # print(predict.size())
        # print(target.size())
        assert predict.size() == target.size()
        dice = DiceLoss(weight=self.weight,
                        ignore_index=self.ignore_index,
                        **self.kwargs)
        dice_loss = dice(predict, target)

        ce = CrossentropyLoss(weight=self.weight)
        ce_loss = ce(predict, target)

        total_loss = ce_loss + dice_loss

        return total_loss
Ejemplo n.º 3
0
    def __init__(
        self,
        span_loss_candidates='all',
        loss_type='bce',
        dice_smooth=1e-8,
    ):
        super(CustomAdaptiveLoss, self).__init__()
        self.loss_type = loss_type
        self.span_loss_candidates = span_loss_candidates

        if self.loss_type == "bce":
            self.bce_loss = BCEWithLogitsLoss(reduction="none")
        else:
            self.dice_loss = DiceLoss(with_logits=True, smooth=dice_smooth)

        self.log_vars = nn.Parameter(torch.zeros(2))
Ejemplo n.º 4
0
    def forward(self, predict, target):

        assert isinstance(predict, list)
        assert isinstance(target, list)
        assert len(predict) == len(target) and len(predict) == 2

        dice = DiceLoss(weight=self.weight,
                        ignore_index=self.ignore_index,
                        **self.kwargs)
        dice_loss = dice(predict[0], target[0])

        bce = nn.BCEWithLogitsLoss(self.weight)
        bce_loss = bce(predict[1], target[1])

        total_loss = bce_loss + dice_loss

        return total_loss
Ejemplo n.º 5
0
    def compute_loss(self, logits, labels):
        if self.loss_type == "ce":
            loss_fct = CrossEntropyLoss()
            loss = loss_fct(logits.view(-1, 2), labels.view(-1))
        elif self.loss_type == "focal":
            loss_fct = FocalLoss(gamma=self.args.focal_gamma, reduction="mean")
            loss = loss_fct(logits.view(-1, 2), labels.view(-1))
        elif self.loss_type == "dice":
            loss_fct = DiceLoss(with_logits=True,
                                smooth=self.args.dice_smooth,
                                ohem_ratio=self.args.dice_ohem,
                                alpha=self.args.dice_alpha,
                                square_denominator=self.args.dice_square,
                                reduction="mean")
            loss = loss_fct(logits.view(-1, self.num_classes), labels)
        else:
            raise ValueError

        return loss
Ejemplo n.º 6
0
    def __init__(
        self,
        weight_start=1.,
        weight_end=1.,
        weight_span=1.,
        span_loss_candidates='all',
        loss_type='bce',
        dice_smooth=1e-8,
    ):
        super(CustomLoss, self).__init__()
        weight_sum = weight_start + weight_end + weight_span
        self.weight_start = weight_start / weight_sum
        self.weight_end = weight_end / weight_sum
        self.weight_span = weight_span / weight_sum
        self.span_loss_candidates = span_loss_candidates
        self.loss_type = loss_type

        if self.loss_type == "bce":
            self.bce_loss = BCEWithLogitsLoss(reduction="none")
        else:
            self.dice_loss = DiceLoss(with_logits=True, smooth=dice_smooth)
def _criterion():
    criterion = DiceLoss(
    )  #WeightedBinaryCrossEntropyLoss(class_frequency=True)#WBCECenterlineLoss(0.1)#MixedDiceLoss(0.1)#DiceLoss()#
    return criterion
Ejemplo n.º 8
0
def _criterion():
    criterion = DiceLoss(
    )  #WeightedBinaryCrossEntropyLoss(class_frequency=True)
    return criterion
Ejemplo n.º 9
0
    def compute_loss(self,
                     start_logits,
                     end_logits,
                     span_logits,
                     start_labels,
                     end_labels,
                     match_labels,
                     start_label_mask,
                     end_label_mask,
                     answerable_cls_logits=None,
                     answerable_cls_labels=None):
        batch_size, seq_len = start_logits.size()[0], start_logits.size()[1]
        start_float_label_mask = start_label_mask.view(-1).float()
        end_float_label_mask = end_label_mask.view(-1).float()
        match_label_row_mask = start_label_mask.bool().unsqueeze(-1).expand(
            -1, -1, seq_len)
        match_label_col_mask = end_label_mask.bool().unsqueeze(-2).expand(
            -1, seq_len, -1)
        match_label_mask = match_label_row_mask & match_label_col_mask
        # torch.triu -> returns the upper triangular part of a matrix or batch of matrces input,
        # the other elements of the result tensor are set to 0.
        # an named entity should have the start position which is smaller or equal to the end position.
        match_label_mask = torch.triu(match_label_mask,
                                      0)  # start should be less equal to end

        if self.args.span_loss_candidates == "all":
            # naive mask
            float_match_label_mask = match_label_mask.view(batch_size,
                                                           -1).float()
        else:
            # use only pred or golden start/end to compute match loss
            logits_size = start_logits.shape[-1]
            if logits_size == 1:
                start_preds, end_preds = start_logits > 0, end_logits > 0
                start_preds, end_preds = torch.squeeze(
                    start_preds, dim=-1), torch.squeeze(end_preds, dim=-1)
            elif logits_size == 2:
                start_preds, end_preds = torch.argmax(
                    start_logits, dim=-1), torch.argmax(end_logits, dim=-1)
            else:
                raise ValueError

            if self.args.span_loss_candidates == "gold":
                match_candidates = (
                    (start_labels.unsqueeze(-1).expand(-1, -1, seq_len) > 0)
                    & (end_labels.unsqueeze(-2).expand(-1, seq_len, -1) > 0))
            elif self.args.span_loss_candidates == "gold_random":
                gold_matrix = (
                    (start_labels.unsqueeze(-1).expand(-1, -1, seq_len) > 0)
                    & (end_labels.unsqueeze(-2).expand(-1, seq_len, -1) > 0))
                data_generator = torch.Generator()
                data_generator.manual_seed(self.args.seed)
                random_matrix = torch.empty(batch_size, seq_len,
                                            seq_len).uniform_(0, 1)
                random_matrix = torch.bernoulli(
                    random_matrix, generator=data_generator).long()
                random_matrix = random_matrix.cuda()
                match_candidates = torch.logical_or(gold_matrix, random_matrix)
            elif self.args.span_loss_candidates == "gold_pred":
                match_candidates = torch.logical_or(
                    (start_preds.unsqueeze(-1).expand(-1, -1, seq_len)
                     & end_preds.unsqueeze(-2).expand(-1, seq_len, -1)),
                    (start_labels.unsqueeze(-1).expand(-1, -1, seq_len)
                     & end_labels.unsqueeze(-2).expand(-1, seq_len, -1)))
            elif self.args.span_loss_candidates == "gold_pred_random":
                gold_and_pred = torch.logical_or(
                    (start_preds.unsqueeze(-1).expand(-1, -1, seq_len)
                     & end_preds.unsqueeze(-2).expand(-1, seq_len, -1)),
                    (start_labels.unsqueeze(-1).expand(-1, -1, seq_len)
                     & end_labels.unsqueeze(-2).expand(-1, seq_len, -1)))
                data_generator = torch.Generator()
                data_generator.manual_seed(self.args.seed)
                random_matrix = torch.empty(batch_size, seq_len,
                                            seq_len).uniform_(0, 1)
                random_matrix = torch.bernoulli(
                    random_matrix, generator=data_generator).long()
                random_matrix = random_matrix.cuda()
                match_candidates = torch.logical_or(gold_and_pred,
                                                    random_matrix)
            else:
                raise ValueError
            match_label_mask = match_label_mask & match_candidates
            float_match_label_mask = match_label_mask.view(batch_size,
                                                           -1).float()

        if self.loss_type == "bce":
            start_end_logits_size = start_logits.shape[-1]
            if start_end_logits_size == 1:
                loss_fct = BCEWithLogitsLoss(reduction="none")
                start_loss = loss_fct(start_logits.view(-1),
                                      start_labels.view(-1).float())
                start_loss = (start_loss * start_float_label_mask
                              ).sum() / start_float_label_mask.sum()
                end_loss = loss_fct(end_logits.view(-1),
                                    end_labels.view(-1).float())
                end_loss = (end_loss * end_float_label_mask
                            ).sum() / end_float_label_mask.sum()
            elif start_end_logits_size == 2:
                loss_fct = CrossEntropyLoss(reduction='none')
                start_loss = loss_fct(start_logits.view(-1, 2),
                                      start_labels.view(-1))
                start_loss = (start_loss * start_float_label_mask
                              ).sum() / start_float_label_mask.sum()
                end_loss = loss_fct(end_logits.view(-1, 2),
                                    end_labels.view(-1))
                end_loss = (end_loss * end_float_label_mask
                            ).sum() / end_float_label_mask.sum()
            else:
                raise ValueError

            if span_logits is not None:
                loss_fct = BCEWithLogitsLoss(reduction="mean")
                select_span_logits = torch.masked_select(
                    span_logits.view(-1),
                    match_label_mask.view(-1).bool())
                select_span_labels = torch.masked_select(
                    match_labels.view(-1),
                    match_label_mask.view(-1).bool())
                match_loss = loss_fct(select_span_logits.view(-1, 1),
                                      select_span_labels.float().view(-1, 1))
            else:
                match_loss = None

            if answerable_cls_logits is not None:
                loss_fct = BCEWithLogitsLoss(reduction="mean")
                answerable_loss = loss_fct(
                    answerable_cls_logits.view(-1, 1),
                    answerable_cls_labels.float().view(-1, 1))
            else:
                answerable_loss = None

        elif self.loss_type in ["dice", "adaptive_dice"]:
            # compute span loss
            loss_fct = DiceLoss(with_logits=True,
                                smooth=self.args.dice_smooth,
                                ohem_ratio=self.args.dice_ohem,
                                alpha=self.args.dice_alpha,
                                square_denominator=self.args.dice_square,
                                reduction="mean",
                                index_label_position=False)
            start_end_logits_size = start_logits.shape[-1]
            start_loss = loss_fct(
                start_logits.view(-1, start_end_logits_size),
                start_labels.view(-1, 1),
            )
            end_loss = loss_fct(
                end_logits.view(-1, start_end_logits_size),
                end_labels.view(-1, 1),
            )

            if span_logits is not None:
                select_span_logits = torch.masked_select(
                    span_logits.view(-1),
                    match_label_mask.view(-1).bool())
                select_span_labels = torch.masked_select(
                    match_labels.view(-1),
                    match_label_mask.view(-1).bool())
                match_loss = loss_fct(
                    select_span_logits.view(-1, 1),
                    select_span_labels.view(-1, 1),
                )
            else:
                match_loss = None

            if answerable_cls_logits is not None:
                answerable_loss = loss_fct(answerable_cls_logits.view(-1, 1),
                                           answerable_cls_labels.view(-1, 1))
            else:
                answerable_loss = None

        else:
            loss_fct = FocalLoss(gamma=self.args.focal_gamma, reduction="none")
            start_loss = loss_fct(
                FocalLoss.convert_binary_pred_to_two_dimension(
                    start_logits.view(-1)), start_labels.view(-1))
            start_loss = (start_loss * start_float_label_mask
                          ).sum() / start_float_label_mask.sum()
            end_loss = loss_fct(
                FocalLoss.convert_binary_pred_to_two_dimension(
                    end_logits.view(-1)), end_labels.view(-1))
            end_loss = (end_loss * end_float_label_mask
                        ).sum() / end_float_label_mask.sum()
            if answerable_cls_logits is not None:
                answerable_loss = loss_fct(
                    FocalLoss.convert_binary_pred_to_two_dimension(
                        answerable_cls_logits.view(-1)),
                    answerable_cls_labels.view(-1))
                answerable_loss = answerable_loss.mean()
            else:
                answerable_loss = None

            if span_logits is not None:
                match_loss = loss_fct(
                    FocalLoss.convert_binary_pred_to_two_dimension(
                        span_logits.view(-1)), match_labels.view(-1))
                match_loss = match_loss * float_match_label_mask.view(-1)
                match_loss = match_loss.sum() / (float_match_label_mask.sum() +
                                                 1e-10)
            else:
                match_loss = None

        if answerable_loss is not None:
            return start_loss, end_loss, match_loss, answerable_loss
        return start_loss, end_loss, match_loss
Ejemplo n.º 10
0
    def _get_loss(self, loss_fun, class_weight=None):
        if class_weight is not None:
            class_weight = torch.tensor(class_weight)

        if loss_fun == 'Cross_Entropy':
            from loss.cross_entropy import CrossentropyLoss
            loss = CrossentropyLoss(weight=class_weight)
        if loss_fun == 'DynamicTopKLoss':
            from loss.cross_entropy import DynamicTopKLoss
            loss = DynamicTopKLoss(weight=class_weight,
                                   step_threshold=self.step_pre_epoch)

        elif loss_fun == 'DynamicTopkCEPlusDice':
            from loss.combine_loss import DynamicTopkCEPlusDice
            loss = DynamicTopkCEPlusDice(weight=class_weight,
                                         ignore_index=0,
                                         step_threshold=self.step_pre_epoch)

        elif loss_fun == 'TopKLoss':
            from loss.cross_entropy import TopKLoss
            loss = TopKLoss(weight=class_weight, k=self.topk)

        elif loss_fun == 'DiceLoss':
            from loss.dice_loss import DiceLoss
            loss = DiceLoss(weight=class_weight, ignore_index=0, p=1)
        elif loss_fun == 'ShiftDiceLoss':
            from loss.dice_loss import ShiftDiceLoss
            loss = ShiftDiceLoss(weight=class_weight,
                                 ignore_index=0,
                                 reduction='topk',
                                 shift=0.5,
                                 p=1,
                                 k=self.topk)
        elif loss_fun == 'TopkDiceLoss':
            from loss.dice_loss import DiceLoss
            loss = DiceLoss(weight=class_weight,
                            ignore_index=0,
                            reduction='topk',
                            k=self.topk)

        elif loss_fun == 'PowDiceLoss':
            from loss.dice_loss import DiceLoss
            loss = DiceLoss(weight=class_weight, ignore_index=0, p=2)

        elif loss_fun == 'TverskyLoss':
            from loss.tversky_loss import TverskyLoss
            loss = TverskyLoss(weight=class_weight, ignore_index=0, alpha=0.7)

        elif loss_fun == 'FocalTverskyLoss':
            from loss.tversky_loss import TverskyLoss
            loss = TverskyLoss(weight=class_weight,
                               ignore_index=0,
                               alpha=0.7,
                               gamma=0.75)

        elif loss_fun == 'BCEWithLogitsLoss':
            loss = nn.BCEWithLogitsLoss(class_weight)

        elif loss_fun == 'BCEPlusDice':
            from loss.combine_loss import BCEPlusDice
            loss = BCEPlusDice(weight=class_weight, ignore_index=0, p=1)

        elif loss_fun == 'CEPlusDice':
            from loss.combine_loss import CEPlusDice
            loss = CEPlusDice(weight=class_weight, ignore_index=0)

        elif loss_fun == 'CEPlusTopkDice':
            from loss.combine_loss import CEPlusTopkDice
            loss = CEPlusTopkDice(weight=class_weight,
                                  ignore_index=0,
                                  reduction='topk',
                                  k=self.topk)

        elif loss_fun == 'TopkCEPlusTopkDice':
            from loss.combine_loss import TopkCEPlusTopkDice
            loss = TopkCEPlusTopkDice(weight=class_weight,
                                      ignore_index=0,
                                      reduction='topk',
                                      k=self.topk)

        elif loss_fun == 'TopkCEPlusDice':
            from loss.combine_loss import TopkCEPlusDice
            loss = TopkCEPlusDice(weight=class_weight,
                                  ignore_index=0,
                                  k=self.topk)

        elif loss_fun == 'TopkCEPlusShiftDice':
            from loss.combine_loss import TopkCEPlusShiftDice
            loss = TopkCEPlusShiftDice(weight=class_weight,
                                       ignore_index=0,
                                       shift=0.5,
                                       k=self.topk)

        elif loss_fun == 'TopkCEPlusTopkShiftDice':
            from loss.combine_loss import TopkCEPlusTopkShiftDice
            loss = TopkCEPlusTopkShiftDice(weight=class_weight,
                                           ignore_index=0,
                                           reduction='topk',
                                           shift=0.5,
                                           k=self.topk)

        return loss