Пример #1
0
def train(model, train_loader, unlabeled_eval_loader, args):
    optimizer = Adam(model.parameters(), lr=args.lr)
    exp_lr_scheduler = lr_scheduler.StepLR(optimizer,
                                           step_size=args.step_size,
                                           gamma=args.gamma)
    criterion1 = nn.CrossEntropyLoss()
    criterion2 = BCE()
    for epoch in range(args.epochs):
        loss_record = AverageMeter()
        model.train()
        exp_lr_scheduler.step()
        w = args.rampup_coefficient * ramps.sigmoid_rampup(
            epoch, args.rampup_length)
        for batch_idx, ((x, x_bar), label,
                        idx) in enumerate(tqdm(train_loader)):
            x, x_bar, label = x.to(device), x_bar.to(device), label.to(device)
            output1, output2, feat = model(x)
            output1_bar, output2_bar, _ = model(x_bar)
            prob1, prob1_bar, prob2, prob2_bar = F.softmax(
                output1, dim=1), F.softmax(output1_bar, dim=1), F.softmax(
                    output2, dim=1), F.softmax(output2_bar, dim=1)

            mask_lb = idx < train_loader.labeled_length

            rank_feat = (feat[~mask_lb]).detach()
            rank_idx = torch.argsort(rank_feat, dim=1, descending=True)
            rank_idx1, rank_idx2 = PairEnum(rank_idx)

            rank_idx1, rank_idx2 = rank_idx1[:, :args.
                                             topk], rank_idx2[:, :args.topk]
            rank_idx1, _ = torch.sort(rank_idx1, dim=1)
            rank_idx2, _ = torch.sort(rank_idx2, dim=1)

            rank_diff = rank_idx1 - rank_idx2
            rank_diff = torch.sum(torch.abs(rank_diff), dim=1)
            target_ulb = torch.ones_like(rank_diff).float().to(device)
            target_ulb[rank_diff > 0] = -1

            prob1_ulb, _ = PairEnum(prob2[~mask_lb])
            _, prob2_ulb = PairEnum(prob2_bar[~mask_lb])

            loss_ce = criterion1(output1[mask_lb], label[mask_lb])
            loss_bce = criterion2(prob1_ulb, prob2_ulb, target_ulb)

            consistency_loss = F.mse_loss(prob1, prob1_bar) + F.mse_loss(
                prob2, prob2_bar)

            loss = loss_ce + loss_bce + w * consistency_loss

            loss_record.update(loss.item(), x.size(0))
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        print('Train Epoch: {} Avg Loss: {:.4f}'.format(
            epoch, loss_record.avg))
        print('test on unlabeled classes')
        args.head = 'head2'
        test(model, unlabeled_eval_loader, args)
Пример #2
0
def train(model, model_ema, train_loader, labeled_eval_loader, unlabeled_eval_loader, args):

    optimizer = SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
    exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=args.step_size, gamma=args.gamma)
    criterion1 = nn.CrossEntropyLoss() 
    criterion2 = BCE() 
    for epoch in range(args.epochs):
        loss_record = AverageMeter()
        model.train()
        model_ema.train()
        exp_lr_scheduler.step()
        w = args.rampup_coefficient * ramps.sigmoid_rampup(epoch, args.rampup_length) 
        for batch_idx, ((x, x_bar), label, idx) in enumerate(tqdm(train_loader)):
            x, x_bar, label = x.to(device), x_bar.to(device), label.to(device)

            output1, output2, feat = model(x)
            output1_bar, output2_bar, _ = model(x_bar)

            with torch.no_grad():
                output1_ema, output2_ema, feat_ema = model_ema(x)
                output1_bar_ema, output2_bar_ema, _ = model_ema(x_bar)
            prob1, prob1_bar, prob2, prob2_bar = F.softmax(output1, dim=1),  F.softmax(output1_bar, dim=1), F.softmax(output2, dim=1), F.softmax(output2_bar, dim=1)
            prob1_ema, prob1_bar_ema, prob2_ema, prob2_bar_ema = F.softmax(output1_ema, dim=1),  F.softmax(output1_bar_ema, dim=1), F.softmax(output2_ema, dim=1), F.softmax(output2_bar_ema, dim=1)

            mask_lb = label<args.num_labeled_classes

            loss_ce = criterion1(output1[mask_lb], label[mask_lb])
            loss_bce = rank_bce(criterion2,feat,mask_lb,prob2,prob2_bar)

            consistency_loss = F.mse_loss(prob1, prob1_bar) + F.mse_loss(prob2, prob2_bar)
            consistency_loss_ema = F.mse_loss(prob1, prob1_bar_ema) + F.mse_loss(prob2, prob2_bar_ema)

            loss = loss_ce + loss_bce + w * consistency_loss + w * consistency_loss_ema #+ smooth_loss(feat,mask_lb) #+ MCR(feat, idx)

            loss_record.update(loss.item(), x.size(0))
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            _update_ema_variables(model, model_ema, 0.99, epoch * len(train_loader) + batch_idx)

        print('Train Epoch: {} Avg Loss: {:.4f}'.format(epoch, loss_record.avg))
        print('test on labeled classes')
        args.head = 'head1'
        test(model, labeled_eval_loader, args)
        print('test on unlabeled classes')
        args.head='head2'
        test(model, unlabeled_eval_loader, args)
        test(model_ema, unlabeled_eval_loader, args)
Пример #3
0
def train(model, model_ema, memorybank, labeled_eval_loader_train,
          unlabeled_eval_loader_test, unlabeled_eval_loader_train, args):
    labeled_train_loader = CIFAR10Loader_iter(root=args.dataset_root,
                                              batch_size=args.batch_size // 2,
                                              split='train',
                                              aug='twice',
                                              shuffle=True,
                                              target_list=range(
                                                  args.num_labeled_classes))
    optimizer = SGD(model.parameters(),
                    lr=args.lr,
                    momentum=args.momentum,
                    weight_decay=args.weight_decay)
    exp_lr_scheduler = lr_scheduler.StepLR(optimizer,
                                           step_size=args.step_size,
                                           gamma=args.gamma)
    criterion1 = nn.CrossEntropyLoss()
    criterion2 = BCE()
    criterion3 = CrossEntropyLabelSmooth(
        num_classes=args.num_unlabeled_classes)

    for epoch in range(args.epochs):
        loss_record = AverageMeter()
        model.train()
        model_ema.train()
        exp_lr_scheduler.step()
        w = args.rampup_coefficient * ramps.sigmoid_rampup(
            epoch, args.rampup_length)

        iters = 400

        if epoch % 5 == 0:
            args.head = 'head2'
            feats, feats_mb, _ = test(model_ema, unlabeled_eval_loader_train,
                                      args)
            feats = F.normalize(torch.cat(feats, dim=0), dim=1)
            feats_mb = F.normalize(torch.cat(feats_mb, dim=0), dim=1)
            cluster = faiss.Kmeans(512, 5, niter=300, verbose=True, gpu=True)
            moving_avg_features = feats.numpy()
            cluster.train(moving_avg_features)
            _, labels_ = cluster.index.search(moving_avg_features, 1)
            labels = labels_ + 5
            target_label = labels.reshape(-1).tolist()

            # centers=faiss.vector_to_array(cluster.centroids).reshape(5, 512)
            centers = cluster.centroids

            # Memory bank by zkc
            # if epoch == 0: memorybank.features = torch.cat((F.normalize(torch.tensor(centers).cuda(), dim=1), feats), dim=0).cuda()
            # memorybank.labels = torch.cat((torch.arange(args.num_unlabeled_classes), torch.Tensor(target_label).long())).cuda()
            if epoch == 0: memorybank.features = feats_mb.cuda()
            memorybank.labels = torch.Tensor(
                labels_.reshape(-1).tolist()).long().cuda()

            model.memory.prototypes[args.num_labeled_classes:] = F.normalize(
                torch.tensor(centers).cuda(), dim=1)
            model_ema.memory.prototypes[args.
                                        num_labeled_classes:] = F.normalize(
                                            torch.tensor(centers).cuda(),
                                            dim=1)

            feats, _, labels = test(model_ema, labeled_eval_loader_train, args)
            feats = F.normalize(torch.cat(feats, dim=0), dim=1)
            centers = torch.zeros(args.num_labeled_classes, 512)
            for i in range(args.num_labeled_classes):
                idx = torch.where(torch.tensor(labels) == i)[0]
                centers[i] = torch.mean(feats[idx], 0)
            model.memory.prototypes[:args.num_labeled_classes] = torch.tensor(
                centers).cuda()
            model_ema.memory.prototypes[:args.
                                        num_labeled_classes] = torch.tensor(
                                            centers).cuda()

            unlabeled_train_loader = CIFAR10Loader_iter(
                root=args.dataset_root,
                batch_size=args.batch_size // 2,
                split='train',
                aug='twice',
                shuffle=True,
                target_list=range(args.num_labeled_classes, num_classes),
                new_labels=target_label)
            # model.head2.weight.data.copy_(
            #     torch.from_numpy(F.normalize(target_centers, axis=1)).float().cuda())

        # labeled_train_loader.new_epoch()
        # unlabeled_train_loader.new_epoch()
        # for batch_idx,_ in enumerate(range(iters)):
        #     ((x_l, x_bar_l), label_l, idx) = labeled_train_loader.next()
        #     ((x_u, x_bar_u), label_u, idx) = unlabeled_train_loader.next()
        for batch_idx, (((x_l, x_bar_l), label_l, idx_l),
                        ((x_u, x_bar_u), label_u, idx_u)) in enumerate(
                            zip(labeled_train_loader, unlabeled_train_loader)):

            x = torch.cat([x_l, x_u], dim=0)
            x_bar = torch.cat([x_bar_l, x_bar_u], dim=0)
            label = torch.cat([label_l, label_u], dim=0)
            idx = torch.cat([idx_l, idx_u], dim=0)
            x, x_bar, label = x.to(device), x_bar.to(device), label.to(device)

            output1, output2, feat, feat_mb = model(x)
            output1_bar, output2_bar, _, _ = model(x_bar)

            with torch.no_grad():
                output1_ema, output2_ema, feat_ema, feat_mb_ema = model_ema(x)
                output1_bar_ema, output2_bar_ema, _, _ = model_ema(x_bar)
            prob1, prob1_bar, prob2, prob2_bar = F.softmax(
                output1, dim=1), F.softmax(output1_bar, dim=1), F.softmax(
                    output2, dim=1), F.softmax(output2_bar, dim=1)
            prob1_ema, prob1_bar_ema, prob2_ema, prob2_bar_ema = F.softmax(
                output1_ema,
                dim=1), F.softmax(output1_bar_ema, dim=1), F.softmax(
                    output2_ema, dim=1), F.softmax(output2_bar_ema, dim=1)

            mask_lb = label < args.num_labeled_classes

            loss_ce_label = criterion1(output1[mask_lb], label[mask_lb])
            loss_ce_unlabel = criterion1(output2[~mask_lb],
                                         label[~mask_lb])  # torch.tensor(0)#

            loss_in_unlabel = torch.tensor(
                0
            )  #memorybank(feat_mb[~mask_lb], feat_mb_ema[~mask_lb], label[~mask_lb], idx[~mask_lb])

            loss_ce = loss_ce_label + loss_ce_unlabel

            loss_bce = rank_bce(criterion2, feat, mask_lb, prob2,
                                prob2_bar)  # torch.tensor(0)#

            consistency_loss = F.mse_loss(prob1, prob1_bar) + F.mse_loss(
                prob2, prob2_bar)
            consistency_loss_ema = F.mse_loss(
                prob1, prob1_bar_ema) + F.mse_loss(prob2, prob2_bar_ema)

            loss = loss_ce + loss_bce + w * consistency_loss + w * consistency_loss_ema + loss_in_unlabel

            loss_record.update(loss.item(), x.size(0))
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            _update_ema_variables(model, model_ema, 0.99,
                                  epoch * iters + batch_idx)

            if batch_idx % 200 == 0:
                print(
                    'Train Epoch: {}, iter {}/{} unl-CE Loss: {:.4f}, unl-instance Loss: {:.4f}, l-CE Loss: {:.4f}, BCE Loss: {:.4f}, CL Loss: {:.4f}, Avg Loss: {:.4f}'
                    .format(epoch, batch_idx, 400, loss_ce_unlabel.item(),
                            loss_in_unlabel.item(), loss_ce_label.item(),
                            loss_bce.item(), consistency_loss.item(),
                            loss_record.avg))
        print('Train Epoch: {} Avg Loss: {:.4f}'.format(
            epoch, loss_record.avg))
        # print('test on labeled classes')
        # args.head = 'head1'
        # test(model, labeled_eval_loader_test, args)

        # print('test on unlabeled classes')
        args.head = 'head2'
        # test(model, unlabeled_eval_loader_train, args)
        test(model_ema, unlabeled_eval_loader_test, args)