Beispiel #1
0
def PI_train(model, train_loader, eva_loader, args):
    optimizer = SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
    exp_lr_scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=args.milestones, gamma=args.gamma)
    w = 0
    for epoch in range(args.epochs):
        loss_record = AverageMeter()
        model.train()
        exp_lr_scheduler.step()
        w = args.rampup_coefficient * ramps.sigmoid_rampup(epoch, args.rampup_length) 
        for batch_idx, ((x, x_bar), label, idx) in enumerate(tqdm(train_loader)):
            x, x_bar = x.to(device), x_bar.to(device)
            _, feat = model(x)
            _, feat_bar = model(x_bar)
            prob = feat2prob(feat, model.center)
            prob_bar = feat2prob(feat_bar, model.center)
            sharp_loss = F.kl_div(prob.log(), args.p_targets[idx].float().to(device))
            consistency_loss = F.mse_loss(prob, prob_bar)
            loss = sharp_loss + w * consistency_loss 
            loss_record.update(loss.item(), x.size(0))
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        print('Train Epoch: {} Avg Loss: {:.4f}'.format(epoch, loss_record.avg))
        _, _, _, probs = test(model, eva_loader, args, epoch)

        if epoch % args.update_interval==0:
            print('updating target ...')
            args.p_targets = target_distribution(probs) 
    torch.save(model.state_dict(), args.model_dir)
    print("model saved to {}.".format(args.model_dir))
Beispiel #2
0
def train(model, alphabetStr, train_loader, eval_loader, args):
    optimizer = Adam(model.parameters(), lr=args.lr)
    for epoch in range(args.epochs):
        loss_record = AverageMeter()
        model.train()
        w = args.rampup_coefficient * ramps.sigmoid_rampup(
            epoch, args.rampup_length)
        for batch_idx, (x, g_x, _, idx) in enumerate(train_loader):
            _, feat = model(x.to(device))
            _, feat_g = model(g_x.to(device))
            prob = feat2prob(feat, model.center)
            prob_g = feat2prob(feat_g, model.center)
            loss = F.kl_div(prob.log(), args.p_targets[idx].float().to(device))
            mse_loss = F.mse_loss(prob, prob_g)
            loss = loss + w * mse_loss
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            loss_record.update(loss.item(), x.size(0))
        print('Train Epoch: {} Avg Loss: {:.4f}'.format(
            epoch, loss_record.avg))
        _, _, _, probs = test(model, eval_loader, args)
        args.p_targets = target_distribution(probs)
    torch.save(model.state_dict(), args.model_dir)
    print("model saved to {}.".format(args.model_dir))
Beispiel #3
0
def TE_train(model, alphabetStr, train_loader, eval_loader, args):
    optimizer = Adam(model.parameters(), lr=args.lr)
    w = 0
    alpha = 0.6
    ntrain = len(train_loader.dataset)
    Z = torch.zeros(ntrain, args.n_clusters).float().to(device)        # intermediate values
    z_ema = torch.zeros(ntrain, args.n_clusters).float().to(device)        # temporal outputs
    z_epoch = torch.zeros(ntrain, args.n_clusters).float().to(device)  # current outputs
    for epoch in range(args.epochs):
        loss_record = AverageMeter()
        model.train()
        w = args.rampup_coefficient * ramps.sigmoid_rampup(epoch, args.rampup_length) 
        for batch_idx, (x, _, _, idx) in enumerate(train_loader):
            _, feat = model(x.to(device))
            prob = feat2prob(feat, model.center)
            z_epoch[idx, :] = prob
            prob_bar = Variable(z_ema[idx, :], requires_grad=False)
            loss = F.kl_div(prob.log(), args.p_targets[idx].float().to(device))
            mse_loss = F.mse_loss(prob, prob_bar)
            loss=loss+w*mse_loss
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            loss_record.update(loss.item(), x.size(0))

        Z = alpha * Z + (1. - alpha) * z_epoch
        z_ema = Z * (1. / (1. - alpha ** (epoch + 1)))
        print('Train Epoch: {} Avg Loss: {:.4f}'.format(epoch, loss_record.avg))
        _, _, _, probs = test(model, eval_loader, args)

        if epoch % args.update_interval==0:
            args.p_targets = target_distribution(probs) 
    torch.save(model.state_dict(), args.model_dir)
    print("model saved to {}.".format(args.model_dir))
Beispiel #4
0
def train(model, train_loader, unlabeled_eval_loader, args):
    optimizer = Adam(model.parameters(), lr=args.lr)
    exp_lr_scheduler = lr_scheduler.StepLR(optimizer,
                                           step_size=args.step_size,
                                           gamma=args.gamma)
    criterion1 = nn.CrossEntropyLoss()
    criterion2 = BCE()
    for epoch in range(args.epochs):
        loss_record = AverageMeter()
        model.train()
        exp_lr_scheduler.step()
        w = args.rampup_coefficient * ramps.sigmoid_rampup(
            epoch, args.rampup_length)
        for batch_idx, ((x, x_bar), label,
                        idx) in enumerate(tqdm(train_loader)):
            x, x_bar, label = x.to(device), x_bar.to(device), label.to(device)
            output1, output2, feat = model(x)
            output1_bar, output2_bar, _ = model(x_bar)
            prob1, prob1_bar, prob2, prob2_bar = F.softmax(
                output1, dim=1), F.softmax(output1_bar, dim=1), F.softmax(
                    output2, dim=1), F.softmax(output2_bar, dim=1)

            mask_lb = idx < train_loader.labeled_length

            rank_feat = (feat[~mask_lb]).detach()
            rank_idx = torch.argsort(rank_feat, dim=1, descending=True)
            rank_idx1, rank_idx2 = PairEnum(rank_idx)

            rank_idx1, rank_idx2 = rank_idx1[:, :args.
                                             topk], rank_idx2[:, :args.topk]
            rank_idx1, _ = torch.sort(rank_idx1, dim=1)
            rank_idx2, _ = torch.sort(rank_idx2, dim=1)

            rank_diff = rank_idx1 - rank_idx2
            rank_diff = torch.sum(torch.abs(rank_diff), dim=1)
            target_ulb = torch.ones_like(rank_diff).float().to(device)
            target_ulb[rank_diff > 0] = -1

            prob1_ulb, _ = PairEnum(prob2[~mask_lb])
            _, prob2_ulb = PairEnum(prob2_bar[~mask_lb])

            loss_ce = criterion1(output1[mask_lb], label[mask_lb])
            loss_bce = criterion2(prob1_ulb, prob2_ulb, target_ulb)

            consistency_loss = F.mse_loss(prob1, prob1_bar) + F.mse_loss(
                prob2, prob2_bar)

            loss = loss_ce + loss_bce + w * consistency_loss

            loss_record.update(loss.item(), x.size(0))
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        print('Train Epoch: {} Avg Loss: {:.4f}'.format(
            epoch, loss_record.avg))
        print('test on unlabeled classes')
        args.head = 'head2'
        test(model, unlabeled_eval_loader, args)
Beispiel #5
0
def train(model, model_ema, train_loader, labeled_eval_loader, unlabeled_eval_loader, args):

    optimizer = SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
    exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=args.step_size, gamma=args.gamma)
    criterion1 = nn.CrossEntropyLoss() 
    criterion2 = BCE() 
    for epoch in range(args.epochs):
        loss_record = AverageMeter()
        model.train()
        model_ema.train()
        exp_lr_scheduler.step()
        w = args.rampup_coefficient * ramps.sigmoid_rampup(epoch, args.rampup_length) 
        for batch_idx, ((x, x_bar), label, idx) in enumerate(tqdm(train_loader)):
            x, x_bar, label = x.to(device), x_bar.to(device), label.to(device)

            output1, output2, feat = model(x)
            output1_bar, output2_bar, _ = model(x_bar)

            with torch.no_grad():
                output1_ema, output2_ema, feat_ema = model_ema(x)
                output1_bar_ema, output2_bar_ema, _ = model_ema(x_bar)
            prob1, prob1_bar, prob2, prob2_bar = F.softmax(output1, dim=1),  F.softmax(output1_bar, dim=1), F.softmax(output2, dim=1), F.softmax(output2_bar, dim=1)
            prob1_ema, prob1_bar_ema, prob2_ema, prob2_bar_ema = F.softmax(output1_ema, dim=1),  F.softmax(output1_bar_ema, dim=1), F.softmax(output2_ema, dim=1), F.softmax(output2_bar_ema, dim=1)

            mask_lb = label<args.num_labeled_classes

            loss_ce = criterion1(output1[mask_lb], label[mask_lb])
            loss_bce = rank_bce(criterion2,feat,mask_lb,prob2,prob2_bar)

            consistency_loss = F.mse_loss(prob1, prob1_bar) + F.mse_loss(prob2, prob2_bar)
            consistency_loss_ema = F.mse_loss(prob1, prob1_bar_ema) + F.mse_loss(prob2, prob2_bar_ema)

            loss = loss_ce + loss_bce + w * consistency_loss + w * consistency_loss_ema #+ smooth_loss(feat,mask_lb) #+ MCR(feat, idx)

            loss_record.update(loss.item(), x.size(0))
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            _update_ema_variables(model, model_ema, 0.99, epoch * len(train_loader) + batch_idx)

        print('Train Epoch: {} Avg Loss: {:.4f}'.format(epoch, loss_record.avg))
        print('test on labeled classes')
        args.head = 'head1'
        test(model, labeled_eval_loader, args)
        print('test on unlabeled classes')
        args.head='head2'
        test(model, unlabeled_eval_loader, args)
        test(model_ema, unlabeled_eval_loader, args)
Beispiel #6
0
def TE_train(model, train_loader, eva_loader, args):
    optimizer = SGD(model.parameters(),
                    lr=args.lr,
                    momentum=args.momentum,
                    weight_decay=args.weight_decay)
    exp_lr_scheduler = lr_scheduler.MultiStepLR(optimizer,
                                                milestones=args.milestones,
                                                gamma=args.gamma)
    w = 0
    alpha = 0.6
    ntrain = len(train_loader.dataset)
    Z = torch.zeros(ntrain,
                    args.n_clusters).float().to(device)  # intermediate values
    z_ema = torch.zeros(ntrain,
                        args.n_clusters).float().to(device)  # temporal outputs
    z_epoch = torch.zeros(ntrain, args.n_clusters).float().to(
        device)  # current outputs
    for epoch in range(args.epochs):
        loss_record = AverageMeter()
        model.train()
        exp_lr_scheduler.step()
        w = args.rampup_coefficient * ramps.sigmoid_rampup(
            epoch, args.rampup_length)
        for batch_idx, (x, label, idx) in enumerate(tqdm(train_loader)):
            x = x.to(device)
            feat = model(x)
            prob = feat2prob(feat, model.center)
            z_epoch[idx, :] = prob
            prob_bar = Variable(z_ema[idx, :], requires_grad=False)
            sharp_loss = F.kl_div(prob.log(),
                                  args.p_targets[idx].float().to(device))
            consistency_loss = F.mse_loss(prob, prob_bar)
            loss = sharp_loss + w * consistency_loss
            loss_record.update(loss.item(), x.size(0))
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        Z = alpha * Z + (1. - alpha) * z_epoch
        z_ema = Z * (1. / (1. - alpha**(epoch + 1)))
        print('Train Epoch: {} Avg Loss: {:.4f}'.format(
            epoch, loss_record.avg))
        _, _, _, probs = test(model, eva_loader, args)
        if epoch % args.update_interval == 0:
            print('updating target ...')
            args.p_targets = target_distribution(probs)
    torch.save(model.state_dict(), args.model_dir)
    print("model saved to {}.".format(args.model_dir))
Beispiel #7
0
def TEP_train(model, train_loader, eva_loader, args):
    optimizer = SGD(model.parameters(),
                    lr=args.lr,
                    momentum=args.momentum,
                    weight_decay=args.weight_decay)
    w = 0
    alpha = 0.6
    ntrain = len(train_loader.dataset)
    Z = torch.zeros(ntrain,
                    args.n_clusters).float().to(device)  # intermediate values
    z_ema = torch.zeros(ntrain,
                        args.n_clusters).float().to(device)  # temporal outputs
    z_epoch = torch.zeros(ntrain, args.n_clusters).float().to(
        device)  # current outputs
    for epoch in range(args.epochs):
        loss_record = AverageMeter()
        acc_record = AverageMeter()
        model.train()
        w = args.rampup_coefficient * ramps.sigmoid_rampup(
            epoch, args.rampup_length)
        for batch_idx, (x, label, idx) in enumerate(tqdm(train_loader)):
            x = x.to(device)
            feat = model(x)
            prob = feat2prob(feat, model.center)
            loss = F.kl_div(prob.log(), args.p_targets[idx].float().to(device))
            loss_record.update(loss.item(), x.size(0))
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        print('Train Epoch: {} Avg Loss: {:.4f}'.format(
            epoch, loss_record.avg))
        _, _, _, probs = test(model, eva_loader, args, epoch)
        z_epoch = probs.float().to(device)
        Z = alpha * Z + (1. - alpha) * z_epoch
        z_ema = Z * (1. / (1. - alpha**(epoch + 1)))
        if epoch % args.update_interval == 0:
            print('updating target ...')
            args.p_targets = target_distribution(z_ema).float().to(device)
    torch.save(model.state_dict(), args.model_dir)
    print("model saved to {}.".format(args.model_dir))
Beispiel #8
0
def get_current_consistency_weight(epoch):
    # Consistency ramp-up from https://arxiv.org/abs/1610.02242
    return args.consistency * ramps.sigmoid_rampup(epoch,
                                                   args.consistency_rampup)
def train(args, snapshot_path):
    base_lr = args.base_lr
    num_classes = args.num_classes
    batch_size = args.batch_size
    max_iterations = args.max_iterations

    def create_model(ema=False):
        # Network definition
        model = net_factory(net_type=args.model,
                            in_chns=1,
                            class_num=num_classes)
        if ema:
            for param in model.parameters():
                param.detach_()
        return model

    model = create_model()
    ema_model = create_model(ema=True)

    def worker_init_fn(worker_id):
        random.seed(args.seed + worker_id)

    db_train = BaseDataSets(base_dir=args.root_path,
                            split="train",
                            num=None,
                            transform=transforms.Compose(
                                [RandomGenerator(args.patch_size)]))
    db_val = BaseDataSets(base_dir=args.root_path, split="val")
    total_slices = len(db_train)
    labeled_slice = patients_to_slices(args.root_path, args.labeled_num)
    print("Total silices is: {}, labeled slices is: {}".format(
        total_slices, labeled_slice))
    labeled_idxs = list(range(0, labeled_slice))
    unlabeled_idxs = list(range(labeled_slice, total_slices))
    batch_sampler = TwoStreamBatchSampler(labeled_idxs, unlabeled_idxs,
                                          batch_size,
                                          batch_size - args.labeled_bs)

    trainloader = DataLoader(db_train,
                             batch_sampler=batch_sampler,
                             num_workers=4,
                             pin_memory=True,
                             worker_init_fn=worker_init_fn)

    model.train()

    valloader = DataLoader(db_val, batch_size=1, shuffle=False, num_workers=1)

    optimizer = optim.SGD(model.parameters(),
                          lr=base_lr,
                          momentum=0.9,
                          weight_decay=0.0001)
    ce_loss = CrossEntropyLoss()
    dice_loss = losses.DiceLoss(num_classes)

    writer = SummaryWriter(snapshot_path + '/log')
    logging.info("{} iterations per epoch".format(len(trainloader)))

    iter_num = 0
    max_epoch = max_iterations // len(trainloader) + 1
    best_performance = 0.0
    iterator = tqdm(range(max_epoch), ncols=70)
    for epoch_num in iterator:
        for i_batch, sampled_batch in enumerate(trainloader):

            volume_batch, label_batch = sampled_batch['image'], sampled_batch[
                'label']
            volume_batch, label_batch = volume_batch.cuda(), label_batch.cuda()
            unlabeled_volume_batch = volume_batch[args.labeled_bs:]

            noise = torch.clamp(
                torch.randn_like(unlabeled_volume_batch) * 0.1, -0.2, 0.2)
            ema_inputs = unlabeled_volume_batch + noise

            outputs = model(volume_batch)
            outputs_soft = torch.softmax(outputs, dim=1)
            with torch.no_grad():
                ema_output = ema_model(ema_inputs)
            T = 8
            _, _, w, h = unlabeled_volume_batch.shape
            volume_batch_r = unlabeled_volume_batch.repeat(2, 1, 1, 1)
            stride = volume_batch_r.shape[0] // 2
            preds = torch.zeros([stride * T, num_classes, w, h]).cuda()
            for i in range(T // 2):
                ema_inputs = volume_batch_r + \
                    torch.clamp(torch.randn_like(
                        volume_batch_r) * 0.1, -0.2, 0.2)
                with torch.no_grad():
                    preds[2 * stride * i:2 * stride *
                          (i + 1)] = ema_model(ema_inputs)
            preds = F.softmax(preds, dim=1)
            preds = preds.reshape(T, stride, num_classes, w, h)
            preds = torch.mean(preds, dim=0)
            uncertainty = -1.0 * \
                torch.sum(preds*torch.log(preds + 1e-6), dim=1, keepdim=True)

            loss_ce = ce_loss(outputs[:args.labeled_bs],
                              label_batch[:args.labeled_bs][:].long())
            loss_dice = dice_loss(outputs_soft[:args.labeled_bs],
                                  label_batch[:args.labeled_bs].unsqueeze(1))
            supervised_loss = 0.5 * (loss_dice + loss_ce)
            consistency_weight = get_current_consistency_weight(iter_num //
                                                                150)
            consistency_dist = losses.softmax_mse_loss(
                outputs[args.labeled_bs:],
                ema_output)  # (batch, 2, 112,112,80)
            threshold = (0.75 + 0.25 * ramps.sigmoid_rampup(
                iter_num, max_iterations)) * np.log(2)
            mask = (uncertainty < threshold).float()
            consistency_loss = torch.sum(
                mask * consistency_dist) / (2 * torch.sum(mask) + 1e-16)

            loss = supervised_loss + consistency_weight * consistency_loss
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            update_ema_variables(model, ema_model, args.ema_decay, iter_num)

            lr_ = base_lr * (1.0 - iter_num / max_iterations)**0.9
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr_

            iter_num = iter_num + 1
            writer.add_scalar('info/lr', lr_, iter_num)
            writer.add_scalar('info/total_loss', loss, iter_num)
            writer.add_scalar('info/loss_ce', loss_ce, iter_num)
            writer.add_scalar('info/loss_dice', loss_dice, iter_num)
            writer.add_scalar('info/consistency_loss', consistency_loss,
                              iter_num)
            writer.add_scalar('info/consistency_weight', consistency_weight,
                              iter_num)
            logging.info(
                'iteration %d : loss : %f, loss_ce: %f, loss_dice: %f' %
                (iter_num, loss.item(), loss_ce.item(), loss_dice.item()))

            if iter_num % 20 == 0:
                image = volume_batch[1, 0:1, :, :]
                writer.add_image('train/Image', image, iter_num)
                outputs = torch.argmax(torch.softmax(outputs, dim=1),
                                       dim=1,
                                       keepdim=True)
                writer.add_image('train/Prediction', outputs[1, ...] * 50,
                                 iter_num)
                labs = label_batch[1, ...].unsqueeze(0) * 50
                writer.add_image('train/GroundTruth', labs, iter_num)

            if iter_num > 0 and iter_num % 200 == 0:
                model.eval()
                metric_list = 0.0
                for i_batch, sampled_batch in enumerate(valloader):
                    metric_i = test_single_volume(sampled_batch["image"],
                                                  sampled_batch["label"],
                                                  model,
                                                  classes=num_classes)
                    metric_list += np.array(metric_i)
                metric_list = metric_list / len(db_val)
                for class_i in range(num_classes - 1):
                    writer.add_scalar('info/val_{}_dice'.format(class_i + 1),
                                      metric_list[class_i, 0], iter_num)
                    writer.add_scalar('info/val_{}_hd95'.format(class_i + 1),
                                      metric_list[class_i, 1], iter_num)

                performance = np.mean(metric_list, axis=0)[0]

                mean_hd95 = np.mean(metric_list, axis=0)[1]
                writer.add_scalar('info/val_mean_dice', performance, iter_num)
                writer.add_scalar('info/val_mean_hd95', mean_hd95, iter_num)

                if performance > best_performance:
                    best_performance = performance
                    save_mode_path = os.path.join(
                        snapshot_path, 'iter_{}_dice_{}.pth'.format(
                            iter_num, round(best_performance, 4)))
                    save_best = os.path.join(
                        snapshot_path, '{}_best_model.pth'.format(args.model))
                    torch.save(model.state_dict(), save_mode_path)
                    torch.save(model.state_dict(), save_best)

                logging.info('iteration %d : mean_dice : %f mean_hd95 : %f' %
                             (iter_num, performance, mean_hd95))
                model.train()

            if iter_num % 3000 == 0:
                save_mode_path = os.path.join(snapshot_path,
                                              'iter_' + str(iter_num) + '.pth')
                torch.save(model.state_dict(), save_mode_path)
                logging.info("save model to {}".format(save_mode_path))

            if iter_num >= max_iterations:
                break
        if iter_num >= max_iterations:
            iterator.close()
            break
    writer.close()
    return "Training Finished!"
def train(train_loader,
          model,
          optimizer,
          epoch,
          ema_model=None,
          weak_mask=None,
          strong_mask=None):
    """ One epoch of a Mean Teacher model
    :param train_loader: torch.utils.data.DataLoader, iterator of training batches for an epoch.
    Should return 3 values: teacher input, student input, labels
    :param model: torch.Module, model to be trained, should return a weak and strong prediction
    :param optimizer: torch.Module, optimizer used to train the model
    :param epoch: int, the current epoch of training
    :param ema_model: torch.Module, student model, should return a weak and strong prediction
    :param weak_mask: mask the batch to get only the weak labeled data (used to calculate the loss)
    :param strong_mask: mask the batch to get only the strong labeled data (used to calcultate the loss)
    """
    class_criterion = nn.BCELoss()

    ##################################################
    class_criterion1 = nn.BCELoss(reduction='none')
    ##################################################

    consistency_criterion = nn.MSELoss()

    # [class_criterion, consistency_criterion] = to_cuda_if_available(
    #     [class_criterion, consistency_criterion])
    [class_criterion, class_criterion1,
     consistency_criterion] = to_cuda_if_available(
         [class_criterion, class_criterion1, consistency_criterion])

    meters = AverageMeterSet()

    LOG.debug("Nb batches: {}".format(len(train_loader)))
    start = time.time()
    rampup_length = len(train_loader) * cfg.n_epoch // 2

    print("Train\n")
    # LOG.info("Weak[k] -> Weak[k]")
    # LOG.info("Weak[k] -> strong[k]")

    # print(weak_mask.start)
    # print(strong_mask.start)
    # exit()
    count = 0
    check_cus_weak = 0
    difficulty_loss = 0
    loss_w = 1
    LOG.info("loss paramater:{}".format(loss_w))
    for i, (batch_input, ema_batch_input, target) in enumerate(train_loader):
        # print(batch_input.shape)
        # print(ema_batch_input.shape)
        # exit()
        global_step = epoch * len(train_loader) + i
        if global_step < rampup_length:
            rampup_value = ramps.sigmoid_rampup(global_step, rampup_length)
        else:
            rampup_value = 1.0

        # Todo check if this improves the performance
        # adjust_learning_rate(optimizer, rampup_value, rampdown_value)
        meters.update('lr', optimizer.param_groups[0]['lr'])

        [batch_input, ema_batch_input,
         target] = to_cuda_if_available([batch_input, ema_batch_input, target])
        LOG.debug("batch_input:{}".format(batch_input.mean()))

        # print(batch_input)
        # exit()

        # Outputs
        ##################################################
        # strong_pred_ema, weak_pred_ema = ema_model(ema_batch_input)
        strong_pred_ema, weak_pred_ema, sof_ema = ema_model(ema_batch_input)
        sof_ema = sof_ema.detach()
        ##################################################

        strong_pred_ema = strong_pred_ema.detach()
        weak_pred_ema = weak_pred_ema.detach()

        ##################################################
        # strong_pred, weak_pred = model(batch_input)
        strong_pred, weak_pred, sof = model(batch_input)
        ##################################################

        ##################################################
        # custom_ema_loss = Custom_BCE_Loss(ema_batch_input, class_criterion1)

        if difficulty_loss == 0:
            LOG.info("############### Deffine Difficulty Loss ###############")
            difficulty_loss = 1
        custom_ema_loss = Custom_BCE_Loss_difficulty(ema_batch_input,
                                                     class_criterion1,
                                                     paramater=loss_w)
        custom_ema_loss.initialize(strong_pred_ema, sof_ema)

        # custom_loss = Custom_BCE_Loss(batch_input, class_criterion1)
        custom_loss = Custom_BCE_Loss_difficulty(batch_input,
                                                 class_criterion1,
                                                 paramater=loss_w)
        custom_loss.initialize(strong_pred, sof)
        ##################################################

        # print(strong_pred.shape)
        # print(strong_pred)
        # print(weak_pred.shape)
        # print(weak_pred)
        # exit()

        loss = None
        # Weak BCE Loss
        # Take the max in the time axis
        # torch.set_printoptions(threshold=10000)
        # print(target[-10])
        # # print(target.max(-2))
        # # print(target.max(-2)[0])
        # print(target.max(-1)[0][-10])
        # exit()

        target_weak = target.max(-2)[0]
        if weak_mask is not None:
            weak_class_loss = class_criterion(weak_pred[weak_mask],
                                              target_weak[weak_mask])
            ema_class_loss = class_criterion(weak_pred_ema[weak_mask],
                                             target_weak[weak_mask])

            print(
                "noraml_weak:",
                class_criterion(weak_pred[weak_mask], target_weak[weak_mask]))

            ##################################################
            custom_weak_class_loss = custom_loss.weak(target_weak, weak_mask)
            custom_ema_class_loss = custom_ema_loss.weak(
                target_weak, weak_mask)
            print("custom_weak:", custom_weak_class_loss)
            ##################################################

            count += 1
            check_cus_weak += custom_weak_class_loss
            # print(custom_weak_class_loss.item())

            if i == 0:
                LOG.debug("target: {}".format(target.mean(-2)))
                LOG.debug("Target_weak: {}".format(target_weak))
                LOG.debug("Target_weak mask: {}".format(
                    target_weak[weak_mask]))
                LOG.debug(custom_weak_class_loss)  ###
                LOG.debug("rampup_value: {}".format(rampup_value))
            meters.update('weak_class_loss',
                          custom_weak_class_loss.item())  ###
            meters.update('Weak EMA loss', custom_ema_class_loss.item())  ###

            # loss = weak_class_loss
            loss = custom_weak_class_loss

            ####################################################################################
            # weak_class_loss = class_criterion(strong_pred[weak_mask], target[weak_mask])
            # ema_class_loss = class_criterion(strong_pred_ema[weak_mask], target[weak_mask])
            # # if i == 0:
            # #     LOG.debug("target: {}".format(target.mean(-2)))
            # #     LOG.debug("Target_weak: {}".format(target))
            # #     LOG.debug("Target_weak mask: {}".format(target[weak_mask]))
            # #     LOG.debug(weak_class_loss)
            # #     LOG.debug("rampup_value: {}".format(rampup_value))
            # meters.update('weak_class_loss', weak_class_loss.item())
            # meters.update('Weak EMA loss', ema_class_loss.item())

            # loss = weak_class_loss
            ####################################################################################

        # Strong BCE loss
        if strong_mask is not None:
            strong_class_loss = class_criterion(strong_pred[strong_mask],
                                                target[strong_mask])
            # meters.update('Strong loss', strong_class_loss.item())

            strong_ema_class_loss = class_criterion(
                strong_pred_ema[strong_mask], target[strong_mask])
            # meters.update('Strong EMA loss', strong_ema_class_loss.item())

            print(
                "normal_strong:",
                class_criterion(strong_pred[strong_mask], target[strong_mask]))

            ##################################################
            custom_strong_class_loss = custom_loss.strong(target, strong_mask)
            meters.update('Strong loss', custom_strong_class_loss.item())

            custom_strong_ema_class_loss = custom_ema_loss.strong(
                target, strong_mask)
            meters.update('Strong EMA loss',
                          custom_strong_ema_class_loss.item())
            print("custom_strong:", custom_strong_class_loss)
            ##################################################

            if loss is not None:
                # loss += strong_class_loss
                loss += custom_strong_class_loss
            else:
                # loss = strong_class_loss
                loss = custom_strong_class_loss

        # print("check_weak:", class_criterion1(weak_pred[weak_mask], target_weak[weak_mask]).mean())
        # print("check_strong:", class_criterion1(strong_pred[strong_mask], target[strong_mask]).mean())
        # print("\n")

        # exit()

        # Teacher-student consistency cost
        if ema_model is not None:

            consistency_cost = cfg.max_consistency_cost * rampup_value
            meters.update('Consistency weight', consistency_cost)
            # Take consistency about strong predictions (all data)
            consistency_loss_strong = consistency_cost * consistency_criterion(
                strong_pred, strong_pred_ema)
            meters.update('Consistency strong', consistency_loss_strong.item())
            if loss is not None:
                loss += consistency_loss_strong
            else:
                loss = consistency_loss_strong

            meters.update('Consistency weight', consistency_cost)
            # Take consistency about weak predictions (all data)
            consistency_loss_weak = consistency_cost * consistency_criterion(
                weak_pred, weak_pred_ema)
            meters.update('Consistency weak', consistency_loss_weak.item())
            if loss is not None:
                loss += consistency_loss_weak
            else:
                loss = consistency_loss_weak

        assert not (np.isnan(loss.item())
                    or loss.item() > 1e5), 'Loss explosion: {}'.format(
                        loss.item())
        assert not loss.item() < 0, 'Loss problem, cannot be negative'
        meters.update('Loss', loss.item())

        # compute gradient and do optimizer step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        global_step += 1
        if ema_model is not None:
            update_ema_variables(model, ema_model, 0.999, global_step)

    epoch_time = time.time() - start

    LOG.info('Epoch: {}\t'
             'Time {:.2f}\t'
             '{meters}'.format(epoch, epoch_time, meters=meters))

    print("\ncheck_cus_weak:\n", check_cus_weak / count)
Beispiel #11
0
def get_current_consistency_weight(epoch):
    return sigmoid_rampup(epoch, 10)
def train(args, snapshot_path):
    base_lr = args.base_lr
    train_data_path = args.root_path
    batch_size = args.batch_size
    max_iterations = args.max_iterations
    num_classes = 2

    def create_model(ema=False):
        # Network definition
        net = net_factory_3d(net_type=args.model,
                             in_chns=1,
                             class_num=num_classes)
        model = net.cuda()
        if ema:
            for param in model.parameters():
                param.detach_()
        return model

    model = create_model()
    ema_model = create_model(ema=True)

    db_train = BraTS2019(base_dir=train_data_path,
                         split='train',
                         num=None,
                         transform=transforms.Compose([
                             RandomRotFlip(),
                             RandomCrop(args.patch_size),
                             ToTensor(),
                         ]))

    def worker_init_fn(worker_id):
        random.seed(args.seed + worker_id)

    labeled_idxs = list(range(0, args.labeled_num))
    unlabeled_idxs = list(range(args.labeled_num, 250))
    batch_sampler = TwoStreamBatchSampler(labeled_idxs, unlabeled_idxs,
                                          batch_size,
                                          batch_size - args.labeled_bs)

    trainloader = DataLoader(db_train,
                             batch_sampler=batch_sampler,
                             num_workers=4,
                             pin_memory=True,
                             worker_init_fn=worker_init_fn)

    model.train()
    ema_model.train()

    optimizer = optim.SGD(model.parameters(),
                          lr=base_lr,
                          momentum=0.9,
                          weight_decay=0.0001)
    ce_loss = CrossEntropyLoss()
    dice_loss = losses.DiceLoss(2)

    writer = SummaryWriter(snapshot_path + '/log')
    logging.info("{} iterations per epoch".format(len(trainloader)))

    iter_num = 0
    max_epoch = max_iterations // len(trainloader) + 1
    best_performance = 0.0
    iterator = tqdm(range(max_epoch), ncols=70)
    for epoch_num in iterator:
        for i_batch, sampled_batch in enumerate(trainloader):

            volume_batch, label_batch = sampled_batch['image'], sampled_batch[
                'label']
            volume_batch, label_batch = volume_batch.cuda(), label_batch.cuda()
            unlabeled_volume_batch = volume_batch[args.labeled_bs:]

            noise = torch.clamp(
                torch.randn_like(unlabeled_volume_batch) * 0.1, -0.2, 0.2)
            ema_inputs = unlabeled_volume_batch + noise

            outputs = model(volume_batch)
            outputs_soft = torch.softmax(outputs, dim=1)
            with torch.no_grad():
                ema_output = ema_model(ema_inputs)
            T = 8
            _, _, d, w, h = unlabeled_volume_batch.shape
            volume_batch_r = unlabeled_volume_batch.repeat(2, 1, 1, 1, 1)
            stride = volume_batch_r.shape[0] // 2
            preds = torch.zeros([stride * T, 2, d, w, h]).cuda()
            for i in range(T // 2):
                ema_inputs = volume_batch_r + \
                    torch.clamp(torch.randn_like(
                        volume_batch_r) * 0.1, -0.2, 0.2)
                with torch.no_grad():
                    preds[2 * stride * i:2 * stride *
                          (i + 1)] = ema_model(ema_inputs)
            preds = torch.softmax(preds, dim=1)
            preds = preds.reshape(T, stride, 2, d, w, h)
            preds = torch.mean(preds, dim=0)
            uncertainty = -1.0 * \
                torch.sum(preds*torch.log(preds + 1e-6), dim=1, keepdim=True)

            loss_ce = ce_loss(outputs[:args.labeled_bs],
                              label_batch[:args.labeled_bs][:])
            loss_dice = dice_loss(outputs_soft[:args.labeled_bs],
                                  label_batch[:args.labeled_bs].unsqueeze(1))
            supervised_loss = 0.5 * (loss_dice + loss_ce)
            consistency_weight = get_current_consistency_weight(iter_num //
                                                                150)
            consistency_dist = losses.softmax_mse_loss(
                outputs[args.labeled_bs:],
                ema_output)  # (batch, 2, 112,112,80)
            threshold = (0.75 + 0.25 * ramps.sigmoid_rampup(
                iter_num, max_iterations)) * np.log(2)
            mask = (uncertainty < threshold).float()
            consistency_loss = torch.sum(
                mask * consistency_dist) / (2 * torch.sum(mask) + 1e-16)

            loss = supervised_loss + consistency_weight * consistency_loss
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            update_ema_variables(model, ema_model, args.ema_decay, iter_num)

            lr_ = base_lr * (1.0 - iter_num / max_iterations)**0.9
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr_

            iter_num = iter_num + 1
            writer.add_scalar('info/lr', lr_, iter_num)
            writer.add_scalar('info/total_loss', loss, iter_num)
            writer.add_scalar('info/loss_ce', loss_ce, iter_num)
            writer.add_scalar('info/loss_dice', loss_dice, iter_num)
            writer.add_scalar('info/consistency_loss', consistency_loss,
                              iter_num)
            writer.add_scalar('info/consistency_weight', consistency_weight,
                              iter_num)

            logging.info(
                'iteration %d : loss : %f, loss_ce: %f, loss_dice: %f' %
                (iter_num, loss.item(), loss_ce.item(), loss_dice.item()))
            writer.add_scalar('loss/loss', loss, iter_num)

            if iter_num % 20 == 0:
                image = volume_batch[0, 0:1, :, :,
                                     20:61:10].permute(3, 0, 1,
                                                       2).repeat(1, 3, 1, 1)
                grid_image = make_grid(image, 5, normalize=True)
                writer.add_image('train/Image', grid_image, iter_num)

                image = outputs_soft[0, 1:2, :, :,
                                     20:61:10].permute(3, 0, 1,
                                                       2).repeat(1, 3, 1, 1)
                grid_image = make_grid(image, 5, normalize=False)
                writer.add_image('train/Predicted_label', grid_image, iter_num)

                image = label_batch[0, :, :, 20:61:10].unsqueeze(0).permute(
                    3, 0, 1, 2).repeat(1, 3, 1, 1)
                grid_image = make_grid(image, 5, normalize=False)
                writer.add_image('train/Groundtruth_label', grid_image,
                                 iter_num)

            if iter_num > 0 and iter_num % 200 == 0:
                model.eval()
                avg_metric = test_all_case(model,
                                           args.root_path,
                                           test_list="val.txt",
                                           num_classes=2,
                                           patch_size=args.patch_size,
                                           stride_xy=64,
                                           stride_z=64)
                if avg_metric[:, 0].mean() > best_performance:
                    best_performance = avg_metric[:, 0].mean()
                    save_mode_path = os.path.join(
                        snapshot_path, 'iter_{}_dice_{}.pth'.format(
                            iter_num, round(best_performance, 4)))
                    save_best = os.path.join(
                        snapshot_path, '{}_best_model.pth'.format(args.model))
                    torch.save(model.state_dict(), save_mode_path)
                    torch.save(model.state_dict(), save_best)

                writer.add_scalar('info/val_dice_score', avg_metric[0, 0],
                                  iter_num)
                writer.add_scalar('info/val_hd95', avg_metric[0, 1], iter_num)
                logging.info('iteration %d : dice_score : %f hd95 : %f' %
                             (iter_num, avg_metric[0, 0].mean(),
                              avg_metric[0, 1].mean()))
                model.train()

            if iter_num % 3000 == 0:
                save_mode_path = os.path.join(snapshot_path,
                                              'iter_' + str(iter_num) + '.pth')
                torch.save(model.state_dict(), save_mode_path)
                logging.info("save model to {}".format(save_mode_path))

            if iter_num >= max_iterations:
                break
        if iter_num >= max_iterations:
            iterator.close()
            break
    writer.close()
    return "Training Finished!"
                    preds[2 * stride * i:2 * stride * (i + 1)] = ema_model(ema_inputs)
            preds = F.softmax(preds, dim=1)
            preds = preds.reshape(T, stride, 2, 112, 112, 80)
            preds = torch.mean(preds, dim=0)  #(batch, 2, 112,112,80)
            uncertainty = -1.0*torch.sum(preds*torch.log(preds + 1e-6), dim=1, keepdim=True) #(batch, 1, 112,112,80)


            ## calculate the loss
            loss_seg = F.cross_entropy(outputs[:labeled_bs], label_batch[:labeled_bs])
            outputs_soft = F.softmax(outputs, dim=1)
            loss_seg_dice = losses.dice_loss(outputs_soft[:labeled_bs, 1, :, :, :], label_batch[:labeled_bs] == 1)
            supervised_loss = 0.5*(loss_seg+loss_seg_dice)

            consistency_weight = get_current_consistency_weight(iter_num//150)
            consistency_dist = consistency_criterion(outputs[labeled_bs:], ema_output) #(batch, 2, 112,112,80)
            threshold = (0.75+0.25*ramps.sigmoid_rampup(iter_num, max_iterations))*np.log(2)
            mask = (uncertainty<threshold).float()
            consistency_dist = torch.sum(mask*consistency_dist)/(2*torch.sum(mask)+1e-16)
            consistency_loss = consistency_weight * consistency_dist
            loss = supervised_loss + consistency_loss

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            update_ema_variables(model, ema_model, args.ema_decay, iter_num)

            iter_num = iter_num + 1
            writer.add_scalar('uncertainty/mean', uncertainty[0,0].mean(), iter_num)
            writer.add_scalar('uncertainty/max', uncertainty[0,0].max(), iter_num)
            writer.add_scalar('uncertainty/min', uncertainty[0,0].min(), iter_num)
            writer.add_scalar('uncertainty/mask_per', torch.sum(mask)/mask.numel(), iter_num)
Beispiel #14
0
    else:
        model_path_pretrain = os.path.join(model_directory, model_name_triplet, "epoch_" + str(f_args.epochs))
    print("path of model : " + model_path_pretrain)
    create_folder(os.path.join(model_directory, model_name_triplet))

    batch_size_classif = cfg.batch_size_classif
    # Hard coded because no semi_hard in this version
    semi_hard_embed = None
    semi_hard_input = None
    if not os.path.exists(model_path_pretrain) or cfg.recompute_embedding:
        margin = triplet_margin
        for epoch in range(f_args.epochs):
            t_start_epoch = time.time()

            if cfg.rampup_margin_length is not None:
                margin = sigmoid_rampup(epoch, cfg.rampup_margin_length) * triplet_margin
            model_triplet.train()
            model_triplet, loss_mean_triplet, ratio_used = train_triplet_epoch(triplet_loader,
                                                                               # triplet_loader,
                                                                               model_triplet, optimizer,
                                                                               semi_hard_input, semi_hard_embed,
                                                                               pit=pit,
                                                                               margin=margin, swap=swap,
                                                                               acc_grad=segment
                                                                               )
            model_triplet.eval()
            loss_mean_triplet = np.mean(loss_mean_triplet)

            embed_dir = "stored_data/embeddings"
            embed_dir = os.path.join(embed_dir, model_name_triplet, "embeddings")
            create_folder(embed_dir)
                    labeled_outputs_soft[:, i, :, :, :], label_batch == i)
                loss_seg_dice += loss_mid
                print('dice score (1-dice_loss): {:.3f}'.format(1 - loss_mid))
            supervised_loss = (loss_seg + loss_seg_dice) / 2.0
            #             supervised_loss = loss_seg_dice
            #             if epoch_num==20 and i_batch==5:
            #                 import pdb
            #                 pdb.set_trace()

            ## calculate the loss ( only for unlabeled samples )
            consistency_weight = args.consistency  # get_current_consistency_weight(epoch_num/20)
            consistency_dist = consistency_criterion(
                outputs[:, 1:, :, :, :],
                ema_outputs[:, 1:, :, :, :])  #(batch, num_classes, 112,112,80)
            threshold = (0.75 +
                         0.25 * ramps.sigmoid_rampup(epoch_num, 20)) * np.log(
                             num_classes)  #N分类问题的最大不确定度是log(N)
            mask = (uncertainty < threshold).int()
            consistency_dist = torch.sum(
                mask * consistency_dist) / (torch.sum(mask) + 1e-16)
            consistency_loss = consistency_weight * consistency_dist
            loss = supervised_loss + consistency_loss

            # pytorch模型训练的三板斧
            # 一般训练神经网络,总是逃不开optimizer.zero_grad之后是loss(后面有的时候还会写forward,看你网络怎么写了)之后是是net.backward之后是optimizer.step的这个过程
            optimizer.zero_grad()  #把模型中参数的梯度设为0
            loss.backward()
            optimizer.step()
            update_ema_variables(model, ema_model, args.ema_decay, epoch_num)

            iter_num = iter_num + 1
Beispiel #16
0
def train(cfg,
          train_loader,
          model,
          optimizer,
          epoch,
          ema_model=None,
          weak_mask=None,
          strong_mask=None):
    """ One epoch of a Mean Teacher model
    :param train_loader: torch.utils.data.DataLoader, iterator of training batches for an epoch.
    Should return 3 values: teacher input, student input, labels
    :param model: torch.Module, model to be trained, should return a weak and strong prediction
    :param optimizer: torch.Module, optimizer used to train the model
    :param epoch: int, the current epoch of training
    :param ema_model: torch.Module, student model, should return a weak and strong prediction
    :param weak_mask: mask the batch to get only the weak labeled data (used to calculate the loss)
    :param strong_mask: mask the batch to get only the strong labeled data (used to calcultate the loss)
    """
    class_criterion = nn.BCELoss()
    consistency_criterion_strong = nn.MSELoss()
    lds_criterion = LDSLoss(xi=cfg.vat_xi,
                            eps=cfg.vat_eps,
                            n_power_iter=cfg.vat_n_power_iter)
    [class_criterion, consistency_criterion_strong,
     lds_criterion] = to_cuda_if_available(
         [class_criterion, consistency_criterion_strong, lds_criterion])

    meters = AverageMeterSet()

    LOG.debug("Nb batches: {}".format(len(train_loader)))
    start = time.time()
    rampup_length = len(train_loader) * cfg.n_epoch // 2
    for i, (batch_input, ema_batch_input, target) in enumerate(train_loader):
        global_step = epoch * len(train_loader) + i
        if global_step < rampup_length:
            rampup_value = ramps.sigmoid_rampup(global_step, rampup_length)
        else:
            rampup_value = 1.0

        # Todo check if this improves the performance
        # adjust_learning_rate(optimizer, rampup_value, rampdown_value)
        meters.update('lr', optimizer.param_groups[0]['lr'])

        [batch_input, ema_batch_input,
         target] = to_cuda_if_available([batch_input, ema_batch_input, target])
        LOG.debug(batch_input.mean())
        # Outputs
        strong_pred_ema, weak_pred_ema = ema_model(ema_batch_input)
        strong_pred_ema = strong_pred_ema.detach()
        weak_pred_ema = weak_pred_ema.detach()

        strong_pred, weak_pred = model(batch_input)
        loss = None
        # Weak BCE Loss
        # Take the max in axis 2 (assumed to be time)
        if len(target.shape) > 2:
            target_weak = target.max(-2)[0]
        else:
            target_weak = target

        if weak_mask is not None:
            weak_class_loss = class_criterion(weak_pred[weak_mask],
                                              target_weak[weak_mask])
            ema_class_loss = class_criterion(weak_pred_ema[weak_mask],
                                             target_weak[weak_mask])

            if i == 0:
                LOG.debug("target: {}".format(target.mean(-2)))
                LOG.debug("Target_weak: {}".format(target_weak))
                LOG.debug("Target_weak mask: {}".format(
                    target_weak[weak_mask]))
                LOG.debug(weak_class_loss)
                LOG.debug("rampup_value: {}".format(rampup_value))
            meters.update('weak_class_loss', weak_class_loss.item())

            meters.update('Weak EMA loss', ema_class_loss.item())

            loss = weak_class_loss

        # Strong BCE loss
        if strong_mask is not None:
            strong_class_loss = class_criterion(strong_pred[strong_mask],
                                                target[strong_mask])
            meters.update('Strong loss', strong_class_loss.item())

            strong_ema_class_loss = class_criterion(
                strong_pred_ema[strong_mask], target[strong_mask])
            meters.update('Strong EMA loss', strong_ema_class_loss.item())
            if loss is not None:
                loss += strong_class_loss
            else:
                loss = strong_class_loss

        # Teacher-student consistency cost
        if ema_model is not None:

            consistency_cost = cfg.max_consistency_cost * rampup_value
            meters.update('Consistency weight', consistency_cost)
            # Take only the consistence with weak and unlabel
            consistency_loss_strong = consistency_cost * consistency_criterion_strong(
                strong_pred, strong_pred_ema)
            meters.update('Consistency strong', consistency_loss_strong.item())
            if loss is not None:
                loss += consistency_loss_strong
            else:
                loss = consistency_loss_strong

            meters.update('Consistency weight', consistency_cost)
            # Take only the consistence with weak and unlabel
            consistency_loss_weak = consistency_cost * consistency_criterion_strong(
                weak_pred, weak_pred_ema)
            meters.update('Consistency weak', consistency_loss_weak.item())
            if loss is not None:
                loss += consistency_loss_weak
            else:
                loss = consistency_loss_weak

        # LDS loss
        if cfg.vat_enabled:
            lds_loss = cfg.vat_coeff * lds_criterion(model, batch_input,
                                                     weak_pred)
            LOG.info('loss: {:.3f}, lds loss: {:.3f}'.format(
                loss,
                cfg.vat_coeff * lds_loss.detach().cpu().numpy()))
            loss += lds_loss
        else:
            if i % 25 == 0:
                LOG.info('loss: {:.3f}'.format(loss))

        assert not (np.isnan(loss.item())
                    or loss.item() > 1e5), 'Loss explosion: {}'.format(
                        loss.item())
        assert not loss.item() < 0, 'Loss problem, cannot be negative'
        meters.update('Loss', loss.item())

        # compute gradient and do optimizer step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        global_step += 1
        if ema_model is not None:
            update_ema_variables(model, ema_model, 0.999, global_step)

    epoch_time = time.time() - start

    LOG.info('Epoch: {}\t'
             'Time {:.2f}\t'
             '{meters}'.format(epoch, epoch_time, meters=meters))
Beispiel #17
0
                                            label_batch == i)
                loss_seg_dice += loss_mid
                print('dice score (1-dice_loss): {:.3f}'.format(1 - loss_mid))
#             import pdb
#             pdb.set_trace()
#             print('dicetotal:{:.3f}'.format( loss_seg_dice))
#loss_seg_dice = losses.dice_loss(outputs_soft[:labeled_bs, 1, :, :, :], label_batch[:labeled_bs] == 1)
            supervised_loss = 0.5 * (loss_seg + loss_seg_dice)

            # only for unlabeled samples
            consistency_weight = get_current_consistency_weight(iter_num //
                                                                150)
            consistency_dist = consistency_criterion(
                unlabeled_outputs,
                ema_output)  #(batch, num_classes, 112,112,80)
            threshold = (0.75 + 0.25 * ramps.sigmoid_rampup(
                iter_num, max_iterations)) * np.sqrt(3)  #N分类问题的最大不确定度是sqrt(N)
            mask = (uncertainty < threshold).float()
            #             print("consistency_dist:",consistency_dist.item())
            asd = np.prod(list(mask.shape))
            #print("mask:",np.sum(mask.item())/asd )
            consistency_dist = torch.sum(
                mask * consistency_dist) / (2 * torch.sum(mask) + 1e-16)
            consistency_loss = consistency_weight * consistency_dist
            loss = supervised_loss + consistency_loss

            # pytorch模型训练的三板斧
            # 一般训练神经网络,总是逃不开optimizer.zero_grad之后是loss(后面有的时候还会写forward,看你网络怎么写了)之后是是net.backward之后是optimizer.step的这个过程
            optimizer.zero_grad()  #把模型中参数的梯度设为0
            loss.backward()
            optimizer.step()
            update_ema_variables(model, ema_model, args.ema_decay, iter_num)
Beispiel #18
0
def get_current_consistency_weight(epoch):
    return args.consistency * ramps.sigmoid_rampup(epoch, args.consistency_rampup)
Beispiel #19
0
def train(model, model_ema, memorybank, labeled_eval_loader_train,
          unlabeled_eval_loader_test, unlabeled_eval_loader_train, args):
    labeled_train_loader = CIFAR10Loader_iter(root=args.dataset_root,
                                              batch_size=args.batch_size // 2,
                                              split='train',
                                              aug='twice',
                                              shuffle=True,
                                              target_list=range(
                                                  args.num_labeled_classes))
    optimizer = SGD(model.parameters(),
                    lr=args.lr,
                    momentum=args.momentum,
                    weight_decay=args.weight_decay)
    exp_lr_scheduler = lr_scheduler.StepLR(optimizer,
                                           step_size=args.step_size,
                                           gamma=args.gamma)
    criterion1 = nn.CrossEntropyLoss()
    criterion2 = BCE()
    criterion3 = CrossEntropyLabelSmooth(
        num_classes=args.num_unlabeled_classes)

    for epoch in range(args.epochs):
        loss_record = AverageMeter()
        model.train()
        model_ema.train()
        exp_lr_scheduler.step()
        w = args.rampup_coefficient * ramps.sigmoid_rampup(
            epoch, args.rampup_length)

        iters = 400

        if epoch % 5 == 0:
            args.head = 'head2'
            feats, feats_mb, _ = test(model_ema, unlabeled_eval_loader_train,
                                      args)
            feats = F.normalize(torch.cat(feats, dim=0), dim=1)
            feats_mb = F.normalize(torch.cat(feats_mb, dim=0), dim=1)
            cluster = faiss.Kmeans(512, 5, niter=300, verbose=True, gpu=True)
            moving_avg_features = feats.numpy()
            cluster.train(moving_avg_features)
            _, labels_ = cluster.index.search(moving_avg_features, 1)
            labels = labels_ + 5
            target_label = labels.reshape(-1).tolist()

            # centers=faiss.vector_to_array(cluster.centroids).reshape(5, 512)
            centers = cluster.centroids

            # Memory bank by zkc
            # if epoch == 0: memorybank.features = torch.cat((F.normalize(torch.tensor(centers).cuda(), dim=1), feats), dim=0).cuda()
            # memorybank.labels = torch.cat((torch.arange(args.num_unlabeled_classes), torch.Tensor(target_label).long())).cuda()
            if epoch == 0: memorybank.features = feats_mb.cuda()
            memorybank.labels = torch.Tensor(
                labels_.reshape(-1).tolist()).long().cuda()

            model.memory.prototypes[args.num_labeled_classes:] = F.normalize(
                torch.tensor(centers).cuda(), dim=1)
            model_ema.memory.prototypes[args.
                                        num_labeled_classes:] = F.normalize(
                                            torch.tensor(centers).cuda(),
                                            dim=1)

            feats, _, labels = test(model_ema, labeled_eval_loader_train, args)
            feats = F.normalize(torch.cat(feats, dim=0), dim=1)
            centers = torch.zeros(args.num_labeled_classes, 512)
            for i in range(args.num_labeled_classes):
                idx = torch.where(torch.tensor(labels) == i)[0]
                centers[i] = torch.mean(feats[idx], 0)
            model.memory.prototypes[:args.num_labeled_classes] = torch.tensor(
                centers).cuda()
            model_ema.memory.prototypes[:args.
                                        num_labeled_classes] = torch.tensor(
                                            centers).cuda()

            unlabeled_train_loader = CIFAR10Loader_iter(
                root=args.dataset_root,
                batch_size=args.batch_size // 2,
                split='train',
                aug='twice',
                shuffle=True,
                target_list=range(args.num_labeled_classes, num_classes),
                new_labels=target_label)
            # model.head2.weight.data.copy_(
            #     torch.from_numpy(F.normalize(target_centers, axis=1)).float().cuda())

        # labeled_train_loader.new_epoch()
        # unlabeled_train_loader.new_epoch()
        # for batch_idx,_ in enumerate(range(iters)):
        #     ((x_l, x_bar_l), label_l, idx) = labeled_train_loader.next()
        #     ((x_u, x_bar_u), label_u, idx) = unlabeled_train_loader.next()
        for batch_idx, (((x_l, x_bar_l), label_l, idx_l),
                        ((x_u, x_bar_u), label_u, idx_u)) in enumerate(
                            zip(labeled_train_loader, unlabeled_train_loader)):

            x = torch.cat([x_l, x_u], dim=0)
            x_bar = torch.cat([x_bar_l, x_bar_u], dim=0)
            label = torch.cat([label_l, label_u], dim=0)
            idx = torch.cat([idx_l, idx_u], dim=0)
            x, x_bar, label = x.to(device), x_bar.to(device), label.to(device)

            output1, output2, feat, feat_mb = model(x)
            output1_bar, output2_bar, _, _ = model(x_bar)

            with torch.no_grad():
                output1_ema, output2_ema, feat_ema, feat_mb_ema = model_ema(x)
                output1_bar_ema, output2_bar_ema, _, _ = model_ema(x_bar)
            prob1, prob1_bar, prob2, prob2_bar = F.softmax(
                output1, dim=1), F.softmax(output1_bar, dim=1), F.softmax(
                    output2, dim=1), F.softmax(output2_bar, dim=1)
            prob1_ema, prob1_bar_ema, prob2_ema, prob2_bar_ema = F.softmax(
                output1_ema,
                dim=1), F.softmax(output1_bar_ema, dim=1), F.softmax(
                    output2_ema, dim=1), F.softmax(output2_bar_ema, dim=1)

            mask_lb = label < args.num_labeled_classes

            loss_ce_label = criterion1(output1[mask_lb], label[mask_lb])
            loss_ce_unlabel = criterion1(output2[~mask_lb],
                                         label[~mask_lb])  # torch.tensor(0)#

            loss_in_unlabel = torch.tensor(
                0
            )  #memorybank(feat_mb[~mask_lb], feat_mb_ema[~mask_lb], label[~mask_lb], idx[~mask_lb])

            loss_ce = loss_ce_label + loss_ce_unlabel

            loss_bce = rank_bce(criterion2, feat, mask_lb, prob2,
                                prob2_bar)  # torch.tensor(0)#

            consistency_loss = F.mse_loss(prob1, prob1_bar) + F.mse_loss(
                prob2, prob2_bar)
            consistency_loss_ema = F.mse_loss(
                prob1, prob1_bar_ema) + F.mse_loss(prob2, prob2_bar_ema)

            loss = loss_ce + loss_bce + w * consistency_loss + w * consistency_loss_ema + loss_in_unlabel

            loss_record.update(loss.item(), x.size(0))
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            _update_ema_variables(model, model_ema, 0.99,
                                  epoch * iters + batch_idx)

            if batch_idx % 200 == 0:
                print(
                    'Train Epoch: {}, iter {}/{} unl-CE Loss: {:.4f}, unl-instance Loss: {:.4f}, l-CE Loss: {:.4f}, BCE Loss: {:.4f}, CL Loss: {:.4f}, Avg Loss: {:.4f}'
                    .format(epoch, batch_idx, 400, loss_ce_unlabel.item(),
                            loss_in_unlabel.item(), loss_ce_label.item(),
                            loss_bce.item(), consistency_loss.item(),
                            loss_record.avg))
        print('Train Epoch: {} Avg Loss: {:.4f}'.format(
            epoch, loss_record.avg))
        # print('test on labeled classes')
        # args.head = 'head1'
        # test(model, labeled_eval_loader_test, args)

        # print('test on unlabeled classes')
        args.head = 'head2'
        # test(model, unlabeled_eval_loader_train, args)
        test(model_ema, unlabeled_eval_loader_test, args)