コード例 #1
0
def train(model, train_loader, unlabeled_eval_loader, args):
    optimizer = Adam(model.parameters(), lr=args.lr)
    exp_lr_scheduler = lr_scheduler.StepLR(optimizer,
                                           step_size=args.step_size,
                                           gamma=args.gamma)
    criterion1 = nn.CrossEntropyLoss()
    criterion2 = BCE()
    for epoch in range(args.epochs):
        loss_record = AverageMeter()
        model.train()
        exp_lr_scheduler.step()
        w = args.rampup_coefficient * ramps.sigmoid_rampup(
            epoch, args.rampup_length)
        for batch_idx, ((x, x_bar), label,
                        idx) in enumerate(tqdm(train_loader)):
            x, x_bar, label = x.to(device), x_bar.to(device), label.to(device)
            output1, output2, feat = model(x)
            output1_bar, output2_bar, _ = model(x_bar)
            prob1, prob1_bar, prob2, prob2_bar = F.softmax(
                output1, dim=1), F.softmax(output1_bar, dim=1), F.softmax(
                    output2, dim=1), F.softmax(output2_bar, dim=1)

            mask_lb = idx < train_loader.labeled_length

            rank_feat = (feat[~mask_lb]).detach()
            rank_idx = torch.argsort(rank_feat, dim=1, descending=True)
            rank_idx1, rank_idx2 = PairEnum(rank_idx)

            rank_idx1, rank_idx2 = rank_idx1[:, :args.
                                             topk], rank_idx2[:, :args.topk]
            rank_idx1, _ = torch.sort(rank_idx1, dim=1)
            rank_idx2, _ = torch.sort(rank_idx2, dim=1)

            rank_diff = rank_idx1 - rank_idx2
            rank_diff = torch.sum(torch.abs(rank_diff), dim=1)
            target_ulb = torch.ones_like(rank_diff).float().to(device)
            target_ulb[rank_diff > 0] = -1

            prob1_ulb, _ = PairEnum(prob2[~mask_lb])
            _, prob2_ulb = PairEnum(prob2_bar[~mask_lb])

            loss_ce = criterion1(output1[mask_lb], label[mask_lb])
            loss_bce = criterion2(prob1_ulb, prob2_ulb, target_ulb)

            consistency_loss = F.mse_loss(prob1, prob1_bar) + F.mse_loss(
                prob2, prob2_bar)

            loss = loss_ce + loss_bce + w * consistency_loss

            loss_record.update(loss.item(), x.size(0))
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        print('Train Epoch: {} Avg Loss: {:.4f}'.format(
            epoch, loss_record.avg))
        print('test on unlabeled classes')
        args.head = 'head2'
        test(model, unlabeled_eval_loader, args)
コード例 #2
0
def rank_bce(criterion2, feat, mask_lb, prob2, prob2_bar):
    rank_feat = (feat[~mask_lb]).detach()

    rank_idx = torch.argsort(rank_feat, dim=1, descending=True)
    rank_idx1, rank_idx2 = PairEnum(rank_idx)
    rank_idx1, rank_idx2 = rank_idx1[:, :args.topk], rank_idx2[:, :args.topk]
    rank_idx1, _ = torch.sort(rank_idx1, dim=1)
    rank_idx2, _ = torch.sort(rank_idx2, dim=1)
    rank_diff = rank_idx1 - rank_idx2
    rank_diff = torch.sum(torch.abs(rank_diff), dim=1)
    target_ulb = torch.ones_like(rank_diff).float().to(device)
    target_ulb[rank_diff > 0] = -1

    prob1_ulb, _ = PairEnum(prob2[~mask_lb])
    _, prob2_ulb = PairEnum(prob2_bar[~mask_lb])
    loss_bce = criterion2(prob1_ulb, prob2_ulb, target_ulb)
    return loss_bce