Esempio n. 1
0
def get_criterion(version):
    if version == "focal_v1":
        return BinaryFocalLoss(alpha=0.04,
                               gamma=2,
                               reduction="mean",
                               ignore_index=None,
                               normalized=False,
                               reduced_threshold=None)

    elif version == "focal_v2":
        return BinaryFocalLoss(alpha=0.04,
                               gamma=1.5,
                               reduction="mean",
                               ignore_index=None,
                               normalized=False,
                               reduced_threshold=None)

    elif version == "bce_v1":
        return BCEWithLogitsLoss(pos_weight=torch.tensor(24.16),
                                 reduction='mean',
                                 weight=None,
                                 size_average=None,
                                 reduce=None)

    elif version == "bce_v2":
        return BCEWithLogitsLoss(reduction='mean',
                                 weight=None,
                                 size_average=None,
                                 reduce=None)

    elif version == "bce_v3":
        return BCEWithLogitsLoss(pos_weight=torch.tensor(15.463626008840938),
                                 reduction='mean',
                                 weight=None,
                                 size_average=None,
                                 reduce=None)

    elif version == "dice_v1":
        return DiceLoss(mode="binary",
                        from_logits=True,
                        classes=None,
                        log_loss=False,
                        smooth=0.0,
                        ignore_index=None,
                        eps=1e-7)

    else:
        raise Exception(f"Criterion version '{version}' is UNKNOWN!")
Esempio n. 2
0
    def __init__(self, vocab: Vocabulary,
                 text_field_embedder: TextFieldEmbedder,
                 text_encoder: Seq2SeqEncoder,
                 classifier_feedforward: FeedForward,
                 verbose_metrics: False,
                 initializer: InitializerApplicator = InitializerApplicator(),
                 regularizer: Optional[RegularizerApplicator] = None,
                 loss: Optional[dict] = None,
                 ) -> None:
        super(TextClassifier, self).__init__(vocab, regularizer)

        self.text_field_embedder = text_field_embedder
        self.num_classes = self.vocab.get_vocab_size("labels")
        self.text_encoder = text_encoder
        self.classifier_feedforward = classifier_feedforward
        self.prediction_layer = torch.nn.Linear(self.classifier_feedforward.get_output_dim(), self.num_classes)
        self.pool = lambda text, mask: util.get_final_encoder_states(text, mask, bidirectional=True)

        self.label_accuracy = CategoricalAccuracy()
        self.label_f1_metrics = {}
        for i in range(self.num_classes):
            self.label_f1_metrics[vocab.get_token_from_index(index=i, namespace="labels")] = F1Measure(positive_label=i)
        self.verbose_metrics = verbose_metrics

        if loss is None:
            self.loss = torch.nn.CrossEntropyLoss()
        else:
            alpha = loss.get('alpha')
            gamma = loss.get('gamma')
            weight = loss.get('weight')
            if alpha is not None:
                alpha = float(alpha)
            if gamma is not None:
                gamma = float(gamma)
            if weight is not None:
                weight = torch.tensor([1.0, float(weight)])
            if loss.get('type') == 'CrossEntropyLoss':
                self.loss = torch.nn.CrossEntropyLoss(weight=weight)
            elif loss.get('type') == 'BinaryFocalLoss':
                self.loss = BinaryFocalLoss(alpha=alpha, gamma=gamma)
            elif loss.get('type') == 'FocalLoss':
                self.loss = FocalLoss(alpha=alpha, gamma=gamma)
            elif loss.get('type') == 'MultiLabelMarginLoss':
                self.loss = torch.nn.MultiLabelMarginLoss()
            elif loss.get('type') == 'MultiLabelSoftMarginLoss':
                self.loss = torch.nn.MultiLabelSoftMarginLoss(weight)
            else:
                raise ValueError(f'Unexpected loss "{loss}"')

        initializer(self)
Esempio n. 3
0
    def __init__(self, hparams):
        super().__init__()
        self.hparams = hparams
        self.model = object_from_dict(self.hparams["model"])
        if "resume_from_checkpoint" in self.hparams:
            corrections: Dict[str, str] = {"model.": ""}

            state_dict = state_dict_from_disk(
                file_path=self.hparams["resume_from_checkpoint"],
                rename_in_layers=corrections,
            )
            self.model.load_state_dict(state_dict)

        self.losses = [
            ("jaccard", 0.1, JaccardLoss(mode="binary", from_logits=True)),
            ("focal", 0.9, BinaryFocalLoss()),
        ]
Esempio n. 4
0
    def __init__(
        self,
        vocab: Vocabulary,
        text_field_embedder: TextFieldEmbedder,
        verbose_metrics: False,
        dropout: float = 0.2,
        initializer: InitializerApplicator = InitializerApplicator(),
        regularizer: Optional[RegularizerApplicator] = None,
        loss: Optional[dict] = None,
    ) -> None:
        super(TextClassifier, self).__init__(vocab, regularizer)

        self.text_field_embedder = text_field_embedder
        self.dropout = torch.nn.Dropout(dropout)
        self.num_classes = self.vocab.get_vocab_size("labels")
        self.classifier_feedforward = torch.nn.Linear(
            self.text_field_embedder.get_output_dim(), self.num_classes)

        self.label_accuracy = CategoricalAccuracy()
        self.label_f1_metrics = {}

        self.verbose_metrics = verbose_metrics

        for i in range(self.num_classes):
            self.label_f1_metrics[vocab.get_token_from_index(
                index=i, namespace="labels")] = F1Measure(positive_label=i)

        if loss is None or loss.get('type') == 'CrossEntropyLoss':
            self.loss = torch.nn.CrossEntropyLoss()
        elif loss.get('type') == 'BinaryFocalLoss':
            self.loss = BinaryFocalLoss(alpha=loss.get('alpha'),
                                        gamma=loss.get('gamma'))
        elif loss.get('type') == 'FocalLoss':
            self.loss = FocalLoss(alpha=loss.get('alpha'),
                                  gamma=loss.get('gamma'))
        elif loss.get('type') == 'MultiLabelMarginLoss':
            self.loss = torch.nn.MultiLabelMarginLoss()
        elif loss.get('type') == 'MultiLabelSoftMarginLoss':
            self.loss = torch.nn.MultiLabelSoftMarginLoss(
                weight=torch.tensor(loss.get('weight')) if 'weight' in
                loss else None)
        else:
            raise ValueError(f'Unexpected loss "{loss}"')

        initializer(self)
Esempio n. 5
0
    def __init__(self, hparams):
        super().__init__()
        self.hparams = hparams

        self.model = object_from_dict(self.hparams["model"])
        if "resume_from_checkpoint" in self.hparams:
            corrections: Dict[str, str] = {"model.": ""}

            checkpoint = load_checkpoint(
                file_path=self.hparams["resume_from_checkpoint"],
                rename_in_layers=corrections,
            )
            self.model.load_state_dict(checkpoint["state_dict"])

        if hparams["sync_bn"]:
            self.model = apex.parallel.convert_syncbn_model(self.model)

        self.losses = [
            ("jaccard", 0.1, JaccardLoss(mode="binary", from_logits=True)),
            ("focal", 0.9, BinaryFocalLoss()),
        ]
Esempio n. 6
0
def get_loss(loss_name: str, tsa=False):
    if loss_name.lower() == "dice":
        return DiceLoss(mode="binary")

    if loss_name.lower() == "mse":
        return nn.MSELoss()

    if loss_name.lower() == "msle":
        return MSLELoss()

    if loss_name.lower() == "smooth_l1":
        return nn.SmoothL1Loss()

    if loss_name.lower() == "mask_bce":
        return ResizeToTarget2d(SoftBCEWithLogitsLoss())

    if loss_name.lower() == "rank":
        return PairwiseRankingLoss()

    if loss_name.lower() == "kl":
        return LogSoftmaxKLDivLoss()

    if loss_name.lower() == "rank2":
        return PairwiseRankingLossV2()

    if loss_name.lower() == "ccos":
        return ContrastiveCosineEmbeddingLoss()

    if loss_name.lower() == "cntr":
        return EmbeddingLoss()

    if loss_name.lower() == "roc_auc":
        return RocAucLoss()

    if loss_name.lower() == "roc_auc_ce":
        return RocAucLossCE()

    if loss_name.lower() == "bce":
        return nn.BCEWithLogitsLoss(reduction="none" if tsa else "mean")

    if loss_name.lower() == "wbce":
        return nn.BCEWithLogitsLoss(reduction="none" if tsa else "mean", pos_weight=torch.tensor(0.33).float()).cuda()

    if loss_name.lower() == "wbce2":
        return nn.BCEWithLogitsLoss(reduction="none" if tsa else "mean", pos_weight=torch.tensor(0.66).float()).cuda()

    if loss_name.lower() == "ce":
        return nn.CrossEntropyLoss(reduction="none" if tsa else "mean")

    if loss_name.lower() == "soft_ce":
        return SoftCrossEntropyLoss(reduction="none" if tsa else "mean", smooth_factor=0.1)

    if loss_name.lower() == "soft_bce":
        return SoftBCEWithLogitsLoss(reduction="none" if tsa else "mean", smooth_factor=0.1, ignore_index=None)

    if loss_name.lower() == "wce":
        return nn.CrossEntropyLoss(
            reduction="none" if tsa else "mean", weight=torch.tensor([2, 1, 2, 1]).float()
        ).cuda()

    if loss_name.lower() == "focal":
        return FocalLoss(alpha=None, gamma=2, reduction="none" if tsa else "mean")

    if loss_name.lower() == "binary_focal":
        return BinaryFocalLoss(alpha=None, gamma=2, reduction="none" if tsa else "mean")

    if loss_name.lower() == "nfl":
        return FocalLoss(alpha=None, gamma=2, normalized=True, reduction="sum" if tsa else "mean")

    if loss_name.lower() == "ohem_ce":
        return OHEMCrossEntropyLoss()

    # losses for embedding
    if loss_name.lower() == "cntrv2":
        return EmbeddingLossV2()

    if loss_name.lower() == "arc_face":
        return ArcFaceLoss()

    raise KeyError(loss_name)