def __init__(self, hparams):
        super().__init__()
        self.hparams = hparams

        self.fpn = smp.FPN(encoder_name=hparams.encoder_name)

        self.iou = smp.utils.metrics.IoU(activation='sigmoid')
        self.mixed_loss = L.JointLoss(L.BinaryFocalLoss(),
                                      L.BinaryLovaszLoss(), 0.7, 0.3)
    def __init__(self, hparams):
        super().__init__()
        self.hparams = hparams

        net = getattr(smp, self.hparams.architecture)

        self.net = net(encoder_name=self.hparams.encoder_name, classes=1)

        self.iou = smp.utils.metrics.IoU(activation='sigmoid')

        self.loss = L.JointLoss(L.BinaryFocalLoss(), L.BinaryLovaszLoss(), 0.7,
                                0.3)
コード例 #3
0
def main():
    global best_acc

    if not os.path.isdir(args.out):
        mkdir_p(args.out)

    # Data
    print(f'==> Preparing freesound')

    train_labeled_set, train_unlabeled_set, val_set, test_set, train_unlabeled_warmstart_set, num_classes, pos_weights = dataset.get_freesound(
    )
    labeled_trainloader = data.DataLoader(train_labeled_set,
                                          batch_size=args.batch_size,
                                          shuffle=True,
                                          num_workers=args.num_cpu,
                                          drop_last=True,
                                          collate_fn=dataset.collate_fn,
                                          pin_memory=True)
    noisy_train_loader = data.DataLoader(train_unlabeled_warmstart_set,
                                         batch_size=args.batch_size,
                                         shuffle=True,
                                         num_workers=args.num_cpu,
                                         drop_last=True,
                                         collate_fn=dataset.collate_fn)
    unlabeled_trainloader = data.DataLoader(
        train_unlabeled_set,
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=args.num_cpu,
        drop_last=True,
        collate_fn=dataset.collate_fn_unlabbelled)
    val_loader = data.DataLoader(val_set,
                                 batch_size=args.batch_size,
                                 shuffle=False,
                                 num_workers=args.num_cpu,
                                 collate_fn=dataset.collate_fn,
                                 pin_memory=True)
    test_loader = data.DataLoader(test_set,
                                  batch_size=args.batch_size,
                                  shuffle=False,
                                  num_workers=args.num_cpu,
                                  collate_fn=dataset.collate_fn,
                                  pin_memory=True)

    # Model
    print("==> creating WRN-28-2")

    def create_model(ema=False):
        model = nn.DataParallel(models.WideResNet(num_classes=num_classes))
        if use_cuda:
            model = model.cuda()

        if ema:
            for param in model.parameters():
                param.detach_()

        return model

    model = create_model()
    ema_model = create_model(ema=True)
    print('    Total params: %.2fM' %
          (sum(p.numel() for p in model.parameters()) / 1000000.0))

    bce_loss = L.BinaryFocalLoss()
    criterion = nn.BCEWithLogitsLoss()
    train_criterion = SemiLoss(criterion)
    noisy_criterion = NoisyLoss()
    optimizer = optim.Adam(model.parameters(), lr=args.lr, eps=1e-6)

    ema_optimizer = WeightEMA(model,
                              ema_model,
                              num_classes,
                              alpha=args.ema_decay)
    start_epoch = 0

    # Resume
    title = 'freesound'
    if args.resume:
        # Load checkpoint.
        print('==> Resuming from checkpoint..')
        assert os.path.isfile(
            args.resume), 'Error: no checkpoint directory found!'
        args.out = os.path.dirname(args.resume)
        checkpoint = torch.load(args.resume, map_location='cpu')
        best_acc = checkpoint['best_acc']
        start_epoch = checkpoint['epoch']
        model.load_state_dict(checkpoint['state_dict'])
        ema_model.load_state_dict(checkpoint['ema_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        logger = Logger(os.path.abspath(os.path.join('result', 'log.txt')),
                        title=title,
                        resume=True)
    else:
        logger = Logger(os.path.join(args.out, 'log.txt'), title=title)
        logger.set_names([
            'Train Loss', 'Train Loss X', 'Train Loss U', 'Train Loss N',
            'Train Acc.', 'Valid Loss', 'Valid Acc.', 'Test Loss', 'Test Acc.'
        ])

    writer = SummaryWriter(args.out)
    step = 0
    test_accs = []
    # Train and val
    for epoch in range(start_epoch, args.epochs):
        print('\nEpoch: [%d | %d] LR: %f' %
              (epoch + 1, args.epochs, state['lr']))
        train_loss, train_loss_x, train_loss_u, train_loss_n = train(
            labeled_trainloader, unlabeled_trainloader, noisy_train_loader,
            model, optimizer, ema_optimizer, train_criterion, noisy_criterion,
            epoch, use_cuda)
        _, train_acc = validate(labeled_trainloader,
                                ema_model,
                                criterion,
                                epoch,
                                use_cuda,
                                mode='Train Stats')
        val_loss, val_acc = validate(val_loader,
                                     ema_model,
                                     criterion,
                                     epoch,
                                     use_cuda,
                                     mode='Valid Stats')
        test_loss, test_acc = validate(test_loader,
                                       ema_model,
                                       criterion,
                                       epoch,
                                       use_cuda,
                                       mode='Test Stats ')

        step = args.batch_size * args.val_iteration * (epoch + 1)

        writer.add_scalar('losses/train_loss', train_loss, step)
        writer.add_scalar('losses/valid_loss', val_loss, step)
        writer.add_scalar('losses/test_loss', test_loss, step)

        writer.add_scalar('accuracy/train_acc', train_acc, step)
        writer.add_scalar('accuracy/val_acc', val_acc, step)
        writer.add_scalar('accuracy/test_acc', test_acc, step)

        # append logger file
        logger.append([
            train_loss, train_loss_x, train_loss_u, train_loss_n, train_acc,
            val_loss, val_acc, test_loss, test_acc
        ])

        # save model
        is_best = val_acc > best_acc
        best_acc = max(val_acc, best_acc)
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'ema_state_dict': ema_model.state_dict(),
                'acc': val_acc,
                'best_acc': best_acc,
                'optimizer': optimizer.state_dict(),
            }, is_best, val_acc)
        test_accs.append(test_acc)
    logger.close()
    writer.close()

    print('Best acc:')
    print(best_acc)

    print('Mean acc:')
    print(np.mean(test_accs[-20:]))
コード例 #4
0
        super().__init__()
        self.first = WeightedLoss(first, first_weight)
        self.second = WeightedLoss(second, second_weight)
        self.third = WeightedLoss(third, third_weight)

    def forward(self, *input):
        return self.first(*input) + self.second(*input) + self.third(*input)


model = get_seg_model()

loss_1 = smp.utils.losses.BCEWithLogitsLoss()

loss_2 = torch.nn.BCELoss()

loss_3 = L.BinaryFocalLoss()

loss_4 = L.BinaryLovaszLoss()

loss_5 = L.DiceLoss(mode='binary')

loss_6 = L.SoftBCEWithLogitsLoss()

loss_7 = L.JaccardLoss(mode='binary')

loss = JointLoss(loss_3, loss_4, loss_3, 0.0, 0.7, 0.3)
metrics = [
    smp.utils.metrics.IoU(threshold=0.5, activation='sigmoid'),
]

optimizer = torch.optim.Adam([
コード例 #5
0
ファイル: train_segment.py プロジェクト: sowmen/imanip_main
def get_lossfn():
    bce = nn.BCEWithLogitsLoss()
    dice = DiceLoss(mode='binary', log_loss=True, smooth=1e-7)
    focal = losses.BinaryFocalLoss(alpha=0.25, reduced_threshold=0.5)
    criterion = ImanipLoss(bce, seglossA=dice, seglossB=focal)
    return criterion
コード例 #6
0
def focal(**kwargs):
    return L.BinaryFocalLoss(**kwargs)