示例#1
0
文件: main_swa.py 项目: dodler/kgl
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        self.model = CassavaModel(cfg=cfg)
        trn_params = cfg['train_params']
        self.fold = get_or_default(trn_params, 'fold', 0)
        self.batch_size = get_or_default(trn_params, 'batch_size', 16)
        self.num_workers = get_or_default(trn_params, 'num_workers', 2)
        self.aug_type = get_or_default(cfg, 'aug', '0')
        self.csv_path = get_or_default(cfg, 'csv_path',
                                       'input/train_folds.csv')
        self.trn_path = get_or_default(cfg, 'image_path', 'input/train_merged')
        self.mixup = get_or_default(cfg, 'mixup', False)
        self.do_cutmix = False

        if 'crit' not in cfg:
            self.crit = nn.CrossEntropyLoss()
        elif cfg['crit'] == 'focal':
            self.crit = FocalLoss()
        elif cfg['crit'] == 'smooth':
            self.crit = SmoothCrossEntropyLoss(smoothing=0.05)
        elif cfg['crit'] == 'cutmix':
            self.crit = CutMixCrossEntropyLoss(True)
            self.do_cutmix = True
        elif self.cfg['crit'] == 'focal_cosine':
            self.crit = FocalCosineLoss()
        elif self.cfg['crit'] == 'ldam':
            labels_list = list(
                Counter(pd.read_csv(self.csv_path).label.values).values())
            self.crit = LDAMLoss(labels_list)
        else:
            raise Exception('criterion not specified')
        print('mixup', self.mixup)

        print('using fold', self.fold)
def get_criterion(loss_type,
                  weights=None,
                  dim=None,
                  n_class=None,
                  s=30,
                  **kwargs):
    if loss_type == 'ce':
        return F.cross_entropy
    elif loss_type == 'weighted_ce':
        return nn.CrossEntropyLoss(weight=weights).cuda()
    elif loss_type == 'ohem':
        return OHEMCrossEntropyLoss(**kwargs).cuda()
    elif loss_type == 'ns':
        return NormSoftmaxLoss(dim, n_class).cuda()
    elif loss_type == 'af':
        return ArcFaceLoss(dim, n_class, s=s, m=0.4).cuda()
    elif loss_type == 'focal':
        return FocalLoss().cuda()
    elif loss_type == 'reduced_focal':
        return FocalLoss(reduced_threshold=0.5).cuda()
    else:
        raise ValueError(loss_type)
示例#3
0
    def __init__(self, vocab: Vocabulary,
                 text_field_embedder: TextFieldEmbedder,
                 text_encoder: Seq2SeqEncoder,
                 classifier_feedforward: FeedForward,
                 verbose_metrics: False,
                 initializer: InitializerApplicator = InitializerApplicator(),
                 regularizer: Optional[RegularizerApplicator] = None,
                 loss: Optional[dict] = None,
                 ) -> None:
        super(TextClassifier, self).__init__(vocab, regularizer)

        self.text_field_embedder = text_field_embedder
        self.num_classes = self.vocab.get_vocab_size("labels")
        self.text_encoder = text_encoder
        self.classifier_feedforward = classifier_feedforward
        self.prediction_layer = torch.nn.Linear(self.classifier_feedforward.get_output_dim(), self.num_classes)
        self.pool = lambda text, mask: util.get_final_encoder_states(text, mask, bidirectional=True)

        self.label_accuracy = CategoricalAccuracy()
        self.label_f1_metrics = {}
        for i in range(self.num_classes):
            self.label_f1_metrics[vocab.get_token_from_index(index=i, namespace="labels")] = F1Measure(positive_label=i)
        self.verbose_metrics = verbose_metrics

        if loss is None:
            self.loss = torch.nn.CrossEntropyLoss()
        else:
            alpha = loss.get('alpha')
            gamma = loss.get('gamma')
            weight = loss.get('weight')
            if alpha is not None:
                alpha = float(alpha)
            if gamma is not None:
                gamma = float(gamma)
            if weight is not None:
                weight = torch.tensor([1.0, float(weight)])
            if loss.get('type') == 'CrossEntropyLoss':
                self.loss = torch.nn.CrossEntropyLoss(weight=weight)
            elif loss.get('type') == 'BinaryFocalLoss':
                self.loss = BinaryFocalLoss(alpha=alpha, gamma=gamma)
            elif loss.get('type') == 'FocalLoss':
                self.loss = FocalLoss(alpha=alpha, gamma=gamma)
            elif loss.get('type') == 'MultiLabelMarginLoss':
                self.loss = torch.nn.MultiLabelMarginLoss()
            elif loss.get('type') == 'MultiLabelSoftMarginLoss':
                self.loss = torch.nn.MultiLabelSoftMarginLoss(weight)
            else:
                raise ValueError(f'Unexpected loss "{loss}"')

        initializer(self)
示例#4
0
    def __init__(
        self,
        vocab: Vocabulary,
        text_field_embedder: TextFieldEmbedder,
        verbose_metrics: False,
        dropout: float = 0.2,
        initializer: InitializerApplicator = InitializerApplicator(),
        regularizer: Optional[RegularizerApplicator] = None,
        loss: Optional[dict] = None,
    ) -> None:
        super(TextClassifier, self).__init__(vocab, regularizer)

        self.text_field_embedder = text_field_embedder
        self.dropout = torch.nn.Dropout(dropout)
        self.num_classes = self.vocab.get_vocab_size("labels")
        self.classifier_feedforward = torch.nn.Linear(
            self.text_field_embedder.get_output_dim(), self.num_classes)

        self.label_accuracy = CategoricalAccuracy()
        self.label_f1_metrics = {}

        self.verbose_metrics = verbose_metrics

        for i in range(self.num_classes):
            self.label_f1_metrics[vocab.get_token_from_index(
                index=i, namespace="labels")] = F1Measure(positive_label=i)

        if loss is None or loss.get('type') == 'CrossEntropyLoss':
            self.loss = torch.nn.CrossEntropyLoss()
        elif loss.get('type') == 'BinaryFocalLoss':
            self.loss = BinaryFocalLoss(alpha=loss.get('alpha'),
                                        gamma=loss.get('gamma'))
        elif loss.get('type') == 'FocalLoss':
            self.loss = FocalLoss(alpha=loss.get('alpha'),
                                  gamma=loss.get('gamma'))
        elif loss.get('type') == 'MultiLabelMarginLoss':
            self.loss = torch.nn.MultiLabelMarginLoss()
        elif loss.get('type') == 'MultiLabelSoftMarginLoss':
            self.loss = torch.nn.MultiLabelSoftMarginLoss(
                weight=torch.tensor(loss.get('weight')) if 'weight' in
                loss else None)
        else:
            raise ValueError(f'Unexpected loss "{loss}"')

        initializer(self)
示例#5
0
def _loss_factory(loss, weight, params, device):
    params = deepcopy(params)
    if loss == 'BCE':
        if 'weight' in params:
            params['weight'] = torch.FloatTensor(params['weight']).to(device)
        criterion = CrossEntropyLoss(**params)
    elif loss == 'SoftLabelCE':
        if 'weight' in params:
            params['weight'] = torch.FloatTensor(params['weight']).to(device)
        criterion = SoftLabelCE(**params)
    elif loss == 'Dice':
        criterion = DiceLoss(**params)
    elif loss == 'Focal':
        criterion = FocalLoss(**params)
    elif loss == 'Lovasz':
        criterion = LovaszLoss(**params)
    else:
        raise ValueError('Wrong loss type!')
    return WeightedLoss(criterion, weight=weight)
示例#6
0
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        self.model = CassavaModel(cfg=cfg)
        trn_params = cfg['train_params']
        self.fold = get_or_default(trn_params, 'fold', 0)
        self.batch_size = get_or_default(trn_params, 'batch_size', 16)
        self.num_workers = get_or_default(trn_params, 'num_workers', 2)
        self.aug_type = get_or_default(cfg, 'aug', '0')
        self.csv_path = get_or_default(cfg, 'csv_path',
                                       'input/train_folds.csv')
        self.trn_path = get_or_default(cfg, 'image_path', 'input/train_merged')
        self.mixup = get_or_default(cfg, 'mixup', False)
        self.do_cutmix = False

        self.disc_is_healty = nn.Sequential(
            nn.Linear(self.model.n_out, 256),
            nn.BatchNorm1d(256),
            nn.Tanh(),
            nn.Linear(256, 1),
            RevGrad(),
        )

        if 'crit' not in cfg:
            self.crit = nn.CrossEntropyLoss()
        elif cfg['crit'] == 'focal':
            self.crit = FocalLoss()
        elif cfg['crit'] == 'smooth':
            self.crit = SmoothCrossEntropyLoss()
        elif cfg['crit'] == 'cutmix':
            self.crit = CutMixCrossEntropyLoss(True)
            self.do_cutmix = True
        else:
            raise Exception('criterion not specified')
        print('mixup', self.mixup)

        print('using fold', self.fold)
示例#7
0
val_augmentations = albu.Compose(
    [
        albu.PadIfNeeded(min_height=1024,
                         min_width=2048,
                         border_mode=cv2.BORDER_CONSTANT,
                         mask_value=ignore_index,
                         p=1),
        normalization,
    ],
    p=1,
)

test_augmentations = albu.Compose([normalization], p=1)

scheduler = torch.optim.lr_scheduler.LambdaLR(
    optimizer, lambda epoch: (1 -
                              (epoch / train_parameters["num_epochs"])**0.9))

train_image_path = Path("data/train/images")
train_mask_path = Path("data/train/masks")

val_image_path = Path("data/val/images")
val_mask_path = Path("data/val/masks")

loss = FocalLoss(ignore_index=ignore_index)

callbacks = []

logdir = f"runs/2_2080Ti_{model.name}_e/baseline"
示例#8
0
def get_loss(loss_name: str, tsa=False):
    if loss_name.lower() == "dice":
        return DiceLoss(mode="binary")

    if loss_name.lower() == "mse":
        return nn.MSELoss()

    if loss_name.lower() == "msle":
        return MSLELoss()

    if loss_name.lower() == "smooth_l1":
        return nn.SmoothL1Loss()

    if loss_name.lower() == "mask_bce":
        return ResizeToTarget2d(SoftBCEWithLogitsLoss())

    if loss_name.lower() == "rank":
        return PairwiseRankingLoss()

    if loss_name.lower() == "kl":
        return LogSoftmaxKLDivLoss()

    if loss_name.lower() == "rank2":
        return PairwiseRankingLossV2()

    if loss_name.lower() == "ccos":
        return ContrastiveCosineEmbeddingLoss()

    if loss_name.lower() == "cntr":
        return EmbeddingLoss()

    if loss_name.lower() == "roc_auc":
        return RocAucLoss()

    if loss_name.lower() == "roc_auc_ce":
        return RocAucLossCE()

    if loss_name.lower() == "bce":
        return nn.BCEWithLogitsLoss(reduction="none" if tsa else "mean")

    if loss_name.lower() == "wbce":
        return nn.BCEWithLogitsLoss(reduction="none" if tsa else "mean", pos_weight=torch.tensor(0.33).float()).cuda()

    if loss_name.lower() == "wbce2":
        return nn.BCEWithLogitsLoss(reduction="none" if tsa else "mean", pos_weight=torch.tensor(0.66).float()).cuda()

    if loss_name.lower() == "ce":
        return nn.CrossEntropyLoss(reduction="none" if tsa else "mean")

    if loss_name.lower() == "soft_ce":
        return SoftCrossEntropyLoss(reduction="none" if tsa else "mean", smooth_factor=0.1)

    if loss_name.lower() == "soft_bce":
        return SoftBCEWithLogitsLoss(reduction="none" if tsa else "mean", smooth_factor=0.1, ignore_index=None)

    if loss_name.lower() == "wce":
        return nn.CrossEntropyLoss(
            reduction="none" if tsa else "mean", weight=torch.tensor([2, 1, 2, 1]).float()
        ).cuda()

    if loss_name.lower() == "focal":
        return FocalLoss(alpha=None, gamma=2, reduction="none" if tsa else "mean")

    if loss_name.lower() == "binary_focal":
        return BinaryFocalLoss(alpha=None, gamma=2, reduction="none" if tsa else "mean")

    if loss_name.lower() == "nfl":
        return FocalLoss(alpha=None, gamma=2, normalized=True, reduction="sum" if tsa else "mean")

    if loss_name.lower() == "ohem_ce":
        return OHEMCrossEntropyLoss()

    # losses for embedding
    if loss_name.lower() == "cntrv2":
        return EmbeddingLossV2()

    if loss_name.lower() == "arc_face":
        return ArcFaceLoss()

    raise KeyError(loss_name)