def train(model, train_loader, unlabeled_eval_loader, args): optimizer = Adam(model.parameters(), lr=args.lr) exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=args.step_size, gamma=args.gamma) criterion1 = nn.CrossEntropyLoss() criterion2 = BCE() for epoch in range(args.epochs): loss_record = AverageMeter() model.train() exp_lr_scheduler.step() w = args.rampup_coefficient * ramps.sigmoid_rampup( epoch, args.rampup_length) for batch_idx, ((x, x_bar), label, idx) in enumerate(tqdm(train_loader)): x, x_bar, label = x.to(device), x_bar.to(device), label.to(device) output1, output2, feat = model(x) output1_bar, output2_bar, _ = model(x_bar) prob1, prob1_bar, prob2, prob2_bar = F.softmax( output1, dim=1), F.softmax(output1_bar, dim=1), F.softmax( output2, dim=1), F.softmax(output2_bar, dim=1) mask_lb = idx < train_loader.labeled_length rank_feat = (feat[~mask_lb]).detach() rank_idx = torch.argsort(rank_feat, dim=1, descending=True) rank_idx1, rank_idx2 = PairEnum(rank_idx) rank_idx1, rank_idx2 = rank_idx1[:, :args. topk], rank_idx2[:, :args.topk] rank_idx1, _ = torch.sort(rank_idx1, dim=1) rank_idx2, _ = torch.sort(rank_idx2, dim=1) rank_diff = rank_idx1 - rank_idx2 rank_diff = torch.sum(torch.abs(rank_diff), dim=1) target_ulb = torch.ones_like(rank_diff).float().to(device) target_ulb[rank_diff > 0] = -1 prob1_ulb, _ = PairEnum(prob2[~mask_lb]) _, prob2_ulb = PairEnum(prob2_bar[~mask_lb]) loss_ce = criterion1(output1[mask_lb], label[mask_lb]) loss_bce = criterion2(prob1_ulb, prob2_ulb, target_ulb) consistency_loss = F.mse_loss(prob1, prob1_bar) + F.mse_loss( prob2, prob2_bar) loss = loss_ce + loss_bce + w * consistency_loss loss_record.update(loss.item(), x.size(0)) optimizer.zero_grad() loss.backward() optimizer.step() print('Train Epoch: {} Avg Loss: {:.4f}'.format( epoch, loss_record.avg)) print('test on unlabeled classes') args.head = 'head2' test(model, unlabeled_eval_loader, args)
def rank_bce(criterion2, feat, mask_lb, prob2, prob2_bar): rank_feat = (feat[~mask_lb]).detach() rank_idx = torch.argsort(rank_feat, dim=1, descending=True) rank_idx1, rank_idx2 = PairEnum(rank_idx) rank_idx1, rank_idx2 = rank_idx1[:, :args.topk], rank_idx2[:, :args.topk] rank_idx1, _ = torch.sort(rank_idx1, dim=1) rank_idx2, _ = torch.sort(rank_idx2, dim=1) rank_diff = rank_idx1 - rank_idx2 rank_diff = torch.sum(torch.abs(rank_diff), dim=1) target_ulb = torch.ones_like(rank_diff).float().to(device) target_ulb[rank_diff > 0] = -1 prob1_ulb, _ = PairEnum(prob2[~mask_lb]) _, prob2_ulb = PairEnum(prob2_bar[~mask_lb]) loss_bce = criterion2(prob1_ulb, prob2_ulb, target_ulb) return loss_bce