예제 #1
0
def test12(test_loader, nets, epoch):
    clf_losses = [AverageMeter() for _ in range(5)]
    top1 = [AverageMeter() for _ in range(5)]
    consistency_meters = [AverageMeter() for _ in range(2)]

    snet = nets['snet']
    tnet = nets['tnet']
    snet.eval()
    tnet.eval()

    for idx, (img, target) in enumerate(test_loader, start=1):
        img = cpu_gpu(args.cuda, img, volatile=True)
        target = cpu_gpu(args.cuda, target, volatile=True)
        out1_t, out2_t = tnet(img)
        out1_s, out2_s = snet(img)
        cls_t1 = F.cross_entropy(out1_t, target)
        cls_t2 = F.cross_entropy(out2_t, target)
        cls_s1 = F.cross_entropy(out1_s, target)
        cls_s2 = F.cross_entropy(out2_s, target)

        preds1_t = [out1_t, out1_s]
        preds2_t = [out2_t, out2_s]

        if 'New' in args.pseudo_label_type:
            pseudo, pseudo_loss = get_pseudo_loss_new(preds1_t, preds2_t, consistency_meters, weight_clip=True)
        elif 'Old' in args.pseudo_label_type:
            pseudo, pseudo_loss = get_pseudo_loss(preds1_t, preds2_t, consistency_meters, softmax_pseudo_label=True)
        else:
            mean_t = (out1_t + out2_t) * 0.5
            mean_s = (out1_s + out2_s) * 0.5
            if using_soft_discrepancy:
                intra_t = soft_discrepancy(out1_t, out2_t) * intra_t_ratio
                intra_s = soft_discrepancy(out1_s, out2_s) * intra_s_ratio
            else:
                intra_t = discrepancy(out1_t, out2_t) * intra_t_ratio
                intra_s = discrepancy(out1_s, out2_s) * intra_s_ratio
            length = intra_t + intra_s
            wt = intra_t.data / length
            ws = intra_s.data / length
            pseudo = wt * mean_t.detach() + ws * mean_s.detach()

        pseudo_cls = F.cross_entropy(pseudo, target)

        out = [out1_t, out2_t, out1_s, out2_s, pseudo]
        accu = [accuracy(pred, target, topk=(1,))[0] for pred in out]
        for acc, top in zip(accu, top1):
            top.update(acc, img.size(0))
        cls_loss = [cls_t1, cls_t2, cls_s1, cls_s2, pseudo_cls]
        for loss, losses in zip(cls_loss, clf_losses):
            losses.update(loss.item(), img.size(0))

    result = 'Epoch:{}, cls-loss:({:.3f},{:.3f},{:.3f},{:.3f},{:.3f}), ' \
             'top1:({:.4f},{:.4f},{:.4f},{:.4f},{:.4f})'. \
        format(epoch, clf_losses[0].avg, clf_losses[1].avg, clf_losses[2].avg, clf_losses[3].avg, clf_losses[4].avg,
               top1[0].avg, top1[1].avg, top1[2].avg, top1[3].avg, top1[4].avg)
    print(result)
    return [top1[0].avg, top1[1].avg, top1[2].avg, top1[3].avg, top1[4].avg]
예제 #2
0
def train(train_loader, nets, optimizers, epoch, mode=0):
    clf_losses = [AverageMeter() for _ in range(5)]
    intra_losses = [AverageMeter() for _ in range(2)]
    inter_losses = [AverageMeter() for _ in range(4)]
    top1 = [AverageMeter() for _ in range(5)]
    consistency_meters = [AverageMeter() for _ in range(2)]
    pseudo_losses = [AverageMeter() for _ in range(2)]
    weight = [AverageMeter() for _ in range(2)]

    tnet = nets['tnet']
    snet = nets['snet']
    tnet.train()
    snet.train()

    t_fea_opt = optimizers['t_fea_opt']
    t_clf_opt = optimizers['t_clf_opt']
    s_fea_opt = optimizers['s_fea_opt']
    s_clf_opt = optimizers['s_clf_opt']

    for idx, (img1, img2, target) in enumerate(train_loader, start=1):
        img1 = cpu_gpu(args.cuda, img1, volatile=False)
        img2 = cpu_gpu(args.cuda, img2, volatile=False)
        target = cpu_gpu(args.cuda, target, volatile=False)

        out1_t, out2_t = tnet(img1, img2)
        out1_s, out2_s = snet(img1, img2)
        cls_t1 = F.cross_entropy(out1_t, target)
        cls_t2 = F.cross_entropy(out2_t, target)
        cls_s1 = F.cross_entropy(out1_s, target)
        cls_s2 = F.cross_entropy(out2_s, target)

        loss1 = cls_t1 + cls_t2
        reset_grad([t_fea_opt, t_clf_opt])
        loss1.backward()
        set_step([t_fea_opt, t_clf_opt])
        loss2 = cls_s1 + cls_s2
        reset_grad([s_fea_opt, s_clf_opt])
        loss2.backward()
        set_step([s_fea_opt, s_clf_opt])

        for kk in range(inter_intra_step):
            out1_t, out2_t = tnet(img1, img2)
            out1_s, out2_s = snet(img1, img2)
            cls_t1 = F.cross_entropy(out1_t, target)
            cls_t2 = F.cross_entropy(out2_t, target)
            cls_s1 = F.cross_entropy(out1_s, target)
            cls_s2 = F.cross_entropy(out2_s, target)

            if using_soft_discrepancy:
                intra_t = soft_discrepancy(out1_t,
                                           out2_t) * intra_t_ratio  # a value
                intra_s = soft_discrepancy(out1_s, out2_s) * intra_s_ratio
            else:
                intra_t = discrepancy(out1_t, out2_t) * intra_t_ratio
                intra_s = discrepancy(out1_s, out2_s) * intra_s_ratio

            loss_1 = intra_t
            reset_grad([t_fea_opt])
            loss_1.backward()
            set_step([t_fea_opt])
            loss_2 = intra_s
            reset_grad([s_fea_opt])
            loss_2.backward()
            set_step([s_fea_opt])

        for jj in range(pseudo_step):
            out1_t, out2_t = tnet(img1, img2)
            out1_s, out2_s = snet(img1, img2)
            if inter_KL:
                inter_t1 = F.kl_div(F.log_softmax(out1_t / 3.0, dim=1),
                                    F.softmax(out1_s.detach() / 3.0, dim=1),
                                    reduction='mean') * (
                                        3.0 * 3.0) / img1.size(0) * inter_ratio
                inter_t2 = F.kl_div(F.log_softmax(out2_t / 3.0, dim=1),
                                    F.softmax(out2_s.detach() / 3.0, dim=1),
                                    reduction='mean') * (
                                        3.0 * 3.0) / img1.size(0) * inter_ratio

                inter_s1 = F.kl_div(F.log_softmax(out1_s / 3.0, dim=1),
                                    F.softmax(out1_t.detach() / 3.0, dim=1),
                                    reduction='mean') * (
                                        3.0 * 3.0) / img1.size(0) * inter_ratio
                inter_s2 = F.kl_div(F.log_softmax(out2_s / 3.0, dim=1),
                                    F.softmax(out2_t.detach() / 3.0, dim=1),
                                    reduction='mean') * (
                                        3.0 * 3.0) / img1.size(0) * inter_ratio
            elif using_soft_discrepancy:
                inter_t1 = soft_discrepancy(out1_t,
                                            out1_s.detach()) * inter_ratio
                inter_t2 = soft_discrepancy(out2_t,
                                            out2_s.detach()) * inter_ratio
                inter_s1 = soft_discrepancy(out1_s,
                                            out1_t.detach()) * inter_ratio
                inter_s2 = soft_discrepancy(out2_s,
                                            out2_t.detach()) * inter_ratio
            else:
                inter_t1 = discrepancy(out1_t, out1_s.detach()) * inter_ratio
                inter_t2 = discrepancy(out2_t, out2_s.detach()) * inter_ratio
                inter_s1 = discrepancy(out1_s, out1_t.detach()) * inter_ratio
                inter_s2 = discrepancy(out2_s, out2_t.detach()) * inter_ratio

            if using_soft_discrepancy:
                intra_t = soft_discrepancy(out1_t,
                                           out2_t) * intra_t_ratio  # a value
                intra_s = soft_discrepancy(out1_s, out2_s) * intra_s_ratio
            else:
                intra_t = discrepancy(out1_t, out2_t) * intra_t_ratio
                intra_s = discrepancy(out1_s, out2_s) * intra_s_ratio
            intra_t_e = torch.exp(-intra_t)
            intra_s_e = torch.exp(-intra_s)
            length = intra_t_e + intra_s_e
            wt = intra_t_e / length
            ws = intra_s_e / length

            mean_t = (out1_t + out2_t) * 0.5
            mean_s = (out1_s + out2_s) * 0.5
            pseudo = wt * mean_t.detach() + ws * mean_s.detach()
            kl_t1 = F.kl_div(
                F.log_softmax(out1_t / 3.0, dim=1),
                F.softmax(pseudo.detach() / 3.0, dim=1),
                reduction='mean') * (3.0 * 3.0) / img1.size(0) * kl_loss_t
            kl_t2 = F.kl_div(
                F.log_softmax(out2_t / 3.0, dim=1),
                F.softmax(pseudo.detach() / 3.0, dim=1),
                reduction='mean') * (3.0 * 3.0) / img1.size(0) * kl_loss_t
            kl_s1 = F.kl_div(
                F.log_softmax(out1_s / 3.0, dim=1),
                F.softmax(pseudo.detach() / 3.0, dim=1),
                reduction='mean') * (3.0 * 3.0) / img1.size(0) * kl_loss_s
            kl_s2 = F.kl_div(
                F.log_softmax(out2_s / 3.0, dim=1),
                F.softmax(pseudo.detach() / 3.0, dim=1),
                reduction='mean') * (3.0 * 3.0) / img1.size(0) * kl_loss_s
            pseudo_loss_t = kl_t1 + kl_t2
            pseudo_loss_s = kl_s1 + kl_s2

            if pseudo_burnin == 'exp':
                beta = 2 / (1 + math.exp(-10 * epoch / args.epochs)) - 1
            elif pseudo_burnin == 'linear':
                beta = min(1, 1.25 * (epoch / args.epochs))
            else:
                assert pseudo_burnin == 'none'
                beta = 1
            pseudo_cls = F.cross_entropy(pseudo, target)
            pseudo_loss_t = pseudo_loss_t * beta * pseudo_loss_ratio
            pseudo_loss_s = pseudo_loss_s * beta * pseudo_loss_ratio

            loss_2t = inter_t1 + inter_t2 + pseudo_loss_t
            reset_grad([t_clf_opt, t_fea_opt])
            loss_2t.backward()
            set_step([t_clf_opt, t_fea_opt])

            loss_2s = inter_s1 + inter_s2 + pseudo_loss_s
            reset_grad([s_clf_opt, s_fea_opt])
            loss_2s.backward()
            set_step([s_clf_opt, s_fea_opt])

        out = [out1_t, out2_t, out1_s, out2_s, pseudo]
        accu = [accuracy(pred, target, topk=(1, ))[0] for pred in out]
        for acc, top in zip(accu, top1):
            top.update(acc, img1.size(0))

        cls_loss = [cls_t1, cls_t2, cls_s1, cls_s2, pseudo_cls]
        for loss, losses in zip(cls_loss, clf_losses):
            losses.update(loss.item(), img1.size(0))

        intra_list = [intra_t, intra_s]
        for intra, intra_meter in zip(intra_list, intra_losses):
            intra_meter.update(intra.item(), img1.size(0))
        inter_list = [inter_t1, inter_t2, inter_s1, inter_s2]
        for inter, inter_meter in zip(inter_list, inter_losses):
            inter_meter.update(inter.item(), img1.size(0))
        pseudo_losses[0].update(pseudo_loss_t, img1.size(0))
        pseudo_losses[1].update(pseudo_loss_s, img1.size(0))
        weight[0].update(wt.item())
        weight[1].update(ws.item())

        if idx % args.print_freq == 0:
            result = 'Epoch:{}, cls-loss:({:.3f},{:.3f},{:.3f},{:.3f},{:.3f}), ' \
                     'intra-loss:({:.3f},{:.3f}), inter-loss:({:.3f},{:.3f},{:.3f},{:.3f}), ' \
                     'pseudo-loss:({:.3f},{:.3f}), weight:({:.4f},{:.4f})'.format(
                epoch, clf_losses[0].avg, clf_losses[1].avg, clf_losses[2].avg, clf_losses[3].avg, clf_losses[4].avg,
                intra_losses[0].avg, intra_losses[1].avg,
                inter_losses[0].avg, inter_losses[1].avg, inter_losses[2].avg, inter_losses[3].avg,
                pseudo_losses[0].avg, pseudo_losses[1].avg, weight[0].avg, weight[1].avg)
            print(result)
            result1 = 'Epoch:{}, top1:({:.4f},{:.4f},{:.4f},{:.4f},{:.4f})'.format(
                epoch, top1[0].avg, top1[1].avg, top1[2].avg, top1[3].avg,
                top1[4].avg)
            print(result1)
예제 #3
0
def train(train_loader, nets, optimizers, epoch):
    clf_losses = [AverageMeter() for _ in range(5)]
    intra_losses = [AverageMeter() for _ in range(2)]
    inter_losses = [AverageMeter() for _ in range(2)]
    top1 = [AverageMeter() for _ in range(5)]
    consistency_meters = [AverageMeter() for _ in range(2)]
    pseudo_losses = AverageMeter()

    tnet = nets['tnet']
    snet = nets['snet']
    tnet.train()
    snet.train()

    t_fea_opt = optimizers['t_fea_opt']
    t_clf_opt = optimizers['t_clf_opt']
    s_fea_opt = optimizers['s_fea_opt']
    s_clf_opt = optimizers['s_clf_opt']

    for idx, (img, target) in enumerate(train_loader, start=1):
        img = cpu_gpu(args.cuda, img, volatile=False)
        target = cpu_gpu(args.cuda, target, volatile=False)

        out1_t, out2_t = tnet(img)
        out1_s, out2_s = snet(img)
        cls_t1 = F.cross_entropy(out1_t, target)
        cls_t2 = F.cross_entropy(out2_t, target)
        cls_s1 = F.cross_entropy(out1_s, target)
        cls_s2 = F.cross_entropy(out2_s, target)

        loss1 = cls_t1 + cls_t2
        reset_grad([t_fea_opt, t_clf_opt])
        loss1.backward()
        set_step([t_fea_opt, t_clf_opt])

        loss2 = cls_s1 + cls_s2
        reset_grad([s_fea_opt, s_clf_opt])
        loss2.backward()
        set_step([s_fea_opt, s_clf_opt])

        # inter_intra to update F
        for kk in range(inter_intra_step):
            out1_t, out2_t = tnet(img)
            out1_s, out2_s = snet(img)
            cls_t1 = F.cross_entropy(out1_t, target)
            cls_t2 = F.cross_entropy(out2_t, target)
            cls_s1 = F.cross_entropy(out1_s, target)
            cls_s2 = F.cross_entropy(out2_s, target)

            if inter_KL:
                inter_t1 = F.kl_div(F.log_softmax(out1_t / 3.0, dim=1), F.softmax(out1_s / 3.0, dim=1), reduction='mean') * (3.0 * 3.0) / img.size(0) * inter_ratio
                inter_t2 = F.kl_div(F.log_softmax(out2_t / 3.0, dim=1), F.softmax(out2_s / 3.0, dim=1), reduction='mean') * (3.0 * 3.0) / img.size(0) * inter_ratio
            elif using_soft_discrepancy:
                inter_t1 = soft_discrepancy(out1_t, out1_s) * inter_ratio
                inter_t2 = soft_discrepancy(out2_t, out2_s) * inter_ratio
            else:
                inter_t1 = discrepancy(out1_t, out1_s) * inter_ratio
                inter_t2 = discrepancy(out2_t, out2_s) * inter_ratio
            if inter_burnin == 'exp':
                beta = 2 / (1 + math.exp(-10 * epoch / args.epochs)) - 1
            elif inter_burnin == 'linear':
                beta = min(1, 1.25 * (epoch / args.epochs))
            else:
                assert inter_burnin == 'none'
                beta = 1
            inter_t1 = inter_t1 * beta
            inter_t2 = inter_t2 * beta

            if using_soft_discrepancy:
                intra_t = soft_discrepancy(out1_t, out2_t) * intra_t_ratio  # a value
                intra_s = soft_discrepancy(out1_s, out2_s) * intra_s_ratio
            else:
                intra_t = discrepancy(out1_t, out2_t) * intra_t_ratio
                intra_s = discrepancy(out1_s, out2_s) * intra_s_ratio
            if args.using_intra1:
                loss_1 = inter_t1 + inter_t2 + intra_t + intra_s + intra_1
            else:
                loss_1 = inter_t1 + inter_t2 + intra_t + intra_s

            reset_grad([t_fea_opt, s_fea_opt])
            loss_1.backward()
            set_step([t_fea_opt, s_fea_opt])

        # pseudo to update C
        for jj in range(inter_intra_step):
            out1_t, out2_t = tnet(img)
            out1_s, out2_s = snet(img)
            preds1_t = [out1_t, out1_s]
            preds2_t = [out2_t, out2_s]
            if 'New' in args.pseudo_label_type:
                pseudo, pseudo_loss = get_pseudo_loss_new(preds1_t, preds2_t, consistency_meters, smooth=0.1, weight_clip=True)
            elif 'Old' in args.pseudo_label_type:
                pseudo, pseudo_loss = get_pseudo_loss(preds1_t, preds2_t, consistency_meters, softmax_pseudo_label=True)
            else:
                if using_soft_discrepancy:
                    intra_t = soft_discrepancy(out1_t, out2_t) * intra_t_ratio  # a value
                    intra_s = soft_discrepancy(out1_s, out2_s) * intra_s_ratio
                else:
                    intra_t = discrepancy(out1_t, out2_t) * intra_t_ratio
                    intra_s = discrepancy(out1_s, out2_s) * intra_s_ratio
                length = intra_t + intra_s
                wt = intra_t / length
                ws = intra_s / length
                mean_t = (out1_t + out2_t) * 0.5
                mean_s = (out1_s + out2_s) * 0.5
                pseudo = wt * mean_t + ws * mean_s

                kl_t1 = F.kl_div(F.log_softmax(out1_t / 3.0, dim=1), F.softmax(pseudo.detach() / 3.0, dim=1),
                                 reduction='mean') * (3.0 * 3.0) / img.size(0) * kl_loss_t
                kl_t2 = F.kl_div(F.log_softmax(out2_t / 3.0, dim=1), F.softmax(pseudo.detach() / 3.0, dim=1),
                                 reduction='mean') * (3.0 * 3.0) / img.size(0) * kl_loss_t
                kl_s1 = F.kl_div(F.log_softmax(out1_s / 3.0, dim=1), F.softmax(pseudo.detach() / 3.0, dim=1),
                                 reduction='mean') * (3.0 * 3.0) / img.size(0) * kl_loss_s
                kl_s2 = F.kl_div(F.log_softmax(out2_s / 3.0, dim=1), F.softmax(pseudo.detach() / 3.0, dim=1),
                                 reduction='mean') * (3.0 * 3.0) / img.size(0) * kl_loss_s
                pseudo_loss = kl_t1 + kl_t2 + kl_s1 + kl_s2

            if pseudo_burnin == 'exp':
                beta = 2 / (1 + math.exp(-10 * epoch / args.epochs)) - 1
            elif pseudo_burnin == 'linear':
                beta = min(1, 1.25 * (epoch / args.epochs))
            else:
                assert pseudo_burnin == 'none'
                beta = 1
            pseudo_cls = F.cross_entropy(pseudo, target)
            pseudo_loss = pseudo_loss * beta * pseudo_loss_ratio

            reset_grad([t_clf_opt, s_clf_opt])
            pseudo_loss.backward()
            set_step([t_clf_opt, s_clf_opt])

        out = [out1_t, out2_t, out1_s, out2_s, pseudo]
        accu = [accuracy(pred, target, topk=(1,))[0] for pred in out]
        for acc, top in zip(accu, top1):
            top.update(acc, img.size(0))

        cls_loss = [cls_t1, cls_t2, cls_s1, cls_s2, pseudo_cls]
        for loss, losses in zip(cls_loss, clf_losses):
            losses.update(loss.item(), img.size(0))

        intra_list = [intra_t, intra_s]
        for intra, intra_meter in zip(intra_list, intra_losses):
            intra_meter.update(intra.item(), img.size(0))
        inter_list = [inter_t1, inter_t2]
        for inter, inter_meter in zip(inter_list, inter_losses):
            inter_meter.update(inter.item(), img.size(0))
        pseudo_losses.update(pseudo_loss, img.size(0))

        if idx % args.print_freq == 0:
            result = 'Epoch:{}, cls-loss:({:.3f},{:.3f},{:.3f},{:.3f},{:.3f}), ' \
                     'intra-loss:({:.3f},{:.3f}), inter-loss:({:.3f},{:.3f}), ' \
                     'pseudo-loss:({:.3f})'.format(
                epoch, clf_losses[0].avg, clf_losses[1].avg, clf_losses[2].avg, clf_losses[3].avg, clf_losses[4].avg,
                intra_losses[0].avg, intra_losses[1].avg,
                inter_losses[0].avg, inter_losses[1].avg, pseudo_losses.avg)
            print(result)
            result1 = 'Epoch:{}, top1:({:.4f},{:.4f},{:.4f},{:.4f},{:.4f})'.format(
                epoch, top1[0].avg, top1[1].avg, top1[2].avg, top1[3].avg, top1[4].avg)
            print(result1)