예제 #1
0
def test_target(args):
    dset_loaders = data_load(args)
    ## set base network
    if args.net[0:3] == 'res':
        netF = network.ResBase(res_name=args.net).cuda()
    elif args.net[0:3] == 'vgg':
        netF = network.VGGBase(vgg_name=args.net).cuda()

    netB = network.feat_bootleneck(type=args.classifier,
                                   feature_dim=netF.in_features,
                                   bottleneck_dim=args.bottleneck).cuda()
    netC = network.feat_classifier(type=args.layer,
                                   class_num=args.class_num,
                                   bottleneck_dim=args.bottleneck).cuda()

    args.modelpath = args.output_dir_ori + "/target_F_" + args.savename + ".pt"
    netF.load_state_dict(torch.load(args.modelpath))
    args.modelpath = args.output_dir_ori + "/target_B_" + args.savename + ".pt"
    netB.load_state_dict(torch.load(args.modelpath))
    args.modelpath = args.output_dir_ori + "/target_C_" + args.savename + ".pt"
    netC.load_state_dict(torch.load(args.modelpath))
    netF.eval()
    netB.eval()
    netC.eval()

    acc, y, py = cal_acc(dset_loaders['test'], netF, netB, netC)
    log_str = '\nTask: {}, Accuracy = {:.2f}%'.format(args.name, acc * 100)
    args.out_file.write(log_str)
    args.out_file.flush()
    print(log_str)

    return y, py
def test(args):
    dset_loaders = digit_load(args)
    ## set base network
    if args.dset == 'u':
        netF = network.LeNetBase()  #.cuda()
    elif args.dset == 'm':
        netF = network.LeNetBase()  #.cuda()
    elif args.dset == 's':
        netF = network.DTNBase()  #.cuda()

    netB = network.feat_bootleneck(type=args.classifier,
                                   feature_dim=netF.in_features,
                                   bottleneck_dim=args.bottleneck)  #.cuda()
    netC = network.feat_classifier(type=args.layer,
                                   class_num=args.class_num,
                                   bottleneck_dim=args.bottleneck)  #.cuda()

    args.modelpath = args.output_dir + '/F.pt'
    netF.load_state_dict(torch.load(args.modelpath))
    args.modelpath = args.output_dir + '/B.pt'
    netB.load_state_dict(torch.load(args.modelpath))
    args.modelpath = args.output_dir + '/C.pt'
    netC.load_state_dict(torch.load(args.modelpath))
    netF.eval()
    netB.eval()
    netC.eval()

    acc, _ = cal_acc(dset_loaders['test'], netF, netB, netC)
    log_str = 'Task: {}, [DONT CARE] Accuracy = {:.2f}%'.format(args.dset, acc)
    try:
        args.out_file.write(log_str + '\n')
        args.out_file.flush()
    except:
        pass
    print(log_str + '\n')
예제 #3
0
def test_target(args):
    dset_loaders = digit_load(args)
    ## set base network
    if args.dset == 'u2m':
        netF = network.LeNetBase().cuda()
    elif args.dset == 'm2u':
        netF = network.LeNetBase().cuda()
    elif args.dset == 's2m':
        netF = network.DTNBase().cuda()

    netB = network.feat_bootleneck(type=args.classifier,
                                   feature_dim=netF.in_features,
                                   bottleneck_dim=args.bottleneck).cuda()
    netC = network.feat_classifier(type=args.layer,
                                   class_num=args.class_num,
                                   bottleneck_dim=args.bottleneck).cuda()

    args.modelpath = args.output_dir + '/source_F_val.pt'
    netF.load_state_dict(torch.load(args.modelpath))
    args.modelpath = args.output_dir + '/source_B_val.pt'
    netB.load_state_dict(torch.load(args.modelpath))
    args.modelpath = args.output_dir + '/source_C_val.pt'
    netC.load_state_dict(torch.load(args.modelpath))
    netF.eval()
    netB.eval()
    netC.eval()

    acc, _ = cal_acc(dset_loaders['test'], netF, netB, netC)
    log_str = 'Task: {}, Accuracy = {:.2f}%'.format(args.dset, acc * 100)
    args.out_file.write(log_str + '\n')
    args.out_file.flush()
    print(log_str + '\n')
예제 #4
0
파일: image_source.py 프로젝트: yyht/SHOT
def test_target(args, zz):
    dset_loaders = data_load(args)
    ## set base network
    if args.net[0:3] == 'res':
        netF = network.ResBase(res_name=args.net).cuda()

    netB = network.feat_bootleneck(type=args.classifier,
                                   feature_dim=netF.in_features,
                                   bottleneck_dim=args.bottleneck).cuda()
    netC = network.feat_classifier(type=args.layer,
                                   class_num=args.class_num,
                                   bottleneck_dim=args.bottleneck).cuda()

    args.modelpath = args.output_dir_src + '/source_F_' + str(zz) + '.pt'
    netF.load_state_dict(torch.load(args.modelpath))
    args.modelpath = args.output_dir_src + '/source_B_' + str(zz) + '.pt'
    netB.load_state_dict(torch.load(args.modelpath))
    args.modelpath = args.output_dir_src + '/source_C_' + str(zz) + '.pt'
    netC.load_state_dict(torch.load(args.modelpath))
    netF.eval()
    netB.eval()
    netC.eval()

    acc, _ = cal_acc(dset_loaders['test'], netF, netB, netC, args.da == 'oda')
    log_str = '\nZz: {}, Task: {}, Accuracy = {:.2f}%'.format(
        zz, args.name, acc * 100)
    args.out_file.write(log_str)
    args.out_file.flush()
    print(log_str)
예제 #5
0
def test_target(args):
    dset_loaders = data_load(args)
    ## set base network
    if args.net[0:3] == 'res':
        netF = network.ResBase(res_name=args.net).cuda()
    elif args.net[0:3] == 'vgg':
        netF = network.VGGBase(vgg_name=args.net).cuda()  

    netB = network.feat_bootleneck(type=args.classifier, feature_dim=netF.in_features, bottleneck_dim=args.bottleneck).cuda()
    netC = network.feat_classifier(type=args.layer, class_num = args.class_num, bottleneck_dim=args.bottleneck).cuda()
    
    args.modelpath = args.output_dir_src + '/source_F.pt'   
    netF.load_state_dict(torch.load(args.modelpath))
    args.modelpath = args.output_dir_src + '/source_B.pt'   
    netB.load_state_dict(torch.load(args.modelpath))
    args.modelpath = args.output_dir_src + '/source_C.pt'   
    netC.load_state_dict(torch.load(args.modelpath))
    netF.eval()
    netB.eval()
    netC.eval()

    if args.da == 'oda':
        acc_os1, acc_os2, acc_unknown = cal_acc_oda(dset_loaders['test'], netF, netB, netC)
        log_str = '\nTraining: {}, Task: {}, Accuracy = {:.2f}% / {:.2f}% / {:.2f}%'.format(args.trte, args.name, acc_os2, acc_os1, acc_unknown)
    else:
        if args.dset=='VISDA-C':
            acc, acc_list = cal_acc(dset_loaders['test'], netF, netB, netC, True)
            log_str = '\nTraining: {}, Task: {}, Accuracy = {:.2f}%'.format(args.trte, args.name, acc) + '\n' + acc_list
        else:
            acc, _ = cal_acc(dset_loaders['test'], netF, netB, netC, False)
            log_str = '\nTraining: {}, Task: {}, Accuracy = {:.2f}%'.format(args.trte, args.name, acc)

    args.out_file.write(log_str)
    args.out_file.flush()
    print(log_str)
예제 #6
0
def test_target(args):
    dset_loaders = data_load(args)

    ## set base network
    if args.net[0:3] == 'res' or args.net[0:3] == 'vgg':
        if args.net[0:3] == 'res':
            netF = network.ResBase(res_name=args.net).cuda()
        else:
            netF = network.VGGBase(vgg_name=args.net).cuda()
        netB = network.feat_bootleneck(type=args.classifier, feature_dim=netF.in_features,
                                       bottleneck_dim=args.bottleneck).cuda()
        netC = network.feat_classifier(type=args.layer, class_num=args.class_num, bottleneck_dim=args.bottleneck).cuda()


        args.modelpath = args.output_dir_src + '/source_F.pt'
        netF.load_state_dict(torch.load(args.modelpath))
        args.modelpath = args.output_dir_src + '/source_B.pt'
        netB.load_state_dict(torch.load(args.modelpath))
        args.modelpath = args.output_dir_src + '/source_C.pt'
        netC.load_state_dict(torch.load(args.modelpath))

        netF.eval()
        netB.eval()
        netC.eval()

    if args.dset == 'VISDA-RSUT' or args.dset == 'VISDA-RSUT-50' or args.dset == 'VISDA-RSUT-10':
        # For VisDA, print acc of each class.
        if args.net[0:3] == 'res' or args.net[0:3] == 'vgg':
            acc, acc_list, acc_cls_avg = cal_acc(dset_loaders['test'], netF=netF, netB=netB, netC=netC,
                                                 per_class_flag=True, visda_flag=True)
            log_str = '\nTraining: {}, Task: {}, Accuracy = {:.2f}%, Cls Avg Acc = {:.2f}%'.format(args.trte,
                                                                args.name, acc, acc_cls_avg) + '\n' + acc_list
    else:
        # For Home, DomainNet, no need to print acc of each class.
        if args.net[0:3] == 'res' or args.net[0:3] == 'vgg':
            acc, acc_cls_avg, _ = cal_acc(dset_loaders['test'], netF=netF, netB=netB, netC=netC,
                                          per_class_flag=True, visda_flag=False)
            log_str = '\nTraining: {}, Task: {}, Accuracy = {:.2f}%, Cls Avg Acc = {:.2f}%'.format(args.trte,
                                                                                    args.name, acc, acc_cls_avg)
    args.out_file.write(log_str)
    args.out_file.flush()
    print(log_str)
예제 #7
0
def train_target(args):
    dset_loaders = data_load(args)
    if args.net[0:3] == 'res':
        netF = network.ResBase(res_name=args.net).cuda()
    elif args.net[0:3] == 'vgg':
        netF = network.VGGBase(vgg_name=args.net).cuda()

    netB = network.feat_bootleneck(type=args.classifier,
                                   feature_dim=netF.in_features,
                                   bottleneck_dim=args.bottleneck).cuda()
    netC = network.feat_classifier(type=args.layer,
                                   class_num=args.class_num,
                                   bottleneck_dim=args.bottleneck).cuda()

    modelpath = args.output_dir_src + '/source_F.pt'
    netF.load_state_dict(torch.load(modelpath))
    modelpath = args.output_dir_src + '/source_B.pt'
    netB.load_state_dict(torch.load(modelpath))
    modelpath = args.output_dir_src + '/source_C.pt'
    netC.load_state_dict(torch.load(modelpath))
    netC.eval()
    for k, v in netC.named_parameters():
        v.requires_grad = False

    param_group = []
    for k, v in netF.named_parameters():
        if args.lr_decay1 > 0:
            param_group += [{'params': v, 'lr': args.lr * args.lr_decay1}]
        else:
            v.requires_grad = False
    for k, v in netB.named_parameters():
        if args.lr_decay2 > 0:
            param_group += [{'params': v, 'lr': args.lr * args.lr_decay2}]
        else:
            v.requires_grad = False

    optimizer = optim.SGD(param_group)
    optimizer = op_copy(optimizer)

    max_iter = args.max_epoch * len(dset_loaders["target"])
    interval_iter = max_iter // args.interval
    iter_num = 0

    iter_sw = int(max_iter / 2.0)

    while iter_num < max_iter:
        try:
            inputs_test, _, tar_idx = iter_test.next()
        except:
            iter_test = iter(dset_loaders["target"])
            inputs_test, _, tar_idx = iter_test.next()

        if inputs_test.size(0) == 1:
            continue

        if iter_num % interval_iter == 0 and args.cls_par > 0:
            netF.eval()
            netB.eval()
            mem_label_soft, mtx_infor_nh, feas_FC = obtain_label(
                dset_loaders['test'], netF, netB, netC, args, iter_num,
                iter_sw)
            mem_label_soft = torch.from_numpy(mem_label_soft).cuda()
            feas_all = feas_FC[0]
            ops_all = feas_FC[1]
            netF.train()
            netB.train()

        inputs_test = inputs_test.cuda()
        iter_num += 1
        lr_scheduler(optimizer, iter_num=iter_num, max_iter=max_iter)

        features_F_self = netF(inputs_test)
        features_F_nh = get_mtx_sam_wgt_nh(feas_all, mtx_infor_nh, tar_idx)
        features_F_nh = features_F_nh.cuda()
        features_F_mix = 0.8 * features_F_self + 0.2 * features_F_nh
        outputs_test_mix = netC(netB(features_F_mix))
        ops_test_self = netC(netB(features_F_self))
        outputs_test_nh = netC(netB(features_F_nh))

        if args.cls_par > 0:
            log_probs = nn.LogSoftmax(dim=1)(outputs_test_mix)
            targets = mem_label_soft[tar_idx]
            loss_soft = (-targets * log_probs).sum(dim=1)
            classifier_loss = loss_soft.mean()

            classifier_loss *= args.cls_par
            if iter_num < interval_iter and args.dset == "VISDA-C":
                classifier_loss *= 0
        else:
            classifier_loss = torch.tensor(0.0).cuda()

        if args.ent:
            softmax_out = nn.Softmax(dim=1)(
                outputs_test_mix)  # outputs_test_mix
            entropy_loss = torch.mean(loss.Entropy(softmax_out))

            if args.gent:
                msoftmax = softmax_out.mean(dim=0)
                gentropy_loss = torch.sum(-msoftmax *
                                          torch.log(msoftmax + args.epsilon))
                entropy_loss -= gentropy_loss
            im_loss = entropy_loss * args.ent_par
            classifier_loss += im_loss

        optimizer.zero_grad()
        classifier_loss.backward()
        optimizer.step()

        if iter_num % interval_iter == 0 or iter_num == max_iter:
            netF.eval()
            netB.eval()
            if args.dset == 'VISDA-C':
                acc_s_te, acc_list = cal_acc(dset_loaders['test'], netF, netB,
                                             netC, True)
                log_str = 'Task: {}, Iter:{}/{}; Accuracy = {:.2f}%'.format(
                    args.name, iter_num, max_iter, acc_s_te) + '\n' + acc_list
            else:
                acc_s_te, _ = cal_acc(dset_loaders['test'], netF, netB, netC,
                                      False)
                log_str = 'Task: {}, Iter:{}/{}; Accuracy = {:.2f}%'.format(
                    args.name, iter_num, max_iter, acc_s_te)

            args.out_file.write(log_str + '\n')
            args.out_file.flush()
            print(log_str + '\n')
            netF.train()
            netB.train()

    if args.issave:
        torch.save(
            netF.state_dict(),
            osp.join(args.output_dir, "target_F_" + args.savename + ".pt"))
        torch.save(
            netB.state_dict(),
            osp.join(args.output_dir, "target_B_" + args.savename + ".pt"))
        torch.save(
            netC.state_dict(),
            osp.join(args.output_dir, "target_C_" + args.savename + ".pt"))

    return netF, netB, netC
예제 #8
0
def train_target(args):
    dset_loaders, txt_tar = data_load(args)
    dsets = dict()
    ## set base network
    if args.net[0:3] == 'res':
        netF = network.ResBase(res_name=args.net).cuda()
    elif args.net[0:3] == 'vgg':
        netF = network.VGGBase(vgg_name=args.net).cuda()

    netB = network.feat_bootleneck(type=args.classifier, feature_dim=netF.in_features,
                                   bottleneck_dim=args.bottleneck).cuda()
    netC = network.feat_classifier(type=args.layer, class_num=args.class_num, bottleneck_dim=args.bottleneck).cuda()

    modelpath = args.output_dir_src + '/source_F.pt'
    netF.load_state_dict(torch.load(modelpath))
    modelpath = args.output_dir_src + '/source_B.pt'
    netB.load_state_dict(torch.load(modelpath))
    modelpath = args.output_dir_src + '/source_C.pt'
    netC.load_state_dict(torch.load(modelpath))

    param_group = []
    for k, v in netF.named_parameters():
        if args.lr_decay1 > 0:
            param_group += [{'params': v, 'lr': args.lr * args.lr_decay1}]
        else:
            v.requires_grad = False
    for k, v in netB.named_parameters():
        if args.lr_decay2 > 0:
            param_group += [{'params': v, 'lr': args.lr * args.lr_decay2}]
        else:
            v.requires_grad = False
    for k, v in netC.named_parameters():
        v.requires_grad = False

    optimizer = optim.SGD(param_group)
    optimizer = op_copy(optimizer)

    max_iter = args.max_epoch * len(dset_loaders["target"])
    iter_per_epoch = len(dset_loaders["target"])
    print("Iter per epoch: {}".format(iter_per_epoch))
    interval_iter = max_iter // args.interval
    iter_num = 0

    if args.paral:
        netF = torch.nn.DataParallel(netF)
        netB = torch.nn.DataParallel(netB)
        netC = torch.nn.DataParallel(netC)

    netF.train()
    netB.train()
    netC.eval()

    if args.scd_lamb:
        scd_lamb_init = args.scd_lamb   # specify hyperparameter for secondary label correcion manually
    else:
        if args.dset[0:5] == "VISDA" :
            scd_lamb_init = 0.1
        elif args.dset[0:11] == "office-home":
            scd_lamb_init = 0.2
            if args.s == 3 and args.t == 2:
                scd_lamb_init *= 0.1
        elif args.dset[0:9] == "domainnet":
            scd_lamb_init = 0.02

    scd_lamb = scd_lamb_init

    while iter_num < max_iter:
        k = 1.0
        k_s = 0.6
        if iter_num % interval_iter == 0 and args.cls_par > 0:  # interval_itr = itr per epoch
            netF.eval()
            netB.eval()

            label_prob_dict = obtain_label(dset_loaders['test'], netF, netB, netC, args)
            mem_label, pseudo_lb_prob = label_prob_dict['primary_lb'], label_prob_dict['primary_lb_prob']
            mem_label, pseudo_lb_prob = torch.from_numpy(mem_label).cuda(), torch.from_numpy(pseudo_lb_prob).cuda()
            if args.scd_label:
                second_label, second_prob = label_prob_dict['secondary_lb'], label_prob_dict['secondary_lb_prob']
                second_label, second_prob = torch.from_numpy(second_label).cuda(), torch.from_numpy(second_prob).cuda()
            if args.third_label:
                third_label, third_prob = label_prob_dict['third_lb'], label_prob_dict['third_lb_prob']
                third_label, third_prob = torch.from_numpy(third_label).cuda(), torch.from_numpy(third_prob).cuda()
            if args.fourth_label:
                fourth_label, fourth_prob = label_prob_dict['fourth_lb'], label_prob_dict['fourth_lb_prob']
                fourth_label, fourth_prob = torch.from_numpy(fourth_label).cuda(), torch.from_numpy(fourth_prob).cuda()
            if args.topk_ent:
                all_entropy = label_prob_dict['entropy']
                all_entropy = torch.from_numpy(all_entropy)

            if args.dset[0:5] == "VISDA" :
                if iter_num // iter_per_epoch < 1:
                    k = 0.6
                elif iter_num // iter_per_epoch < 2:
                    k = 0.7
                elif iter_num // iter_per_epoch < 3:
                    k = 0.8
                elif iter_num // iter_per_epoch < 4:
                    k = 0.9
                else:
                    k = 1.0

                if iter_num // iter_per_epoch >= 8:
                    scd_lamb *= 0.1

            elif args.dset[0:11] == "office-home" or args.dset[0:9] == "domainnet":
                if iter_num // iter_per_epoch < 2:
                    k = 0.2
                elif iter_num // iter_per_epoch < 4:
                    k = 0.4
                elif iter_num // iter_per_epoch < 8:
                    k = 0.6
                elif iter_num // iter_per_epoch < 12:
                    k = 0.8

            if args.topk:
                dsets["target"] = ImageList_PA(txt_tar, mem_label, pseudo_lb_prob, k_low=k, k_up=None,
                                                transform=image_train())
                dset_loaders["target"] = DataLoader(dsets["target"], batch_size=args.batch_size, shuffle=True,
                                                    num_workers=args.worker, drop_last=False)

            if args.topk_ent:
                dsets["target"] = ImageList_PA(txt_tar, mem_label, -1.0 * all_entropy, k_low=k, k_up=None,
                                                transform=image_train())
                dset_loaders["target"] = DataLoader(dsets["target"], batch_size=args.batch_size, shuffle=True,
                                                    num_workers=args.worker, drop_last=False)

            if args.scd_label:
                # 2nd label threshold: prob top 60%
                dsets["target_scd"] = ImageList_PA(txt_tar, second_label, second_prob, k_low=k_s, k_up=None,
                                                    transform=image_train())
                dset_loaders["target_scd"] = DataLoader(dsets["target_scd"], batch_size=args.batch_size, shuffle=True,
                                                        num_workers=args.worker, drop_last=False)

            if args.third_label:
                # 3rd label threshold: prob top 60%
                dsets["target_third"] = ImageList_PA(txt_tar, third_label, third_prob, k_low=k_s, k_up=None,
                                                    transform=image_train())
                dset_loaders["target_third"] = DataLoader(dsets["target_third"], batch_size=args.batch_size, shuffle=True,
                                                        num_workers=args.worker, drop_last=False)
            if args.fourth_label:
                # 4th label threshold: prob top 60%
                dsets["target_fourth"] = ImageList_PA(txt_tar, fourth_label, fourth_prob, k_low=k_s, k_up=None,
                                                      transform=image_train())
                dset_loaders["target_fourth"] = DataLoader(dsets["target_fourth"], batch_size=args.batch_size,
                                                          shuffle=True,
                                                          num_workers=args.worker, drop_last=False)
            netF.train()
            netB.train()

        try:
            inputs_test, _, tar_idx = iter_test.next()
        except:
            iter_test = iter(dset_loaders["target"])
            inputs_test, _, tar_idx = iter_test.next()  # tar_idx: chosen indices in current itr

        if inputs_test.size(0) == 1:
            continue

        if args.scd_label:
            try:
                inputs_test_scd, _, tar_idx_scd = iter_test_scd.next()
            except:
                iter_test_scd = iter(dset_loaders["target_scd"])
                inputs_test_scd, _, tar_idx_scd = iter_test_scd.next()

            if inputs_test_scd.size(0) == 1:
                continue

        if args.third_label:
            try:
                inputs_test_third, _, tar_idx_third = iter_test_third.next()
            except:
                iter_test_third = iter(dset_loaders["target_third"])
                inputs_test_third, _, tar_idx_third = iter_test_third.next()

            if inputs_test_third.size(0) == 1:
                continue

        if args.fourth_label:
            try:
                inputs_test_fourth, _, tar_idx_fourth = iter_test_fourth.next()
            except:
                iter_test_fourth = iter(dset_loaders["target_fourth"])
                inputs_test_fourth, _, tar_idx_fourth = iter_test_fourth.next()

            if inputs_test_fourth.size(0) == 1:
                continue

        iter_num += 1

        inputs_test = inputs_test.cuda()
        features_test = netB(netF(inputs_test))
        outputs_test = netC(features_test)

        if args.scd_label:
            inputs_test_scd = inputs_test_scd.cuda()
            if inputs_test_scd.ndim == 3:
                inputs_test_scd = inputs_test_scd.unsqueeze(0)

            features_test_scd = netB(netF(inputs_test_scd))
            outputs_test_scd = netC(features_test_scd)

            first_prob_of_scd = pseudo_lb_prob[tar_idx_scd]
            scd_prob = second_prob[tar_idx_scd]
            if not args.no_mask:
                mask = (scd_prob / first_prob_of_scd.float()).clamp(max=1.0)
            else:
                mask = torch.ones_like(scd_prob).cuda()

        if args.third_label:
            inputs_test_third = inputs_test_third.cuda()
            if inputs_test_third.ndim == 3:
                inputs_test_third = inputs_test_third.unsqueeze(0)

            features_test_third = netB(netF(inputs_test_third))
            outputs_test_third = netC(features_test_third)

            first_prob_of_third = pseudo_lb_prob[tar_idx_third]
            thi_prob = third_prob[tar_idx_third]

            mask_third = (thi_prob / first_prob_of_third.float()).clamp(max=1.0)

        if args.fourth_label:
            inputs_test_fourth = inputs_test_fourth.cuda()
            if inputs_test_fourth.ndim == 3:
                inputs_test_fourth = inputs_test_fourth.unsqueeze(0)

            features_test_fourth = netB(netF(inputs_test_fourth))
            outputs_test_fourth = netC(features_test_fourth)

            first_prob_of_fourth = pseudo_lb_prob[tar_idx_fourth]
            fth_prob = fourth_prob[tar_idx_fourth]

            mask_fourth = (fth_prob / first_prob_of_fourth.float()).clamp(max=1.0)

        if args.intra_dense or args.inter_sep:
            intra_dist = torch.zeros(1).cuda()
            inter_dist = torch.zeros(1).cuda()
            pred = mem_label[tar_idx]
            same_first = True
            diff_first = True
            cos = nn.CosineSimilarity(dim=1, eps=1e-6)
            for i in range(pred.size(0)):
                for j in range(i, pred.size(0)):
                    # dist = torch.norm(features_test[i] - features_test[j])
                    dist = 0.5 * (1 - cos(features_test[i].unsqueeze(0), features_test[j].unsqueeze(0)))
                    if pred[i].item() == pred[j].item():
                        if same_first:
                            intra_dist = dist.unsqueeze(0)
                            same_first = False
                        else:
                            intra_dist = torch.cat((intra_dist, dist.unsqueeze(0)))

                    else:
                        if diff_first:
                            inter_dist = dist.unsqueeze(0)
                            diff_first = False
                        else:
                            inter_dist = torch.cat((inter_dist, dist.unsqueeze(0)))

            intra_dist = torch.mean(intra_dist)
            inter_dist = torch.mean(inter_dist)

        if args.cls_par > 0:
            pred = mem_label[tar_idx]
            classifier_loss = nn.CrossEntropyLoss()(outputs_test, pred)  # self-train by pseudo label
            classifier_loss *= args.cls_par

            if args.scd_label:
                pred_scd = second_label[tar_idx_scd]
                classifier_loss_scd = nn.CrossEntropyLoss(reduction='none')(outputs_test_scd,
                                                                            pred_scd)  # self-train by pseudo label

                classifier_loss_scd = torch.mean(mask * classifier_loss_scd)

                classifier_loss_scd *= args.cls_par

                classifier_loss += classifier_loss_scd * scd_lamb

            if args.third_label:
                pred_third = third_label[tar_idx_third]
                classifier_loss_third = nn.CrossEntropyLoss(reduction='none')(outputs_test_third,
                                                                            pred_third)  # self-train by pseudo label

                classifier_loss_third = torch.mean(mask_third * classifier_loss_third)

                classifier_loss_third *= args.cls_par

                classifier_loss += classifier_loss_third * scd_lamb    # TODO: better weighting is possible

            if args.fourth_label:
                pred_fourth = fourth_label[tar_idx_fourth]
                classifier_loss_fourth = nn.CrossEntropyLoss(reduction='none')(outputs_test_fourth,
                                                                            pred_fourth)  # self-train by pseudo label

                classifier_loss_fourth = torch.mean(mask_fourth * classifier_loss_fourth)

                classifier_loss_fourth *= args.cls_par

                classifier_loss += classifier_loss_fourth * scd_lamb    # TODO: better weighting is possible

            if iter_num < interval_iter and (args.dset == "VISDA-C" or args.dset == "VISDA-RSUT" or args.dset == 'VISDA-RSUT-50' or args.dset == 'VISDA-RSUT-10'):
                classifier_loss *= 0
        else:
            classifier_loss = torch.tensor(0.0).cuda()

        if args.ent:
            softmax_out = nn.Softmax(dim=1)(outputs_test)
            entropy_loss = torch.mean(loss.Entropy(softmax_out))  # Minimize local entropy

            if args.scd_label:
                softmax_out_scd = nn.Softmax(dim=1)(outputs_test_scd)
                entropy_loss_scd = torch.mean(mask * loss.Entropy(softmax_out_scd))  # Minimize local entropy

            if args.third_label:
                softmax_out_third = nn.Softmax(dim=1)(outputs_test_third)
                entropy_loss_third = torch.mean(mask_third * loss.Entropy(softmax_out_third))  # Minimize local entropy

            if args.fourth_label:
                softmax_out_fourth = nn.Softmax(dim=1)(outputs_test_fourth)
                entropy_loss_fourth = torch.mean(mask_fourth * loss.Entropy(softmax_out_fourth))  # Minimize local entropy

            if args.gent:
                msoftmax = softmax_out.mean(dim=0)
                gentropy_loss = torch.sum(-msoftmax * torch.log(msoftmax + args.epsilon))
                entropy_loss -= gentropy_loss  # Maximize global entropy

                if args.scd_label:
                    msoftmax_scd = softmax_out_scd.mean(dim=0)
                    gentropy_loss_scd = torch.sum(-msoftmax_scd * torch.log(msoftmax_scd + args.epsilon))
                    entropy_loss_scd -= gentropy_loss_scd  # Maximize global entropy
                if args.third_label:
                    msoftmax_third = softmax_out_third.mean(dim=0)
                    gentropy_loss_third = torch.sum(-msoftmax_third * torch.log(msoftmax_third + args.epsilon))
                    entropy_loss_third -= gentropy_loss_third  # Maximize global entropy
                if args.fourth_label:
                    msoftmax_fourth = softmax_out_fourth.mean(dim=0)
                    gentropy_loss_fourth = torch.sum(-msoftmax_fourth * torch.log(msoftmax_fourth + args.epsilon))
                    entropy_loss_fourth -= gentropy_loss_fourth  # Maximize global entropy

            im_loss = entropy_loss * args.ent_par
            if args.scd_label:
                im_loss += entropy_loss_scd * args.ent_par * scd_lamb
            if args.third_label:
                im_loss += entropy_loss_third * args.ent_par * scd_lamb    # TODO: better weighting is possible
            if args.fourth_label:
                im_loss += entropy_loss_fourth * args.ent_par * scd_lamb   # TODO: better weighting is possible

            classifier_loss += im_loss

        if args.intra_dense:
            classifier_loss += args.lamb_intra * intra_dist.squeeze()
        if args.inter_sep:
            classifier_loss += args.lamb_inter * inter_dist.squeeze()

        optimizer.zero_grad()
        classifier_loss.backward()
        optimizer.step()
        lr_scheduler(optimizer, iter_num=iter_num, max_iter=max_iter)

        if iter_num % interval_iter == 0 or iter_num == max_iter:
            netF.eval()
            netB.eval()
            if args.dset == 'VISDA-RSUT' or args.dset == 'VISDA-RSUT-50' or args.dset == 'VISDA-RSUT-10':
                # For VisDA, print the acc of each cls
                acc_s_te, acc_list, acc_cls_avg = cal_acc(dset_loaders['test'], netF, netB, netC, visda_flag=True)    # flag for VisDA -> need cls avg acc.
                log_str = 'Task: {}, Iter:{}/{}; Accuracy = {:.2f}%, Cls Avg Acc = {:.2f}%'.format(args.name, iter_num, max_iter,
                                                                            acc_s_te, acc_cls_avg) + '\n' + acc_list
            else:
                # In imbalanced setting, use per-class avg acc as metric
                # For Office-Home, DomainNet, no need to print the acc of each cls
                acc_s_te, acc_cls_avg, _ = cal_acc(dset_loaders['test'], netF, netB, netC, visda_flag=False)
                log_str = 'Task: {}, Iter:{}/{}; Accuracy = {:.2f}%, Cls Avg Acc = {:.2f}%'.format(args.name, iter_num,
                                                                            max_iter, acc_s_te, acc_cls_avg)

            args.out_file.write(log_str + '\n')
            args.out_file.flush()
            print(log_str + '\n')
            netF.train()
            netB.train()

    if args.issave:
        torch.save(netF.state_dict(), osp.join(args.output_dir, "target_F_" + args.savename + ".pt"))
        torch.save(netB.state_dict(), osp.join(args.output_dir, "target_B_" + args.savename + ".pt"))
        torch.save(netC.state_dict(), osp.join(args.output_dir, "target_C_" + args.savename + ".pt"))

    return netF, netB, netC
예제 #9
0
def train_target(args):
    dset_loaders = digit_load(args)
    ## set base network
    if args.dset == 'u2m':
        netF = network.LeNetBase().cuda()
    elif args.dset == 'm2u':
        netF = network.LeNetBase().cuda()
    elif args.dset == 's2m':
        netF = network.DTNBase().cuda()

    netB = network.feat_bootleneck(type=args.classifier,
                                   feature_dim=netF.in_features,
                                   bottleneck_dim=args.bottleneck).cuda()
    netC = network.feat_classifier(type=args.layer,
                                   class_num=args.class_num,
                                   bottleneck_dim=args.bottleneck).cuda()

    args.modelpath = args.output_dir + '/source_F_val.pt'
    netF.load_state_dict(torch.load(args.modelpath))
    args.modelpath = args.output_dir + '/source_B_val.pt'
    netB.load_state_dict(torch.load(args.modelpath))
    args.modelpath = args.output_dir + '/source_C_val.pt'
    netC.load_state_dict(torch.load(args.modelpath))
    # 只设置netC为测试模式,也就是只设置判别器为测试模型,这样在训练整个模型的时候,netC的参数就不会变化,这样,就保证了F=g.h中的h不变了
    netC.eval()
    for k, v in netC.named_parameters():
        v.requires_grad = False

    param_group = []
    for k, v in netF.named_parameters():
        param_group += [{'params': v, 'lr': args.lr}]
    for k, v in netB.named_parameters():
        param_group += [{'params': v, 'lr': args.lr}]
    optimizer = optim.SGD(param_group,
                          momentum=0.9,
                          weight_decay=5e-4,
                          nesterov=True)
    # scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)

    for epoch in tqdm(range(args.max_epoch), leave=False):
        # 设置前面的为训练模式
        netF.train()
        netB.train()
        iter_test = iter(dset_loaders["target"])

        # 在这里还保存了源域的g模型,并且设置为测试模式
        # 注意,这里拷贝的是上次epoch的模型,而不是源域模型不变
        prev_F = copy.deepcopy(netF)
        prev_B = copy.deepcopy(netB)
        prev_F.eval()
        prev_B.eval()

        # 获取质心,对应论文里面的第一个公式
        center = obtain_center(dset_loaders['target'], prev_F, prev_B, netC,
                               args)

        for _, (inputs_test, _) in tqdm(enumerate(iter_test), leave=False):
            if inputs_test.size(0) == 1:
                continue
            inputs_test = inputs_test.cuda()
            with torch.no_grad():
                # 注意,是每进行一个数据batch的iteration就预测label一次,
                # 另外,无论iteration多少次,他们预测label使用的模型都是上次epoch使用的模型
                # 下面这两句对应论文里面的第二个公式
                # todo li 论文里面还有第三个和第四个公式,怎么没看到在哪啊。
                features_test = prev_B(prev_F(inputs_test))
                pred = obtain_label(features_test, center)

            # 这里是正常的数据经过网络
            features_test = netB(netF(inputs_test))
            outputs_test = netC(features_test)
            # 计算损失
            classifier_loss = CrossEntropyLabelSmooth(
                num_classes=args.class_num, epsilon=0)(outputs_test, pred)

            # 这里计算IM loss
            # 这里计算的是softmax的输出,对dim=1进行softmax
            softmax_out = nn.Softmax(dim=1)(outputs_test)
            # 这个entropy计算的是-sum(softmax_out*log(softmax_out),dim=1),得到每个batch的概率之和
            # 然后再进行一个mean,相当于是计算出来了概率之和相对于每个batch的平均值
            # 这个im_loss对应论文里面的Lent
            im_loss = torch.mean(Entropy(softmax_out))
            # msoftmax计算出来了每个类别的概率,batch的平均值,这个对应论文里面的p^k
            msoftmax = softmax_out.mean(dim=0)
            # 这里的这个-=配合sum里面的负号,就是+=. 这里是求K个的平均值,对应论文里面的Ldiv
            im_loss -= torch.sum(-msoftmax * torch.log(msoftmax + 1e-5))
            # args.par在这里用到了,是权衡IM loss和classifier_loss的超参数
            total_loss = im_loss + args.par * classifier_loss

            optimizer.zero_grad()
            total_loss.backward()
            optimizer.step()

        netF.eval()
        netB.eval()
        acc, _ = cal_acc(dset_loaders['test'], netF, netB, netC)
        log_str = 'Task: {}, Iter:{}/{}; Accuracy = {:.2f}%'.format(
            args.dset, epoch + 1, args.max_epoch, acc * 100)
        args.out_file.write(log_str + '\n')
        args.out_file.flush()
        print(log_str + '\n')

    # torch.save(netF.state_dict(), osp.join(args.output_dir, "target_F.pt"))
    # torch.save(netB.state_dict(), osp.join(args.output_dir, "target_B.pt"))
    # torch.save(netC.state_dict(), osp.join(args.output_dir, "target_C.pt"))
    return netF, netB, netC
예제 #10
0
def train_target(args):
    dset_loaders = data_load(args)
    ## set base network
    if args.net[0:3] == 'res':
        netF = network.ResBase(res_name=args.net).cuda()
    elif args.net[0:3] == 'vgg':
        netF = network.VGGBase(vgg_name=args.net).cuda()

    netB = network.feat_bootleneck(type=args.classifier,
                                   feature_dim=netF.in_features,
                                   bottleneck_dim=args.bottleneck).cuda()
    netC = network.feat_classifier(type=args.layer,
                                   class_num=args.class_num,
                                   bottleneck_dim=args.bottleneck).cuda()

    wandb.watch(netB)
    wandb.watch(netC)
    wandb.watch(netF)

    modelpath = args.output_dir_src + '/source_F.pt'
    netF.load_state_dict(torch.load(modelpath))
    modelpath = args.output_dir_src + '/source_B.pt'
    netB.load_state_dict(torch.load(modelpath))
    modelpath = args.output_dir_src + '/source_C.pt'
    netC.load_state_dict(torch.load(modelpath))
    netC.eval()
    for k, v in netC.named_parameters():
        v.requires_grad = False

    param_group = []
    for k, v in netF.named_parameters():
        if args.lr_decay1 > 0:
            param_group += [{'params': v, 'lr': args.lr * args.lr_decay1}]
        else:
            v.requires_grad = False
    for k, v in netB.named_parameters():
        if args.lr_decay2 > 0:
            param_group += [{'params': v, 'lr': args.lr * args.lr_decay2}]
        else:
            v.requires_grad = False

    optimizer = optim.SGD(param_group)
    optimizer = op_copy(optimizer)

    max_iter = args.max_epoch * len(dset_loaders["target"])
    interval_iter = max_iter // args.interval
    iter_num = 0

    while iter_num < max_iter:
        try:
            inputs_test, _, tar_idx = iter_test.next()
        except:
            iter_test = iter(dset_loaders["target"])
            inputs_test, _, tar_idx = iter_test.next()

        if inputs_test.size(0) == 1:
            continue

        if iter_num % interval_iter == 0 and args.cls_par > 0:
            netF.eval()
            netB.eval()
            # pseudolabels
            mem_label = obtain_label(dset_loaders['test'], netF, netB, netC,
                                     args)
            mem_label = torch.from_numpy(mem_label).cuda()
            netF.train()
            netB.train()

        inputs_test = inputs_test.cuda()

        iter_num += 1
        lr_scheduler(optimizer, iter_num=iter_num, max_iter=max_iter)

        features_test = netB(netF(inputs_test))
        outputs_test = netC(features_test)

        if args.cls_par > 0:
            pred = mem_label[tar_idx]
            # cross-entropy loss between
            classifier_loss = nn.CrossEntropyLoss()(outputs_test, pred)
            classifier_loss *= args.cls_par
            if iter_num < interval_iter and (args.dset == "VISDA18"
                                             or args.dset == 'VISDA-C'):
                classifier_loss *= 0
        else:
            classifier_loss = torch.tensor(0.0).cuda()

        if args.ent:
            softmax_out = nn.Softmax(dim=1)(outputs_test)
            entropy_loss = torch.mean(
                loss.Entropy(softmax_out))  # expectation over samples drawn
            if args.gent:
                msoftmax = softmax_out.mean(dim=0)
                gentropy_loss = torch.sum(-msoftmax *
                                          torch.log(msoftmax + args.epsilon))
                entropy_loss -= gentropy_loss
            im_loss = entropy_loss * args.ent_par
            classifier_loss += im_loss

        wandb.log({
            "target_classifier_loss": classifier_loss,
            "info_max_loss": im_loss
        })

        optimizer.zero_grad()
        classifier_loss.backward()
        optimizer.step()

        if iter_num % interval_iter == 0 or iter_num == max_iter:
            netF.eval()
            netB.eval()
            if args.dset == 'VISDA18' or args.dset == 'VISCA-C':
                # if args.dset=='VISDA-C':
                acc_s_te, acc_list = cal_acc(dset_loaders['test'], netF, netB,
                                             netC, True)
                log_str = 'Task: {}, Iter:{}/{}; Accuracy = {:.2f}%'.format(
                    args.name, iter_num, max_iter, acc_s_te) + '\n' + acc_list
                wandb.log({"accuracy": acc_s_te})  # true test accuracy
            else:
                acc_s_te, _ = cal_acc(dset_loaders['test'], netF, netB, netC,
                                      False)
                log_str = 'Task: {}, Iter:{}/{}; Accuracy = {:.2f}%'.format(
                    args.name, iter_num, max_iter, acc_s_te)

            args.out_file.write(log_str + '\n')
            args.out_file.flush()
            print(log_str + '\n')
            netF.train()
            netB.train()

    if args.issave:
        torch.save(
            netF.state_dict(),
            osp.join(args.output_dir, "target_F_" + args.savename + ".pt"))
        torch.save(
            netB.state_dict(),
            osp.join(args.output_dir, "target_B_" + args.savename + ".pt"))
        torch.save(
            netC.state_dict(),
            osp.join(args.output_dir, "target_C_" + args.savename + ".pt"))

    return netF, netB, netC
def train(args):

    ent_loss_record = []
    gent_loss_record = []
    sent_loss_record = []
    total_loss_record = []

    dset_loaders = digit_load(args)
    ## set base network
    if args.dset == 'u':
        netF = network.LeNetBase()  #.cuda()
    elif args.dset == 'm':
        netF = network.LeNetBase()  #.cuda()
    elif args.dset == 's':
        netF = network.DTNBase()  #.cuda()

    netB = network.feat_bootleneck(type=args.classifier,
                                   feature_dim=netF.in_features,
                                   bottleneck_dim=args.bottleneck)  #.cuda()
    netC = network.feat_classifier(type=args.layer,
                                   class_num=args.class_num,
                                   bottleneck_dim=args.bottleneck)  #.cuda()

    param_group = []
    learning_rate = args.lr
    for k, v in netF.named_parameters():
        param_group += [{'params': v, 'lr': learning_rate}]
    for k, v in netB.named_parameters():
        param_group += [{'params': v, 'lr': learning_rate}]
    for k, v in netC.named_parameters():
        param_group += [{'params': v, 'lr': learning_rate}]

    optimizer = optim.SGD(param_group)
    optimizer = op_copy(optimizer)

    acc_init = 0
    max_iter = args.max_epoch * len(dset_loaders["train"])
    interval_iter = max_iter // 10
    iter_num = 0

    netF.train()
    netB.train()
    netC.train()

    while iter_num < max_iter:
        try:
            inputs_source, strong_inputs, target = iter_source.next()
        except:
            iter_source = iter(dset_loaders["train"])
            inputs_source, strong_inputs, target = iter_source.next()

        if inputs_source.size(0) == 1:
            continue

        iter_num += 1
        lr_scheduler(optimizer, iter_num=iter_num, max_iter=max_iter)

        inputs_source = inputs_source  #.cuda()
        outputs_source = netC(netB(netF(inputs_source)))

        total_loss = torch.tensor(0.0)  #.cuda()
        softmax_out = nn.Softmax(dim=1)(outputs_source)
        if args.ent:
            ent_loss = torch.mean(loss.Entropy(softmax_out))
            total_loss += ent_loss
            ent_loss_record.append(ent_loss.detach().cpu())

        if args.gent:
            msoftmax = softmax_out.mean(dim=0)
            gent_loss = -torch.sum(-msoftmax * torch.log(msoftmax + 1e-5))
            gent_loss_record.append(gent_loss.detach().cpu())
            total_loss += gent_loss

        if args.sent:
            sent_loss = compute_aug_loss(strong_inputs, target, netC, netB,
                                         netF)
            total_loss += sent_loss
            sent_loss_record.append(sent_loss.detach().cpu())

        optimizer.zero_grad()
        total_loss.backward()
        optimizer.step()
        total_loss_record.append(total_loss.detach().cpu())

        if iter_num % interval_iter == 0 or iter_num == max_iter:
            print(iter_num, interval_iter, max_iter)
        #     netF.eval()
        #     netB.eval()
        #     netC.eval()
        #     acc_s_tr, _ = cal_acc(dset_loaders['train'], netF, netB, netC)
        #     acc_s_te, _ = cal_acc(dset_loaders['test'], netF, netB, netC)
        #     log_str = 'Task: {}, Iter:{}/{}; Accuracy = {:.2f}%/ {:.2f}%'.format(args.dset, iter_num, max_iter, acc_s_tr, acc_s_te)
        #     args.out_file.write(log_str + '\n')
        #     args.out_file.flush()
        #     print(log_str+'\n')

        #     if acc_s_te >= acc_init:
        #         acc_init = acc_s_te
        #         best_netF = netF.state_dict()
        #         best_netB = netB.state_dict()
        #         best_netC = netC.state_dict()

        #     netF.train()
        #     netB.train()
        #     netC.train()

    best_netF = netF.state_dict()
    best_netB = netB.state_dict()
    best_netC = netC.state_dict()

    torch.save(best_netF, osp.join(args.output_dir, "F.pt"))
    torch.save(best_netB, osp.join(args.output_dir, "B.pt"))
    torch.save(best_netC, osp.join(args.output_dir, "C.pt"))

    fig, (ax1, ax2, ax3, ax4) = plt.subplots(nrows=4,
                                             sharex=True,
                                             figsize=(16, 8))
    ax1.plot(list(range(len(ent_loss_record))), ent_loss_record, 'r')
    ax2.plot(list(range(len(gent_loss_record))), gent_loss_record, 'g')
    ax3.plot(list(range(len(sent_loss_record))), sent_loss_record, 'b')
    ax4.plot(list(range(len(total_loss_record))), total_loss_record, 'm')
    plt.tight_layout()
    plt.savefig(args.output_dir + '/loss.png')

    return netF, netB, netC
예제 #12
0
def train_source(args):
    dset_loaders = digit_load(args)
    ## set base network
    if args.dset == 'u2m':
        netF = network.LeNetBase().cuda()
    elif args.dset == 'm2u':
        netF = network.LeNetBase().cuda()
    elif args.dset == 's2m':
        netF = network.DTNBase().cuda()

    netB = network.feat_bootleneck(type=args.classifier,
                                   feature_dim=netF.in_features,
                                   bottleneck_dim=args.bottleneck).cuda()
    netC = network.feat_classifier(type=args.layer,
                                   class_num=args.class_num,
                                   bottleneck_dim=args.bottleneck).cuda()

    param_group = []
    learning_rate = args.lr
    for k, v in netF.named_parameters():
        param_group += [{'params': v, 'lr': learning_rate}]
    for k, v in netB.named_parameters():
        param_group += [{'params': v, 'lr': learning_rate}]
    for k, v in netC.named_parameters():
        param_group += [{'params': v, 'lr': learning_rate}]

    optimizer = optim.SGD(param_group)
    optimizer = op_copy(optimizer)

    acc_init = 0
    max_iter = args.max_epoch * len(dset_loaders["source_tr"])
    interval_iter = max_iter // 10
    iter_num = 0

    netF.train()
    netB.train()
    netC.train()

    while iter_num < max_iter:
        try:
            inputs_source, labels_source = iter_source.next()
        except:
            iter_source = iter(dset_loaders["source_tr"])
            inputs_source, labels_source = iter_source.next()

        if inputs_source.size(0) == 1:
            continue

        iter_num += 1
        lr_scheduler(optimizer, iter_num=iter_num, max_iter=max_iter)

        inputs_source, labels_source = inputs_source.cuda(
        ), labels_source.cuda()
        outputs_source = netC(netB(netF(inputs_source)))
        classifier_loss = loss.CrossEntropyLabelSmooth(
            num_classes=args.class_num, epsilon=args.smooth)(outputs_source,
                                                             labels_source)
        optimizer.zero_grad()
        classifier_loss.backward()
        optimizer.step()

        if iter_num % interval_iter == 0 or iter_num == max_iter:
            netF.eval()
            netB.eval()
            netC.eval()
            acc_s_tr, _ = cal_acc(dset_loaders['source_tr'], netF, netB, netC)
            acc_s_te, _ = cal_acc(dset_loaders['source_te'], netF, netB, netC)
            log_str = 'Task: {}, Iter:{}/{}; Accuracy = {:.2f}%/ {:.2f}%'.format(
                args.dset, iter_num, max_iter, acc_s_tr, acc_s_te)
            args.out_file.write(log_str + '\n')
            args.out_file.flush()
            print(log_str + '\n')

            if acc_s_te >= acc_init:
                acc_init = acc_s_te
                best_netF = netF.state_dict()
                best_netB = netB.state_dict()
                best_netC = netC.state_dict()

            netF.train()
            netB.train()
            netC.train()

    torch.save(best_netF, osp.join(args.output_dir, "source_F.pt"))
    torch.save(best_netB, osp.join(args.output_dir, "source_B.pt"))
    torch.save(best_netC, osp.join(args.output_dir, "source_C.pt"))

    return netF, netB, netC
예제 #13
0
def split_target(args):
    train_bs = args.batch_size
    if args.dset == 's2m':
        train_target = mnist.MNIST(
            './data/mnist/',
            train=True,
            download=True,
            transform=transforms.Compose([
                transforms.Resize(32),
                transforms.Lambda(lambda x: x.convert("RGB")),
                transforms.ToTensor(),
                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
            ]))
        train_target2 = mnist.MNIST_twice(
            './data/mnist/',
            train=True,
            download=True,
            transform=transforms.Compose([
                transforms.Resize(32),
                transforms.Lambda(lambda x: x.convert("RGB")),
                transforms.ToTensor(),
                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
            ]))
        test_target = mnist.MNIST(
            './data/mnist/',
            train=False,
            download=True,
            transform=transforms.Compose([
                transforms.Resize(32),
                transforms.Lambda(lambda x: x.convert("RGB")),
                transforms.ToTensor(),
                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
            ]))
    elif args.dset == 'u2m':
        train_target = mnist.MNIST('./data/mnist/',
                                   train=True,
                                   download=True,
                                   transform=transforms.Compose([
                                       transforms.ToTensor(),
                                       transforms.Normalize((0.5, ), (0.5, ))
                                   ]))
        train_target2 = mnist.MNIST_twice('./data/mnist/',
                                          train=True,
                                          download=True,
                                          transform=transforms.Compose([
                                              transforms.ToTensor(),
                                              transforms.Normalize((0.5, ),
                                                                   (0.5, ))
                                          ]))
        test_target = mnist.MNIST('./data/mnist/',
                                  train=False,
                                  download=True,
                                  transform=transforms.Compose([
                                      transforms.ToTensor(),
                                      transforms.Normalize((0.5, ), (0.5, ))
                                  ]))
    elif args.dset == 'm2u':
        train_target = usps.USPS(
            './data/usps/',
            train=True,
            download=True,
            transform=transforms.Compose([
                transforms.ToTensor(),
                # transforms.Lambda(lambda x: _gaussian_blur(x, sigma=0.1)),
                transforms.Normalize((0.5, ), (0.5, ))
            ]))
        train_target2 = usps.USPS_twice(
            './data/usps/',
            train=True,
            download=True,
            transform=transforms.Compose([
                transforms.ToTensor(),
                # transforms.Lambda(lambda x: _gaussian_blur(x, sigma=0.1)),
                transforms.Normalize((0.5, ), (0.5, ))
            ]))
        test_target = usps.USPS(
            './data/usps/',
            train=False,
            download=True,
            transform=transforms.Compose([
                transforms.ToTensor(),
                # transforms.Lambda(lambda x: _gaussian_blur(x, sigma=0.1)),
                transforms.Normalize((0.5, ), (0.5, ))
            ]))
    dset_loaders = {}
    dset_loaders["target_te"] = DataLoader(test_target,
                                           batch_size=train_bs,
                                           shuffle=False,
                                           num_workers=args.worker,
                                           drop_last=False)
    dset_loaders["target"] = DataLoader(train_target,
                                        batch_size=train_bs,
                                        shuffle=False,
                                        num_workers=args.worker,
                                        drop_last=False)
    dset_loaders["target2"] = DataLoader(train_target2,
                                         batch_size=train_bs,
                                         shuffle=False,
                                         num_workers=args.worker,
                                         drop_last=False)

    if args.dset == 'u2m':
        netF = network.LeNetBase().cuda()
    elif args.dset == 'm2u':
        netF = network.LeNetBase().cuda()
    elif args.dset == 's2m':
        netF = network.DTNBase().cuda()

    netB = network.feat_bootleneck(type=args.classifier,
                                   feature_dim=netF.in_features,
                                   bottleneck_dim=args.bottleneck).cuda()
    netC = network.feat_classifier(type=args.layer,
                                   class_num=args.class_num,
                                   bottleneck_dim=args.bottleneck).cuda()

    if args.model == 'source':
        modelpath = args.output_dir + "/source_F.pt"
        netF.load_state_dict(torch.load(modelpath))
        modelpath = args.output_dir + "/source_B.pt"
        netB.load_state_dict(torch.load(modelpath))
        modelpath = args.output_dir + "/source_C.pt"
        netC.load_state_dict(torch.load(modelpath))
        pass
    else:
        modelpath = args.output_dir + "/target_F_" + args.savename + ".pt"
        netF.load_state_dict(torch.load(modelpath))
        modelpath = args.output_dir + "/target_B_" + args.savename + ".pt"
        netB.load_state_dict(torch.load(modelpath))
        modelpath = args.output_dir + "/target_C_" + args.savename + ".pt"
        netC.load_state_dict(torch.load(modelpath))

    netF.eval()
    netB.eval()
    netC.eval()

    start_test = True
    with torch.no_grad():
        iter_test = iter(dset_loaders['target_te'])
        for i in range(len(dset_loaders['target_te'])):
            data = iter_test.next()
            # pdb.set_trace()
            inputs = data[0]
            labels = data[1]
            inputs = inputs.cuda()
            outputs = netC(netB(netF(inputs)))
            if start_test:
                all_output = outputs.float().cpu()
                all_label = labels.float()
                start_test = False
            else:
                all_output = torch.cat((all_output, outputs.float().cpu()), 0)
                all_label = torch.cat((all_label, labels.float()), 0)
    top_pred, predict = torch.max(all_output, 1)
    acc = torch.sum(
        torch.squeeze(predict).float() == all_label).item() / float(
            all_label.size()[0]) * 100
    mean_ent = loss.Entropy(nn.Softmax(dim=1)(all_output))
    log_str = 'Task: {}, Iter:{}/{}; Accuracy = {:.2f}%; Mean Ent = {:.4f}'.format(
        args.dset + '_test', 0, 0, acc, mean_ent.mean())
    args.out_file.write(log_str + '\n')
    args.out_file.flush()
    print(log_str + '\n')

    start_test = True
    with torch.no_grad():
        iter_test = iter(dset_loaders['target'])
        for i in range(len(dset_loaders['target'])):
            data = iter_test.next()
            # pdb.set_trace()
            inputs = data[0]
            labels = data[1]
            inputs = inputs.cuda()
            outputs = netC(netB(netF(inputs)))
            if start_test:
                all_output = outputs.float().cpu()
                all_label = labels.float()
                start_test = False
            else:
                all_output = torch.cat((all_output, outputs.float().cpu()), 0)
                all_label = torch.cat((all_label, labels.float()), 0)
    top_pred, predict = torch.max(all_output, 1)
    acc = torch.sum(
        torch.squeeze(predict).float() == all_label).item() / float(
            all_label.size()[0]) * 100
    mean_ent = loss.Entropy(nn.Softmax(dim=1)(all_output))

    log_str = 'Task: {}, Iter:{}/{}; Accuracy = {:.2f}%; Mean Ent = {:.4f}'.format(
        args.dset + '_train', 0, 0, acc, mean_ent.mean())
    args.out_file.write(log_str + '\n')
    args.out_file.flush()
    print(log_str + '\n')

    if args.ps == 0:
        est_p = (mean_ent < mean_ent.mean()).sum().item() / mean_ent.size(0)
        log_str = 'Task: {:.2f}'.format(est_p)
        print(log_str + '\n')
        args.out_file.write(log_str + '\n')
        args.out_file.flush()
        PS = est_p
    else:
        PS = args.ps

    if args.choice == "ent":
        value = mean_ent
    elif args.choice == "maxp":
        value = -top_pred
    elif args.choice == "marginp":
        pred, _ = torch.sort(all_output, 1)
        value = pred[:, 1] - pred[:, 0]
    else:
        value = torch.rand(len(mean_ent))

    predict = predict.numpy()
    train_idx = np.zeros(predict.shape)

    cls_k = args.class_num
    for c in range(cls_k):
        c_idx = np.where(predict == c)
        c_idx = c_idx[0]
        c_value = value[c_idx]

        _, idx_ = torch.sort(c_value)
        c_num = len(idx_)
        c_num_s = int(c_num * PS)
        # print(c, c_num, c_num_s)

        for ei in range(0, c_num_s):
            ee = c_idx[idx_[ei]]
            train_idx[ee] = 1

    train_target.targets = predict
    new_src = copy.deepcopy(train_target)
    new_tar = copy.deepcopy(train_target2)

    # pdb.set_trace()

    if args.dset == 'm2u':

        new_src.train_data = np.delete(new_src.train_data,
                                       np.where(train_idx == 0)[0],
                                       axis=0)
        new_src.train_labels = np.delete(new_src.train_labels,
                                         np.where(train_idx == 0)[0],
                                         axis=0)

        new_tar.train_data = np.delete(new_tar.train_data,
                                       np.where(train_idx == 1)[0],
                                       axis=0)
        new_tar.train_labels = np.delete(new_tar.train_labels,
                                         np.where(train_idx == 1)[0],
                                         axis=0)

    else:

        new_src.data = np.delete(new_src.data,
                                 np.where(train_idx == 0)[0],
                                 axis=0)
        new_src.targets = np.delete(new_src.targets,
                                    np.where(train_idx == 0)[0],
                                    axis=0)

        new_tar.data = np.delete(new_tar.data,
                                 np.where(train_idx == 1)[0],
                                 axis=0)
        new_tar.targets = np.delete(new_tar.targets,
                                    np.where(train_idx == 1)[0],
                                    axis=0)

    # pdb.set_trace()

    return new_src, new_tar
예제 #14
0
def copy_target_simp(args):
    dset_loaders = data_load(args)
    if args.net_src[0:3] == 'res':
        netF = network.ResBase(res_name=args.net_src).cuda()
    netC = network.feat_classifier_simpl(class_num=args.class_num,
                                         feat_dim=netF.in_features).cuda()

    args.modelpath = args.output_dir_src + '/source_F.pt'
    netF.load_state_dict(torch.load(args.modelpath))
    args.modelpath = args.output_dir_src + '/source_C.pt'
    netC.load_state_dict(torch.load(args.modelpath))
    source_model = nn.Sequential(netF, netC).cuda()
    source_model.eval()

    if args.net[0:3] == 'res':
        netF = network.ResBase(res_name=args.net, pretrain=True).cuda()
    netB = network.feat_bootleneck(type=args.classifier,
                                   feature_dim=netF.in_features,
                                   bottleneck_dim=args.bottleneck).cuda()
    netC = network.feat_classifier(type=args.layer,
                                   class_num=args.class_num,
                                   bottleneck_dim=args.bottleneck).cuda()

    param_group = []
    learning_rate = args.lr
    for k, v in netF.named_parameters():
        param_group += [{'params': v, 'lr': learning_rate * 0.1}]
    for k, v in netB.named_parameters():
        param_group += [{'params': v, 'lr': learning_rate}]
    for k, v in netC.named_parameters():
        param_group += [{'params': v, 'lr': learning_rate}]
    optimizer = optim.SGD(param_group)
    optimizer = op_copy(optimizer)

    ent_best = 1.0
    max_iter = args.max_epoch * len(dset_loaders["target"])
    interval_iter = max_iter // 10
    iter_num = 0

    model = nn.Sequential(netF, netB, netC).cuda()
    model.eval()

    start_test = True
    with torch.no_grad():
        iter_test = iter(dset_loaders["target_te"])
        for i in range(len(dset_loaders["target_te"])):
            data = iter_test.next()
            inputs, labels = data[0], data[1]
            inputs = inputs.cuda()
            outputs = source_model(inputs)
            outputs = nn.Softmax(dim=1)(outputs)
            _, src_idx = torch.sort(outputs, 1, descending=True)
            if args.topk > 0:
                topk = np.min([args.topk, args.class_num])
                for i in range(outputs.size()[0]):
                    outputs[i, src_idx[i, topk:]] = (
                        1.0 - outputs[i, src_idx[i, :topk]].sum()) / (
                            outputs.size()[1] - topk)

            if start_test:
                all_output = outputs.float()
                all_label = labels
                start_test = False
            else:
                all_output = torch.cat((all_output, outputs.float()), 0)
                all_label = torch.cat((all_label, labels), 0)
        mem_P = all_output.detach()

    model.train()
    while iter_num < max_iter:

        if args.ema < 1.0 and iter_num > 0 and iter_num % interval_iter == 0:
            model.eval()
            start_test = True
            with torch.no_grad():
                iter_test = iter(dset_loaders["target_te"])
                for i in range(len(dset_loaders["target_te"])):
                    data = iter_test.next()
                    inputs = data[0]
                    inputs = inputs.cuda()
                    outputs = model(inputs)
                    outputs = nn.Softmax(dim=1)(outputs)
                    if start_test:
                        all_output = outputs.float()
                        start_test = False
                    else:
                        all_output = torch.cat((all_output, outputs.float()),
                                               0)
                mem_P = mem_P * args.ema + all_output.detach() * (1 - args.ema)
            model.train()

        try:
            inputs_target, y, tar_idx = iter_target.next()
        except:
            iter_target = iter(dset_loaders["target"])
            inputs_target, y, tar_idx = iter_target.next()

        if inputs_target.size(0) == 1:
            continue

        iter_num += 1
        lr_scheduler(optimizer,
                     iter_num=iter_num,
                     max_iter=max_iter,
                     power=1.5)
        inputs_target = inputs_target.cuda()
        with torch.no_grad():
            outputs_target_by_source = mem_P[tar_idx, :]
            _, src_idx = torch.sort(outputs_target_by_source,
                                    1,
                                    descending=True)
        outputs_target = model(inputs_target)
        outputs_target = torch.nn.Softmax(dim=1)(outputs_target)
        classifier_loss = nn.KLDivLoss(reduction='batchmean')(
            outputs_target.log(), outputs_target_by_source)
        optimizer.zero_grad()

        entropy_loss = torch.mean(loss.Entropy(outputs_target))
        msoftmax = outputs_target.mean(dim=0)
        gentropy_loss = torch.sum(-msoftmax * torch.log(msoftmax + 1e-5))
        entropy_loss -= gentropy_loss
        classifier_loss += entropy_loss

        classifier_loss.backward()

        if args.mix > 0:
            alpha = 0.3
            lam = np.random.beta(alpha, alpha)
            index = torch.randperm(inputs_target.size()[0]).cuda()
            mixed_input = lam * inputs_target + (1 -
                                                 lam) * inputs_target[index, :]
            mixed_output = (lam * outputs_target +
                            (1 - lam) * outputs_target[index, :]).detach()

            update_batch_stats(model, False)
            outputs_target_m = model(mixed_input)
            update_batch_stats(model, True)
            outputs_target_m = torch.nn.Softmax(dim=1)(outputs_target_m)
            classifier_loss = args.mix * nn.KLDivLoss(reduction='batchmean')(
                outputs_target_m.log(), mixed_output)
            classifier_loss.backward()
        optimizer.step()

        if iter_num % interval_iter == 0 or iter_num == max_iter:
            model.eval()
            acc_s_te, mean_ent = cal_acc(dset_loaders['test'], netF, netB,
                                         netC, False)
            log_str = 'Task: {}, Iter:{}/{}; Accuracy = {:.2f}%, Ent = {:.4f}'.format(
                args.name, iter_num, max_iter, acc_s_te, mean_ent)
            args.out_file.write(log_str + '\n')
            args.out_file.flush()
            print(log_str + '\n')
            model.train()

    torch.save(netF.state_dict(), osp.join(args.output_dir, "source_F.pt"))
    torch.save(netB.state_dict(), osp.join(args.output_dir, "source_B.pt"))
    torch.save(netC.state_dict(), osp.join(args.output_dir, "source_C.pt"))
예제 #15
0
def split_target(args):
    test_transform = torchvision.transforms.Compose([
        torchvision.transforms.Resize((256, 256)),
        torchvision.transforms.CenterCrop(224),
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

    txt_tar = open(args.t_dset_path).readlines()
    dset_loaders = {}

    test_set = ImageList(txt_tar, transform=test_transform)
    dset_loaders["target"] = torch.utils.data.DataLoader(test_set, batch_size=args.batch_size*3,
        shuffle=False, num_workers=args.worker, drop_last=False)

    netF = network.ResBase(res_name=args.net).cuda()
    netB = network.feat_bootleneck(type=args.classifier, feature_dim=netF.in_features, bottleneck_dim=args.bottleneck).cuda()
    netC = network.feat_classifier(type=args.layer, class_num = args.class_num, bottleneck_dim=args.bottleneck).cuda()

    if args.model == "source":
        modelpath = args.output_dir + "/source_F.pt" 
        netF.load_state_dict(torch.load(modelpath))
        modelpath = args.output_dir + "/source_B.pt"   
        netB.load_state_dict(torch.load(modelpath))
        modelpath = args.output_dir + "/source_C.pt"    
        netC.load_state_dict(torch.load(modelpath))
    else:
        modelpath = args.output_dir + "/target_F_" + args.savename + ".pt" 
        netF.load_state_dict(torch.load(modelpath))
        modelpath = args.output_dir + "/target_B_" + args.savename + ".pt"   
        netB.load_state_dict(torch.load(modelpath))
        modelpath = args.output_dir + "/target_C_" + args.savename + ".pt"    
        netC.load_state_dict(torch.load(modelpath))

    netF.eval()
    netB.eval()
    netC.eval()

    start_test = True
    with torch.no_grad():
        iter_test = iter(dset_loaders['target'])
        for i in range(len(dset_loaders['target'])):
            data = iter_test.next()
            inputs = data[0]
            labels = data[1]
            inputs = inputs.cuda()
            outputs = netC(netB(netF(inputs)))
            if start_test:
                all_output = outputs.float().cpu()
                all_label = labels.float()
                start_test = False
            else:
                all_output = torch.cat((all_output, outputs.float().cpu()), 0)
                all_label = torch.cat((all_label, labels.float()), 0)
    top_pred, predict = torch.max(all_output, 1)
    acc = torch.sum(torch.squeeze(predict).float() == all_label).item() / float(all_label.size()[0]) * 100
    mean_ent = loss.Entropy(nn.Softmax(dim=1)(all_output))

    if args.dset == 'VISDA-C':
        matrix = confusion_matrix(all_label, torch.squeeze(predict).float())
        matrix = matrix[np.unique(all_label).astype(int),:]
        all_acc = matrix.diagonal()/matrix.sum(axis=1) * 100
        acc = all_acc.mean()
        aa = [str(np.round(i, 2)) for i in all_acc]
        acc_list = ' '.join(aa)
        print(acc_list)
        args.out_file.write(acc_list + '\n')
        args.out_file.flush()

    log_str = 'Task: {}, Iter:{}/{}; Accuracy = {:.2f}%; Mean Ent = {:.4f}'.format(args.name, 0, 0, acc, mean_ent.mean())
    args.out_file.write(log_str + '\n')
    args.out_file.flush()
    print(log_str+'\n')     

    if args.ps == 0:
        est_p = (mean_ent<mean_ent.mean()).sum().item() / mean_ent.size(0)
        log_str = 'Task: {:.2f}'.format(est_p)
        print(log_str + '\n')
        args.out_file.write(log_str + '\n')
        args.out_file.flush()
        PS = est_p
    else:
        PS = args.ps

    if args.choice == "ent":
        value = mean_ent
    elif args.choice == "maxp":
        value = - top_pred
    elif args.choice == "marginp":
        pred, _ = torch.sort(all_output, 1)
        value = pred[:,1] - pred[:,0]
    else:
        value = torch.rand(len(mean_ent))

    ori_target = txt_tar.copy()
    new_tar = []
    new_src = []

    predict = predict.numpy()

    cls_k = args.class_num
    for c in range(cls_k):
        c_idx = np.where(predict==c)
        c_idx = c_idx[0]
        c_value = value[c_idx]

        _, idx_ = torch.sort(c_value)
        c_num = len(idx_)
        c_num_s = int(c_num * PS)
        
        for ei in range(0, c_num_s):
            ee = c_idx[idx_[ei]]
            reci = ori_target[ee].strip().split(' ')
            line = reci[0] + ' ' + str(c) + '\n' 
            new_src.append(line)
        for ei in range(c_num_s, c_num):
            ee = c_idx[idx_[ei]]
            reci = ori_target[ee].strip().split(' ')
            line = reci[0] + ' ' + str(c) + '\n' 
            new_tar.append(line)

    return new_src.copy(), new_tar.copy()
예제 #16
0
def train_source(args):
    dset_loaders = data_load(args)
    ## set base network
    if args.net[0:3] == 'res' or args.net[0:3] == 'vgg':
        if args.net[0:3] == 'res':
            netF = network.ResBase(res_name=args.net).cuda()
        else:
            netF = network.VGGBase(vgg_name=args.net).cuda()
        netB = network.feat_bootleneck(type=args.classifier, feature_dim=netF.in_features,
                                       bottleneck_dim=args.bottleneck).cuda()   # classifier: bn
        netC = network.feat_classifier(type=args.layer, class_num=args.class_num,
                                       bottleneck_dim=args.bottleneck).cuda()   # layer: wn

        if args.resume:
            args.modelpath = args.output_dir_src + '/source_F.pt'
            netF.load_state_dict(torch.load(args.modelpath))
            args.modelpath = args.output_dir_src + '/source_B.pt'
            netB.load_state_dict(torch.load(args.modelpath))
            args.modelpath = args.output_dir_src + '/source_C.pt'
            netC.load_state_dict(torch.load(args.modelpath))

        param_group = []
        learning_rate = args.lr
        for k, v in netF.named_parameters():
            param_group += [{'params': v, 'lr': learning_rate * 0.1}]
        for k, v in netB.named_parameters():
            param_group += [{'params': v, 'lr': learning_rate}]
        for k, v in netC.named_parameters():
            param_group += [{'params': v, 'lr': learning_rate}]
        optimizer = optim.SGD(param_group)
        optimizer = op_copy(optimizer)

    acc_init = 0.
    max_iter = args.max_epoch * len(dset_loaders["source_tr"])
    print_loss_interval = 25
    interval_iter = 100
    iter_num = 0

    if args.net[0:3] == 'res' or args.net[0:3] == 'vgg':
        netF.train()
        netB.train()
        netC.train()

    while iter_num < max_iter:
        try:
            inputs_source, labels_source = iter_source.next()
        except:
            iter_source = iter(dset_loaders["source_tr"])
            inputs_source, labels_source = iter_source.next()

        if inputs_source.size(0) == 1:
            continue

        iter_num += 1
        lr_scheduler(optimizer, iter_num=iter_num, max_iter=max_iter)

        inputs_source, labels_source = inputs_source.cuda(), labels_source.cuda()
        outputs_source = netC(netB(netF(inputs_source)))
        classifier_loss = CrossEntropyLabelSmooth(num_classes=args.class_num, epsilon=args.smooth)(outputs_source,
                                                                                                   labels_source)
        optimizer.zero_grad()
        classifier_loss.backward()
        optimizer.step()

        if iter_num % print_loss_interval == 0:
            print("Iter:{:>4d}/{} | Classification loss on Source: {:.2f}".format(iter_num, max_iter,
                                                                              classifier_loss.item()))
        if iter_num % interval_iter == 0 or iter_num == max_iter:
            netF.eval()
            netB.eval()
            netC.eval()
            if args.dset == 'VISDA-RSUT' or args.dset == 'VISDA-RSUT-50' or args.dset == 'VISDA-RSUT-10':
                # The small classes in VisDA-C (RSUT) still have relatively many samples.
                # Safe to use per-class average accuracy.
                acc_s_te, acc_list, acc_cls_avg_te= cal_acc(dset_loaders['source_te'], netF, netB, netC,
                                                            per_class_flag=True, visda_flag=True)
                log_str = 'Task: {}, Iter:{}/{}; Accuracy = {:.2f}%, Cls Avg Acc = {:.2f}'.format(args.name_src, iter_num, max_iter,
                                                                acc_s_te, acc_cls_avg_te) + '\n' + acc_list
                cur_acc = acc_cls_avg_te

            else:
                if args.trte == 'stratified':
                    # Stratified cross validation ensures the existence of every class in the validation set.
                    # Safe to use per-class average accuracy.
                    acc_s_te, acc_cls_avg_te, _ = cal_acc(dset_loaders['source_te'], netF, netB, netC,
                                                          per_class_flag=True, visda_flag=False)
                    log_str = 'Task: {}, Iter:{}/{}; Accuracy = {:.2f}%, Cls Avg Acc = {:.2f}'.format(args.name_src,
                                                                iter_num, max_iter, acc_s_te, acc_cls_avg_te)
                    cur_acc = acc_cls_avg_te
                else:
                    # Conventional cross validation may lead to the absence of certain classes in validation set,
                    # esp. when the dataset includes some very small classes, e.g., Office-Home (RSUT), DomainNet.
                    # Use overall accuracy to avoid 'nan' issue.
                    acc_s_te, _ = cal_acc(dset_loaders['source_te'], netF, netB, netC,
                                          per_class_flag=False, visda_flag=False)
                    log_str = 'Task: {}, Iter:{}/{}; Accuracy = {:.2f}%'.format(args.name_src, iter_num, max_iter, acc_s_te)
                    cur_acc = acc_s_te

            args.out_file.write(log_str + '\n')
            args.out_file.flush()
            print(log_str + '\n')

            if cur_acc >= acc_init and iter_num >= 3 * len(dset_loaders["source_tr"]):
                # first 3 epochs: not stable yet
                acc_init = cur_acc
                best_netF = netF.state_dict()
                best_netB = netB.state_dict()
                best_netC = netC.state_dict()

            netF.train()
            netB.train()
            netC.train()

    torch.save(best_netF, osp.join(args.output_dir_src, "source_F.pt"))
    torch.save(best_netB, osp.join(args.output_dir_src, "source_B.pt"))
    torch.save(best_netC, osp.join(args.output_dir_src, "source_C.pt"))

    return netF, netB, netC
예제 #17
0
def train_target(args):
    dset_loaders = digit_load(args)
    ## set base network
    if args.dset == 'u2m':
        netF = network.LeNetBase().cuda()
    elif args.dset == 'm2u':
        netF = network.LeNetBase().cuda()
    elif args.dset == 's2m':
        netF = network.DTNBase().cuda()

    netB = network.feat_bootleneck(type=args.classifier,
                                   feature_dim=netF.in_features,
                                   bottleneck_dim=args.bottleneck).cuda()
    netC = network.feat_classifier(type=args.layer,
                                   class_num=args.class_num,
                                   bottleneck_dim=args.bottleneck).cuda()

    args.modelpath = args.output_dir + '/source_F_val.pt'
    netF.load_state_dict(torch.load(args.modelpath))
    args.modelpath = args.output_dir + '/source_B_val.pt'
    netB.load_state_dict(torch.load(args.modelpath))
    args.modelpath = args.output_dir + '/source_C_val.pt'
    netC.load_state_dict(torch.load(args.modelpath))
    netC.eval()
    for k, v in netC.named_parameters():
        v.requires_grad = False

    param_group = []
    for k, v in netF.named_parameters():
        param_group += [{'params': v, 'lr': args.lr}]
    for k, v in netB.named_parameters():
        param_group += [{'params': v, 'lr': args.lr}]
    optimizer = optim.SGD(param_group,
                          momentum=0.9,
                          weight_decay=5e-4,
                          nesterov=True)

    for epoch in tqdm(range(args.max_epoch), leave=False):
        iter_test = iter(dset_loaders["target"])

        netF.eval()
        netF.eval()
        mem_label = obtain_label(dset_loaders['target_te'], netF, netB, netC,
                                 args)
        mem_label = torch.from_numpy(mem_label).cuda()
        netF.train()
        netB.train()
        for _, (inputs_test, _, tar_idx) in tqdm(enumerate(iter_test),
                                                 leave=False):
            if inputs_test.size(0) == 1:
                continue
            inputs_test = inputs_test.cuda()
            pred = mem_label[tar_idx]

            features_test = netB(netF(inputs_test))
            outputs_test = netC(features_test)
            classifier_loss = CrossEntropyLabelSmooth(
                num_classes=args.class_num, epsilon=0)(outputs_test, pred)

            softmax_out = nn.Softmax(dim=1)(outputs_test)
            im_loss = torch.mean(Entropy(softmax_out))
            msoftmax = softmax_out.mean(dim=0)
            im_loss -= torch.sum(-msoftmax * torch.log(msoftmax + 1e-5))
            total_loss = im_loss + args.cls_par * classifier_loss

            optimizer.zero_grad()
            total_loss.backward()
            optimizer.step()

        netF.eval()
        netB.eval()
        acc, _ = cal_acc(dset_loaders['test'], netF, netB, netC)
        log_str = 'Task: {}, Iter:{}/{}; Accuracy = {:.2f}%'.format(
            args.dset, epoch + 1, args.max_epoch, acc * 100)
        args.out_file.write(log_str + '\n')
        args.out_file.flush()
        print(log_str + '\n')

    # torch.save(netF.state_dict(), osp.join(args.output_dir, "target_F.pt"))
    # torch.save(netB.state_dict(), osp.join(args.output_dir, "target_B.pt"))
    # torch.save(netC.state_dict(), osp.join(args.output_dir, "target_C.pt"))
    return netF, netB, netC
예제 #18
0
def train_target_bait(args, zz=''):
    dset_loaders = data_load(args)
    ## set base network
    if args.net[0:3] == 'res':
        netF = network.ResBase(res_name=args.net).cuda()

    netB = network.feat_bootleneck(type=args.classifier,
                                   feature_dim=netF.in_features,
                                   bottleneck_dim=args.bottleneck).cuda()
    netC = network.feat_classifier(type=args.layer,
                                   class_num=args.class_num,
                                   bottleneck_dim=args.bottleneck).cuda()
    oldC = network.feat_classifier(type=args.layer,
                                   class_num=args.class_num,
                                   bottleneck_dim=args.bottleneck).cuda()

    args.modelpath = args.output_dir_src + '/source_F_' + str(zz) + '.pt'
    netF.load_state_dict(torch.load(args.modelpath))
    args.modelpath = args.output_dir_src + '/source_B_' + str(zz) + '.pt'
    netB.load_state_dict(torch.load(args.modelpath))
    args.modelpath = args.output_dir_src + '/source_C_' + str(zz) + '.pt'
    netC.load_state_dict(torch.load(args.modelpath))
    oldC.load_state_dict(torch.load(args.modelpath))
    oldC.eval()
    netC.train()
    for k, v in oldC.named_parameters():
        v.requires_grad = False

    param_group = []
    param_group_c = []
    for k, v in netF.named_parameters():
        param_group += [{'params': v, 'lr': args.lr * 0.01}]  #0.1
    for k, v in netB.named_parameters():
        param_group += [{'params': v, 'lr': args.lr * .1}]  # 1
    for k, v in netC.named_parameters():
        param_group_c += [{'params': v, 'lr': args.lr * .1}]  #1
    optimizer = optim.SGD(param_group,
                          momentum=0.9,
                          weight_decay=5e-4,
                          nesterov=True)
    optimizer_c = optim.SGD(param_group_c,
                            momentum=0.9,
                            weight_decay=5e-4,
                            nesterov=True)

    netF.train()
    netB.train()

    iter_num = 0
    iter_target = iter(dset_loaders["target"])
    while iter_num < (args.max_epoch) * len(dset_loaders["target"]):
        try:
            inputs_test, _, tar_idx = iter_target.next()
        except:
            iter_target = iter(dset_loaders["target"])
            inputs_test, _, tar_idx = iter_target.next()
        if inputs_test.size(0) == 1:
            continue
        iter_num += 1
        inputs_test = inputs_test.cuda()
        batch_size = inputs_test.shape[0]

        if True:
            total_loss = 0
            features_test = netB(netF(inputs_test))
            outputs_test = netC(features_test)
            outputs_test_old = oldC(features_test)

            softmax_out = nn.Softmax(dim=1)(outputs_test)
            softmax_out_old = nn.Softmax(dim=1)(outputs_test_old)

            loss_cast = loss.SKL(softmax_out, softmax_out_old).sum(dim=1)

            entropy_old = Entropy(softmax_out_old)
            indx = entropy_old.topk(int(batch_size * 0.5), largest=True)[-1]
            ones_mask = torch.ones(batch_size).cuda() * -1
            ones_mask[indx] = 1
            loss_cast = loss_cast * ones_mask
            total_loss -= torch.mean(loss_cast) * 10

            optimizer_c.zero_grad()
            total_loss.backward()
            optimizer_c.step()

        for _ in range(1):
            total_loss = 0
            features_test = netB(netF(inputs_test))
            outputs_test = netC(features_test)

            softmax_out = nn.Softmax(dim=1)(outputs_test)

            outputs_test_old = oldC(features_test)
            softmax_out_old = nn.Softmax(dim=1)(outputs_test_old)

            msoftmax = softmax_out_old.mean(dim=0)
            cb_loss = torch.sum(msoftmax * torch.log(msoftmax + 1e-5))
            total_loss += cb_loss

            msoftmax = softmax_out.mean(dim=0)
            cb_loss = torch.sum(msoftmax * torch.log(msoftmax + 1e-5))
            total_loss += cb_loss

            loss_bite = (-softmax_out_old * torch.log(softmax_out + 1e-5)).sum(
                1) - (softmax_out * torch.log(softmax_out_old + 1e-5)).sum(1)
            total_loss += torch.mean(loss_bite)  #*0.8

            optimizer.zero_grad()
            total_loss.backward()
            optimizer.step()

        if iter_num % int(args.interval * len(dset_loaders["target"])) == 0:
            netF.eval()
            netB.eval()
            netC.eval()
            acc, acc_list = cal_acc(dset_loaders['test'], netF, netB, oldC,
                                    args.dset == "visda17")
            log_str = 'Task: {}, Iter:{}/{}; Accuracy = {:.2f}%'.format(args.name, iter_num, \
                args.max_epoch * len(dset_loaders["target"]), acc) + '\n' + str(acc_list)

            args.out_file.write(log_str + '\n')
            args.out_file.flush()
            print(log_str + '\n')

            netF.train()
            netB.train()
            netC.train()
    return netF, netB, netC
예제 #19
0
def train_source(args):
    dset_loaders = data_load(args)
    ## set base network
    if args.net[0:3] == 'res':
        netF = network.ResBase(res_name=args.net).cuda()

    netB = network.feat_bootleneck(type=args.classifier, feature_dim=netF.in_features, bottleneck_dim=args.bottleneck).cuda()
    netC = network.feat_classifier(type=args.layer, class_num = args.class_num, bottleneck_dim=args.bottleneck).cuda()

    param_group = []
    learning_rate = args.lr
    for k, v in netF.named_parameters():
        param_group += [{'params': v, 'lr': learning_rate*0.1}]#1
    for k, v in netB.named_parameters():
        param_group += [{'params': v, 'lr': learning_rate*1}]#10
    for k, v in netC.named_parameters():
        param_group += [{'params': v, 'lr': learning_rate*1}]#10
    optimizer = optim.SGD(param_group, momentum=0.9, weight_decay=5e-4, nesterov=True)

    acc_init = 0
    netF.train()
    netB.train()
    netC.train()
    iter_num = 0
    iter_source = iter(dset_loaders["source_tr"])
    while iter_num < args.max_epoch * len(dset_loaders["source_tr"]):
        try:
            inputs_source, labels_source = iter_source.next()
        except:
            iter_source = iter(dset_loaders["source_tr"])
            inputs_source, labels_source = iter_source.next()
        if inputs_source.size(0) == 1:
            continue
        iter_num += 1
        inputs_source, labels_source = inputs_source.cuda(), labels_source.cuda()
        outputs_source = netC(netB(netF(inputs_source)))
        classifier_loss = loss.CrossEntropyLabelSmooth(num_classes=args.class_num, epsilon=args.smooth)(outputs_source, labels_source)
        optimizer.zero_grad()
        classifier_loss.backward()
        optimizer.step()

        if (iter_num % int(args.interval*len(dset_loaders["source_tr"])) == 0):
            netF.eval()
            netB.eval()
            netC.eval()
            acc_s_te, acc_list = cal_acc(dset_loaders['source_te'], netF, netB, netC, args.dset=="visda17")
            log_str = 'Task: {}, Iter:{}; Accuracy = {:.2f}%'.format(args.name_src, iter_num, acc_s_te) + '\n' + str(acc_list)
            args.out_file.write(log_str + '\n')
            args.out_file.flush()
            print(log_str+'\n')
            if acc_s_te >= acc_init:
                acc_init = acc_s_te
                best_netF = netF.state_dict()
                best_netB = netB.state_dict()
                best_netC = netC.state_dict()
            netF.train()
            netB.train()
            netC.train()

    torch.save(best_netF, osp.join(args.output_dir_src, "source_F_val.pt"))
    torch.save(best_netB, osp.join(args.output_dir_src, "source_B_val.pt"))
    torch.save(best_netC, osp.join(args.output_dir_src, "source_C_val.pt"))
    return netF, netB, netC
예제 #20
0
def train(args, txt_src, txt_tgt):
    ## set pre-process
    dset_loaders = data_load(args, txt_src, txt_tgt)
    # pdb.set_trace()
    max_len = max(len(dset_loaders["source"]), len(dset_loaders["target"]))
    max_iter = args.max_epoch * max_len
    interval_iter = max_iter // 10

    if args.dset == 'u2m':
        netG = network.LeNetBase().cuda()
    elif args.dset == 'm2u':
        netG = network.LeNetBase().cuda()
    elif args.dset == 's2m':
        netG = network.DTNBase().cuda()

    netB = network.feat_bootleneck(type=args.classifier,
                                   feature_dim=netG.in_features,
                                   bottleneck_dim=args.bottleneck).cuda()
    netC = network.feat_classifier(type=args.layer,
                                   class_num=args.class_num,
                                   bottleneck_dim=args.bottleneck).cuda()

    if args.model == 'source':
        modelpath = args.output_dir + "/source_F.pt"
        netG.load_state_dict(torch.load(modelpath))
        modelpath = args.output_dir + "/source_B.pt"
        netB.load_state_dict(torch.load(modelpath))
    else:
        modelpath = args.output_dir + "/target_F_" + args.savename + ".pt"
        netG.load_state_dict(torch.load(modelpath))
        modelpath = args.output_dir + "/target_B_" + args.savename + ".pt"
        netB.load_state_dict(torch.load(modelpath))

    netF = nn.Sequential(netB, netC)
    optimizer_g = optim.SGD(netG.parameters(), lr=args.lr * 0.1)
    optimizer_f = optim.SGD(netF.parameters(), lr=args.lr)

    base_network = nn.Sequential(netG, netF)
    source_loader_iter = iter(dset_loaders["source"])
    target_loader_iter = iter(dset_loaders["target"])

    list_acc = []
    best_ent = 100

    for iter_num in range(1, max_iter + 1):
        base_network.train()
        lr_scheduler(optimizer_g,
                     init_lr=args.lr * 0.1,
                     iter_num=iter_num,
                     max_iter=max_iter)
        lr_scheduler(optimizer_f,
                     init_lr=args.lr,
                     iter_num=iter_num,
                     max_iter=max_iter)

        try:
            inputs_source, labels_source = source_loader_iter.next()
        except:
            source_loader_iter = iter(dset_loaders["source"])
            inputs_source, labels_source = source_loader_iter.next()
        try:
            inputs_target, _, target_idx = target_loader_iter.next()
        except:
            target_loader_iter = iter(dset_loaders["target"])
            inputs_target, _, target_idx = target_loader_iter.next()

        targets_s = torch.zeros(args.batch_size, args.class_num).scatter_(
            1, labels_source.view(-1, 1), 1)
        inputs_s = inputs_source.cuda()
        targets_s = targets_s.cuda()
        inputs_t = inputs_target[0].cuda()
        inputs_t2 = inputs_target[1].cuda()

        with torch.no_grad():
            # compute guessed labels of unlabel samples
            outputs_u = base_network(inputs_t)
            outputs_u2 = base_network(inputs_t2)
            p = (torch.softmax(outputs_u, dim=1) +
                 torch.softmax(outputs_u2, dim=1)) / 2
            pt = p**(1 / args.T)
            targets_u = pt / pt.sum(dim=1, keepdim=True)
            targets_u = targets_u.detach()

        ####################################################################
        all_inputs = torch.cat([inputs_s, inputs_t, inputs_t2], dim=0)
        all_targets = torch.cat([targets_s, targets_u, targets_u], dim=0)
        if args.alpha > 0:
            l = np.random.beta(args.alpha, args.alpha)
            l = max(l, 1 - l)
        else:
            l = 1
        idx = torch.randperm(all_inputs.size(0))

        input_a, input_b = all_inputs, all_inputs[idx]
        target_a, target_b = all_targets, all_targets[idx]
        mixed_input = l * input_a + (1 - l) * input_b
        mixed_target = l * target_a + (1 - l) * target_b

        # interleave labeled and unlabed samples between batches to get correct batchnorm calculation
        mixed_input = list(torch.split(mixed_input, args.batch_size))
        mixed_input = utils.interleave(mixed_input, args.batch_size)
        # s = [sa, sb, sc]
        # t1 = [t1a, t1b, t1c]
        # t2 = [t2a, t2b, t2c]
        # => s' = [sa, t1b, t2c]   t1' = [t1a, sb, t1c]   t2' = [t2a, t2b, sc]

        logits = base_network(mixed_input[0])
        logits = [logits]
        for input in mixed_input[1:]:
            temp = base_network(input)
            logits.append(temp)

        # put interleaved samples back
        # [i[:,0] for i in aa]
        logits = utils.interleave(logits, args.batch_size)
        logits_x = logits[0]
        logits_u = torch.cat(logits[1:], dim=0)

        train_criterion = utils.SemiLoss()

        Lx, Lu, w = train_criterion(logits_x, mixed_target[:args.batch_size],
                                    logits_u, mixed_target[args.batch_size:],
                                    iter_num, max_iter, args.lambda_u)
        loss = Lx + w * Lu

        optimizer_g.zero_grad()
        optimizer_f.zero_grad()
        loss.backward()
        optimizer_g.step()
        optimizer_f.step()

        if iter_num % interval_iter == 0 or iter_num == max_iter:
            base_network.eval()

            acc, py, score, y = cal_acc(dset_loaders["train"],
                                        base_network,
                                        flag=False)
            mean_ent = torch.mean(Entropy(score))
            log_str = 'Task: {}, Iter:{}/{}; Accuracy = {:.2f}%; Mean Ent = {:.4f}'.format(
                args.dset + '_train', iter_num, max_iter, acc, mean_ent)
            args.out_file.write(log_str + '\n')
            args.out_file.flush()
            print(log_str + '\n')

            acc, py, score, y = cal_acc(dset_loaders["test"],
                                        base_network,
                                        flag=False)
            mean_ent = torch.mean(Entropy(score))
            list_acc.append(acc)

            if best_ent > mean_ent:
                val_acc = acc
                best_ent = mean_ent
                best_y = y
                best_py = py
                best_score = score

            log_str = 'Task: {}, Iter:{}/{}; Accuracy = {:.2f}%; Mean Ent = {:.4f}'.format(
                args.dset + '_test', iter_num, max_iter, acc, mean_ent)
            args.out_file.write(log_str + '\n')
            args.out_file.flush()
            print(log_str + '\n')

    idx = np.argmax(np.array(list_acc))
    max_acc = list_acc[idx]
    final_acc = list_acc[-1]

    log_str = '\n==========================================\n'
    log_str += '\nVal Acc = {:.2f}\nMax Acc = {:.2f}\nFin Acc = {:.2f}\n'.format(
        val_acc, max_acc, final_acc)
    args.out_file.write(log_str + '\n')
    args.out_file.flush()

    # torch.save(base_network.state_dict(), osp.join(args.output_dir, args.log + ".pt"))
    # sio.savemat(osp.join(args.output_dir, args.log + ".mat"), {'y':best_y.cpu().numpy(),
    #     'py':best_py.cpu().numpy(), 'score':best_score.cpu().numpy()})

    return base_network, py
예제 #21
0
def train_target(args):
    dset_loaders = digit_load(args)
    ## set base network
    if args.dset == 'u2m':
        netF = network.LeNetBase().cuda()
    elif args.dset == 'm2u':
        netF = network.LeNetBase().cuda()
    elif args.dset == 's2m':
        netF = network.DTNBase().cuda()

    netB = network.feat_bootleneck(type=args.classifier,
                                   feature_dim=netF.in_features,
                                   bottleneck_dim=args.bottleneck).cuda()
    netC = network.feat_classifier(type=args.layer,
                                   class_num=args.class_num,
                                   bottleneck_dim=args.bottleneck).cuda()

    args.modelpath = args.output_dir + '/source_F.pt'
    netF.load_state_dict(torch.load(args.modelpath))
    args.modelpath = args.output_dir + '/source_B.pt'
    netB.load_state_dict(torch.load(args.modelpath))
    args.modelpath = args.output_dir + '/source_C.pt'
    netC.load_state_dict(torch.load(args.modelpath))
    netC.eval()
    for k, v in netC.named_parameters():
        v.requires_grad = False

    param_group = []
    for k, v in netF.named_parameters():
        param_group += [{'params': v, 'lr': args.lr}]
    for k, v in netB.named_parameters():
        param_group += [{'params': v, 'lr': args.lr}]

    optimizer = optim.SGD(param_group)
    optimizer = op_copy(optimizer)

    max_iter = args.max_epoch * len(dset_loaders["target"])
    interval_iter = len(dset_loaders["target"])
    # interval_iter = max_iter // args.interval
    iter_num = 0

    while iter_num < max_iter:
        optimizer.zero_grad()
        try:
            inputs_test, _, tar_idx = iter_test.next()
        except:
            iter_test = iter(dset_loaders["target"])
            inputs_test, _, tar_idx = iter_test.next()

        if inputs_test.size(0) == 1:
            continue

        if iter_num % interval_iter == 0 and args.cls_par > 0:
            netF.eval()
            netB.eval()
            mem_label = obtain_label(dset_loaders['target_te'], netF, netB,
                                     netC, args)
            mem_label = torch.from_numpy(mem_label).cuda()
            netF.train()
            netB.train()

        iter_num += 1
        lr_scheduler(optimizer, iter_num=iter_num, max_iter=max_iter)

        inputs_test = inputs_test.cuda()
        features_test = netB(netF(inputs_test))
        outputs_test = netC(features_test)

        if args.cls_par > 0:
            pred = mem_label[tar_idx]
            classifier_loss = args.cls_par * nn.CrossEntropyLoss()(
                outputs_test, pred)
        else:
            classifier_loss = torch.tensor(0.0).cuda()

        if args.ent:
            softmax_out = nn.Softmax(dim=1)(outputs_test)
            entropy_loss = torch.mean(loss.Entropy(softmax_out))
            # if args.gent:
            #     msoftmax = softmax_out.mean(dim=0)
            #     entropy_loss -= torch.sum(-msoftmax * torch.log(msoftmax + 1e-5))

            im_loss = entropy_loss * args.ent_par
            classifier_loss += im_loss

        optimizer.zero_grad()
        classifier_loss.backward()
        optimizer.step()

        if iter_num % interval_iter == 0 or iter_num == max_iter:
            netF.eval()
            netB.eval()
            acc, _ = cal_acc(dset_loaders['test'], netF, netB, netC)
            log_str = 'Task: {}, Iter:{}/{}; Accuracy = {:.2f}%'.format(
                args.dset, iter_num, max_iter, acc)
            args.out_file.write(log_str + '\n')
            args.out_file.flush()
            print(log_str + '\n')
            netF.train()
            netB.train()

    if args.issave:
        torch.save(
            netF.state_dict(),
            osp.join(args.output_dir, "target_F_" + args.savename + ".pt"))
        torch.save(
            netB.state_dict(),
            osp.join(args.output_dir, "target_B_" + args.savename + ".pt"))
        torch.save(
            netC.state_dict(),
            osp.join(args.output_dir, "target_C_" + args.savename + ".pt"))

    return netF, netB, netC
예제 #22
0
def train_target(args):
    dset_loaders = data_load(args)
    ## set base network
    if args.net[0:3] == 'res':
        netF = network.ResBase(res_name=args.net).cuda()
    elif args.net[0:3] == 'vgg':
        netF = network.VGGBase(vgg_name=args.net).cuda()

    netB = network.feat_bootleneck(type=args.classifier,
                                   feature_dim=netF.in_features,
                                   bottleneck_dim=args.bottleneck).cuda()
    netC = network.feat_classifier(type=args.layer,
                                   class_num=args.class_num,
                                   bottleneck_dim=args.bottleneck).cuda()

    args.modelpath = args.output_dir_src + '/source_F.pt'
    netF.load_state_dict(torch.load(args.modelpath))
    args.modelpath = args.output_dir_src + '/source_B.pt'
    netB.load_state_dict(torch.load(args.modelpath))
    args.modelpath = args.output_dir_src + '/source_C.pt'
    netC.load_state_dict(torch.load(args.modelpath))
    netC.eval()
    for k, v in netC.named_parameters():
        v.requires_grad = False

    param_group = []
    for k, v in netF.named_parameters():
        if args.lr_decay1 > 0:
            param_group += [{'params': v, 'lr': args.lr * args.lr_decay1}]
        else:
            v.requires_grad = False
    for k, v in netB.named_parameters():
        if args.lr_decay2 > 0:
            param_group += [{'params': v, 'lr': args.lr * args.lr_decay2}]
        else:
            v.requires_grad = False

    optimizer = optim.SGD(param_group)
    optimizer = op_copy(optimizer)

    tt = 0
    iter_num = 0
    max_iter = args.max_epoch * len(dset_loaders["target"])
    interval_iter = max_iter // args.interval

    while iter_num < max_iter:
        try:
            inputs_test, _, tar_idx = iter_test.next()
        except:
            iter_test = iter(dset_loaders["target"])
            inputs_test, _, tar_idx = iter_test.next()

        if inputs_test.size(0) == 1:
            continue

        if iter_num % interval_iter == 0:
            netF.eval()
            netB.eval()
            mem_label, ENT_THRESHOLD = obtain_label(dset_loaders['test'], netF,
                                                    netB, netC, args)
            mem_label = torch.from_numpy(mem_label).cuda()
            netF.train()
            netB.train()

        inputs_test = inputs_test.cuda()

        iter_num += 1
        lr_scheduler(optimizer, iter_num=iter_num, max_iter=max_iter)

        pred = mem_label[tar_idx]
        features_test = netB(netF(inputs_test))
        outputs_test = netC(features_test)

        softmax_out = nn.Softmax(dim=1)(outputs_test)
        outputs_test_known = outputs_test[pred < args.class_num, :]
        pred = pred[pred < args.class_num]

        if len(pred) == 0:
            print(tt)
            del features_test
            del outputs_test
            tt += 1
            continue

        if args.cls_par > 0:
            classifier_loss = nn.CrossEntropyLoss()(outputs_test_known, pred)
            classifier_loss *= args.cls_par
        else:
            classifier_loss = torch.tensor(0.0).cuda()

        if args.ent:
            softmax_out_known = nn.Softmax(dim=1)(outputs_test_known)
            entropy_loss = torch.mean(loss.Entropy(softmax_out_known))
            if args.gent:
                msoftmax = softmax_out.mean(dim=0)
                gentropy_loss = torch.sum(-msoftmax *
                                          torch.log(msoftmax + args.epsilon))
                entropy_loss -= gentropy_loss
            classifier_loss += entropy_loss * args.ent_par

        optimizer.zero_grad()
        classifier_loss.backward()
        optimizer.step()

        if iter_num % interval_iter == 0 or iter_num == max_iter:
            netF.eval()
            netB.eval()
            acc_os1, acc_os2, acc_unknown = cal_acc(dset_loaders['test'], netF,
                                                    netB, netC, True,
                                                    ENT_THRESHOLD)
            log_str = 'Task: {}, Iter:{}/{}; Accuracy = {:.2f}% / {:.2f}% / {:.2f}%'.format(
                args.name, iter_num, max_iter, acc_os2, acc_os1, acc_unknown)
            args.out_file.write(log_str + '\n')
            args.out_file.flush()
            print(log_str + '\n')
            netF.train()
            netB.train()

    if args.issave:
        torch.save(
            netF.state_dict(),
            osp.join(args.output_dir, "target_F_" + args.savename + ".pt"))
        torch.save(
            netB.state_dict(),
            osp.join(args.output_dir, "target_B_" + args.savename + ".pt"))
        torch.save(
            netC.state_dict(),
            osp.join(args.output_dir, "target_C_" + args.savename + ".pt"))

    return netF, netB, netC
예제 #23
0
def train_target(args):
    dset_loaders = data_load(args)
    ## set base network
    if args.net[0:3] == 'res':
        netF = network.ResBase(res_name=args.net).cuda()
    elif args.net[0:3] == 'vgg':
        netF = network.VGGBase(vgg_name=args.net).cuda()  

    netB = network.feat_bootleneck(type=args.classifier, feature_dim=netF.in_features, bottleneck_dim=args.bottleneck).cuda()
    netC = network.feat_classifier(type=args.layer, class_num = args.class_num, bottleneck_dim=args.bottleneck).cuda()
    
    if not args.ssl == 0:
        netR = network.feat_classifier(type='linear', class_num=4, bottleneck_dim=2*args.bottleneck).cuda()
        netR_dict, acc_rot = train_target_rot(args)
        netR.load_state_dict(netR_dict)
    
    modelpath = args.output_dir_src + '/source_F.pt'   
    netF.load_state_dict(torch.load(modelpath))
    modelpath = args.output_dir_src + '/source_B.pt'   
    netB.load_state_dict(torch.load(modelpath))
    modelpath = args.output_dir_src + '/source_C.pt'    
    netC.load_state_dict(torch.load(modelpath))
    netC.eval()
    for k, v in netC.named_parameters():
        v.requires_grad = False

    param_group = []
    for k, v in netF.named_parameters():
        if args.lr_decay1 > 0:
            param_group += [{'params': v, 'lr': args.lr * args.lr_decay1}]
        else:
            v.requires_grad = False
    for k, v in netB.named_parameters():
        if args.lr_decay2 > 0:
            param_group += [{'params': v, 'lr': args.lr * args.lr_decay2}]
        else:
            v.requires_grad = False
    if not args.ssl == 0:
        for k, v in netR.named_parameters():
            param_group += [{'params': v, 'lr': args.lr * args.lr_decay2}]
        netR.train()

    optimizer = optim.SGD(param_group)
    optimizer = op_copy(optimizer)

    max_iter = args.max_epoch * len(dset_loaders["target"])
    interval_iter = max_iter // args.interval
    iter_num = 0

    while iter_num < max_iter:
        optimizer.zero_grad()
        try:
            inputs_test, _, tar_idx = iter_test.next()
        except:
            iter_test = iter(dset_loaders["target"])
            inputs_test, _, tar_idx = iter_test.next()

        if inputs_test.size(0) == 1:
            continue

        if iter_num % interval_iter == 0 and args.cls_par > 0:
            netF.eval()
            netB.eval()
            mem_label = obtain_label(dset_loaders['target_te'], netF, netB, netC, args)
            mem_label = torch.from_numpy(mem_label).cuda()
            netF.train()
            netB.train()

        inputs_test = inputs_test.cuda()

        iter_num += 1
        lr_scheduler(optimizer, iter_num=iter_num, max_iter=max_iter)
        if args.cls_par > 0:
            pred = mem_label[tar_idx]

        features_test = netB(netF(inputs_test))
        outputs_test = netC(features_test)

        if args.cls_par > 0:
            classifier_loss = nn.CrossEntropyLoss()(outputs_test, pred)
            classifier_loss *= args.cls_par
            if iter_num < interval_iter and args.dset == "VISDA-C":
                classifier_loss *= 0
        else:
            classifier_loss = torch.tensor(0.0).cuda()

        if args.ent:
            softmax_out = nn.Softmax(dim=1)(outputs_test)
            entropy_loss = torch.mean(loss.Entropy(softmax_out))
            if args.gent:
                msoftmax = softmax_out.mean(dim=0)
                gentropy_loss = torch.sum(-msoftmax * torch.log(msoftmax + args.epsilon))
                entropy_loss -= gentropy_loss
            im_loss = entropy_loss * args.ent_par
            classifier_loss += im_loss

        classifier_loss.backward()

        if not args.ssl == 0:
            r_labels_target = np.random.randint(0, 4, len(inputs_test))
            r_inputs_target = rotation.rotate_batch_with_labels(inputs_test, r_labels_target)
            r_labels_target = torch.from_numpy(r_labels_target).cuda()
            r_inputs_target = r_inputs_target.cuda()

            f_outputs = netB(netF(inputs_test))
            f_outputs = f_outputs.detach()
            f_r_outputs = netB(netF(r_inputs_target))
            r_outputs_target = netR(torch.cat((f_outputs, f_r_outputs), 1))

            rotation_loss = args.ssl * nn.CrossEntropyLoss()(r_outputs_target, r_labels_target)   
            rotation_loss.backward() 

        optimizer.step()

        if iter_num % interval_iter == 0 or iter_num == max_iter:
            netF.eval()
            netB.eval()
            if args.dset=='VISDA-C':
                acc_s_te, acc_list = cal_acc(dset_loaders['test'], netF, netB, netC, True)
                log_str = 'Task: {}, Iter:{}/{}; Accuracy = {:.2f}%'.format(args.name, iter_num, max_iter, acc_s_te) + '\n' + acc_list
            else:
                acc_s_te, _ = cal_acc(dset_loaders['test'], netF, netB, netC, False)
                log_str = 'Task: {}, Iter:{}/{}; Accuracy = {:.2f}%'.format(args.name, iter_num, max_iter, acc_s_te)

            args.out_file.write(log_str + '\n')
            args.out_file.flush()
            print(log_str+'\n')
            netF.train()
            netB.train()

    if args.issave:   
        torch.save(netF.state_dict(), osp.join(args.output_dir, "target_F_" + args.savename + ".pt"))
        torch.save(netB.state_dict(), osp.join(args.output_dir, "target_B_" + args.savename + ".pt"))
        torch.save(netC.state_dict(), osp.join(args.output_dir, "target_C_" + args.savename + ".pt"))
        
    return netF, netB, netC
예제 #24
0
def pretrain_on_source(src_data_loader, src_data_loader_eval, output_dir):
    ## set base network
    if params.mode == 'u2m':
        netF = network.LeNetBase().cuda()
    elif params.mode == 'm2u':
        netF = network.LeNetBase().cuda()
    elif params.mode == 's2m':
        netF = network.DTNBase().cuda()
    netB = network.feat_bootleneck(type=params.classifier,
                                   feature_dim=netF.in_features,
                                   bottleneck_dim=params.bottleneck).cuda()
    netC = network.feat_classifier(type=params.layer,
                                   class_num=params.class_num,
                                   bottleneck_dim=params.bottleneck).cuda()

    param_group = []
    learning_rate = params.lr
    for k, v in netF.named_parameters():
        param_group += [{'params': v, 'lr': learning_rate}]
    for k, v in netB.named_parameters():
        param_group += [{'params': v, 'lr': learning_rate}]
    for k, v in netC.named_parameters():
        param_group += [{'params': v, 'lr': learning_rate}]
    optimizer = optim.SGD(param_group,
                          momentum=0.9,
                          weight_decay=5e-4,
                          nesterov=True)

    acc_init = 0
    out_file = open(os.path.join(output_dir, 'log_pretrain.txt'), 'w')
    for epoch in range(params.epochs):
        # scheduler.step()
        netF.train()
        netB.train()
        netC.train()
        iter_source = iter(src_data_loader)
        for _, (inputs_source, labels_source) in enumerate(iter_source):
            if inputs_source.size(0) == 1:
                continue
            inputs_source, labels_source = inputs_source.cuda(
            ), labels_source.cuda()
            outputs_source = netC(netB(netF(inputs_source)))
            classifier_loss = network.CrossEntropyLabelSmooth(
                num_classes=params.class_num,
                epsilon=params.smooth)(outputs_source, labels_source)
            optimizer.zero_grad()
            classifier_loss.backward()
            optimizer.step()

        netF.eval()
        netB.eval()
        netC.eval()
        acc_s_tr, _ = network.cal_acc(src_data_loader, netF, netB, netC)
        acc_s_te, _ = network.cal_acc(src_data_loader_eval, netF, netB, netC)
        log_str = 'Task: {}, Iter:{}/{}; Accuracy = {:.2f}%/ {:.2f}%'.format(
            params.mode, epoch + 1, params.epochs, acc_s_tr * 100,
            acc_s_te * 100)
        out_file.write(log_str + '\n')
        out_file.flush()
        print(log_str + '\n')

        if acc_s_te >= acc_init:
            acc_init = acc_s_te
            best_netF = netF.state_dict()
            best_netB = netB.state_dict()
            best_netC = netC.state_dict()

    torch.save(best_netF, os.path.join(output_dir, "source_F_val.pt"))
    torch.save(best_netB, os.path.join(output_dir, "source_B_val.pt"))
    torch.save(best_netC, os.path.join(output_dir, "source_C_val.pt"))

    return netF, netB, netC
예제 #25
0
def train_target_rot(args):
    dset_loaders = data_load(args)
    ## set base network
    if args.net[0:3] == 'res':
        netF = network.ResBase(res_name=args.net).cuda()
    elif args.net[0:3] == 'vgg':
        netF = network.VGGBase(vgg_name=args.net).cuda()  

    netB = network.feat_bootleneck(type=args.classifier, feature_dim=netF.in_features, bottleneck_dim=args.bottleneck).cuda()
    netR = network.feat_classifier(type='linear', class_num=4, bottleneck_dim=2*args.bottleneck).cuda()

    modelpath = args.output_dir_src + '/source_F.pt'   
    netF.load_state_dict(torch.load(modelpath))
    netF.eval()
    for k, v in netF.named_parameters():
        v.requires_grad = False
    modelpath = args.output_dir_src + '/source_B.pt'   
    netB.load_state_dict(torch.load(modelpath))
    netB.eval()
    for k, v in netB.named_parameters():
        v.requires_grad = False

    param_group = []
    for k, v in netR.named_parameters():
        param_group += [{'params': v, 'lr': args.lr*1}]
    netR.train()
    optimizer = optim.SGD(param_group)
    optimizer = op_copy(optimizer)

    max_iter = args.max_epoch * len(dset_loaders["target"])
    interval_iter = max_iter // 10
    iter_num = 0

    rot_acc = 0
    while iter_num < max_iter:
        optimizer.zero_grad()
        try:
            inputs_test, _, tar_idx = iter_test.next()
        except:
            iter_test = iter(dset_loaders["target"])
            inputs_test, _, tar_idx = iter_test.next()

        if inputs_test.size(0) == 1:
            continue

        inputs_test = inputs_test.cuda()

        iter_num += 1
        lr_scheduler(optimizer, iter_num=iter_num, max_iter=max_iter)

        r_labels_target = np.random.randint(0, 4, len(inputs_test))
        r_inputs_target = rotation.rotate_batch_with_labels(inputs_test, r_labels_target)
        r_labels_target = torch.from_numpy(r_labels_target).cuda()
        r_inputs_target = r_inputs_target.cuda()
        
        f_outputs = netB(netF(inputs_test))
        f_r_outputs = netB(netF(r_inputs_target))
        r_outputs_target = netR(torch.cat((f_outputs, f_r_outputs), 1))

        rotation_loss = nn.CrossEntropyLoss()(r_outputs_target, r_labels_target)
        rotation_loss.backward() 

        optimizer.step()

        if iter_num % interval_iter == 0 or iter_num == max_iter:
            netR.eval()
            acc_rot = cal_acc_rot(dset_loaders['target'], netF, netB, netR)
            log_str = 'Task: {}, Iter:{}/{}; Accuracy = {:.2f}%'.format(args.name, iter_num, max_iter, acc_rot)
            args.out_file.write(log_str + '\n')
            args.out_file.flush()
            print(log_str+'\n')
            netR.train()

            if rot_acc < acc_rot:
                rot_acc = acc_rot
                best_netR = netR.state_dict()

    log_str = 'Best Accuracy = {:.2f}%'.format(rot_acc)
    args.out_file.write(log_str + '\n')
    args.out_file.flush()
    print(log_str+'\n')

    return best_netR, rot_acc
예제 #26
0
def train_source(args):
    dset_loaders = data_load(args)
    ## set base network
    if args.norm_layer == 'batchnorm':
        norm_layer = nn.BatchNorm2d
    elif args.norm_layer == 'groupnorm':

        def gn_helper(planes):
            return nn.GroupNorm(8, planes)

        norm_layer = gn_helper
    if args.net[0:3] == 'res':
        if '26' in args.net:
            netF = network.ResCifarBase(26, norm_layer=norm_layer)
            args.bottleneck = netF.in_features // 2
        else:
            netF = network.ResBase(res_name=args.net, args=args)
    elif args.net[0:3] == 'vgg':
        netF = network.VGGBase(vgg_name=args.net)

    if args.ssl_before_btn:
        netH = network.ssl_head(ssl_task=args.ssl_task,
                                feature_dim=netF.in_features,
                                embedding_dim=args.embedding_dim)
    else:
        netH = network.ssl_head(ssl_task=args.ssl_task,
                                feature_dim=args.bottleneck,
                                embedding_dim=args.embedding_dim)
    if args.bottleneck != 0:
        netB = network.feat_bootleneck(type=args.classifier,
                                       feature_dim=netF.in_features,
                                       bottleneck_dim=args.bottleneck,
                                       norm_btn=args.norm_btn)
        netC = network.feat_classifier(type=args.layer,
                                       class_num=args.class_num,
                                       bottleneck_dim=args.bottleneck,
                                       bias=args.classifier_bias,
                                       temp=args.angular_temp,
                                       args=args)
    else:
        netB = nn.Identity()
        netC = network.feat_classifier(type=args.layer,
                                       class_num=args.class_num,
                                       bottleneck_dim=netF.in_features,
                                       bias=args.classifier_bias,
                                       temp=args.angular_temp,
                                       args=args)

    if args.dataparallel:
        netF = nn.DataParallel(netF).cuda()
        netH = nn.DataParallel(netH).cuda()
        netB = nn.DataParallel(netB).cuda()
        netC = nn.DataParallel(netC).cuda()
    else:
        netF.cuda()
        netH.cuda()
        netB.cuda()
        netC.cuda()

    param_group = []
    learning_rate = args.lr
    for k, v in netF.named_parameters():
        if args.separate_wd and ('bias' in k or 'norm' in k):
            param_group += [{
                'params': v,
                'lr': learning_rate * 0.1,
                'weight_decay': 0
            }]
        else:
            param_group += [{
                'params': v,
                'lr': learning_rate * 0.1,
                'weight_decay': args.weight_decay
            }]
    for k, v in netH.named_parameters():
        if args.separate_wd and ('bias' in k or 'norm' in k):
            param_group += [{
                'params': v,
                'lr': learning_rate,
                'weight_decay': 0
            }]
        else:
            param_group += [{
                'params': v,
                'lr': learning_rate,
                'weight_decay': args.weight_decay
            }]
    for k, v in netB.named_parameters():
        if args.separate_wd and ('bias' in k or 'norm' in k):
            param_group += [{
                'params': v,
                'lr': learning_rate,
                'weight_decay': 0
            }]
        else:
            param_group += [{
                'params': v,
                'lr': learning_rate,
                'weight_decay': args.weight_decay
            }]
    for k, v in netC.named_parameters():
        if args.separate_wd and ('bias' in k or 'norm' in k):
            param_group += [{
                'params': v,
                'lr': learning_rate,
                'weight_decay': 0
            }]
        else:
            param_group += [{
                'params': v,
                'lr': learning_rate,
                'weight_decay': args.weight_decay
            }]

    optimizer = optim.SGD(param_group)
    optimizer = op_copy(optimizer)

    acc_init = 0
    if args.class_stratified:
        max_iter = args.max_epoch * len(
            dset_loaders["source_tr"].batch_sampler)
    else:
        max_iter = args.max_epoch * len(dset_loaders["source_tr"])
    interval_iter = max_iter // 10
    iter_num = 0
    epoch = 0

    netF.train()
    netH.train()
    netB.train()
    netC.train()

    if args.use_focal_loss:
        cls_loss_fn = FocalLoss(alpha=args.focal_alpha,
                                gamma=args.focal_gamma,
                                reduction='mean')
    else:
        if args.ce_weighting:
            w = torch.Tensor(args.ce_weight).cuda()
            w.requires_grad = False
            if args.smooth == 0:
                cls_loss_fn = nn.CrossEntropyLoss(weight=w).cuda()
            else:
                cls_loss_fn = CrossEntropyLabelSmooth(
                    num_classes=args.class_num, epsilon=args.smooth,
                    weight=w).cuda()
        else:
            if args.smooth == 0:
                cls_loss_fn = nn.CrossEntropyLoss().cuda()
            else:
                cls_loss_fn = CrossEntropyLabelSmooth(
                    num_classes=args.class_num, epsilon=args.smooth).cuda()

    if args.ssl_task in ['simclr', 'crs']:
        if args.use_new_ntxent:
            ssl_loss_fn = SupConLoss(temperature=args.temperature,
                                     base_temperature=args.temperature).cuda()
        else:
            ssl_loss_fn = NTXentLoss(args.batch_size, args.temperature,
                                     True).cuda()
    elif args.ssl_task in ['supcon', 'crsc']:
        ssl_loss_fn = SupConLoss(temperature=args.temperature,
                                 base_temperature=args.temperature).cuda()
    elif args.ssl_task == 'ls_supcon':
        ssl_loss_fn = LabelSmoothedSCLLoss(args.batch_size, args.temperature,
                                           args.class_num, args.ssl_smooth)

    if args.cr_weight > 0:
        if args.cr_metric == 'cos':
            dist = nn.CosineSimilarity(dim=1).cuda()
        elif args.cr_metric == 'l1':
            dist = nn.PairwiseDistance(p=1).cuda()
        elif args.cr_metric == 'l2':
            dist = nn.PairwiseDistance(p=2).cuda()
        elif args.cr_metric == 'bce':
            dist = nn.BCEWithLogitsLoss(reduction='sum').cuda()
        elif args.cr_metric == 'kl':
            dist = nn.KLDivLoss(reduction='sum').cuda()
        elif args.cr_metric == 'js':
            dist = JSDivLoss(reduction='sum').cuda()

    use_second_pass = (args.ssl_task in ['simclr', 'supcon', 'ls_supcon'
                                         ]) and (args.ssl_weight > 0)
    use_third_pass = (args.cr_weight >
                      0) or (args.ssl_task in ['crsc', 'crs']
                             and args.ssl_weight > 0) or (args.cls3)

    while iter_num < max_iter:
        try:
            inputs_source, labels_source = iter_source.next()
        except:
            iter_source = iter(dset_loaders["source_tr"])
            if args.class_stratified:
                dset_loaders["source_tr"].batch_sampler.set_epoch(epoch)
            epoch += 1
            inputs_source, labels_source = iter_source.next()

        try:
            if inputs_source.size(0) == 1:
                continue
        except:
            if inputs_source[0].size(0) == 1:
                continue

        iter_num += 1
        lr_scheduler(args, optimizer, iter_num=iter_num, max_iter=max_iter)

        inputs_source1 = None
        inputs_source2 = None
        inputs_source3 = None
        labels_source = labels_source.cuda()

        if args.layer in ['add_margin', 'arc_margin', 'shpere']:
            labels_forward = labels_source
        else:
            labels_forward = None

        if type(inputs_source) is list:
            inputs_source1 = inputs_source[0].cuda()
            inputs_source2 = inputs_source[1].cuda()
            if len(inputs_source) == 3:
                inputs_source3 = inputs_source[2].cuda()
        else:
            inputs_source1 = inputs_source.cuda()

        if inputs_source1 is not None:
            f1 = netF(inputs_source1)
            b1 = netB(f1)
            outputs_source = netC(b1, labels_forward)
        if use_second_pass:
            f2 = netF(inputs_source2)
            b2 = netB(f2)
        if use_third_pass:
            if args.sg3:
                with torch.no_grad():
                    f3 = netF(inputs_source3)
                    b3 = netB(f3)
                    c3 = netC(b3, labels_forward)
                    conf = torch.max(F.softmax(c3, dim=1), dim=1)[0]
            else:
                f3 = netF(inputs_source3)
                b3 = netB(f3)
                c3 = netC(b3, labels_forward)
                conf = torch.max(F.softmax(c3, dim=1), dim=1)[0]

        if args.cr_weight > 0:
            if args.cr_site == 'feat':
                f_hard = f1
                f_weak = f3
            elif args.cr_site == 'btn':
                f_hard = b1
                f_weak = b3
            elif args.cr_site == 'cls':
                f_hard = outputs_source
                f_weak = c3
                if args.cr_metric in ['kl', 'js']:
                    f_hard = F.softmax(f_hard, dim=-1)
                if args.cr_metric in ['bce', 'kl', 'js']:
                    f_weak = F.softmax(f_weak, dim=-1)
            else:
                raise NotImplementedError

        classifier_loss = cls_loss_fn(outputs_source, labels_source)
        if args.cls3:
            classifier_loss += cls_loss_fn(c3, labels_source)

        if args.ssl_weight > 0:
            if args.ssl_before_btn:
                z1 = netH(f1, args.norm_feat)
                if use_second_pass:
                    z2 = netH(f2, args.norm_feat)
                if use_third_pass:
                    z3 = netH(f3, args.norm_feat)
            else:
                z1 = netH(b1, args.norm_feat)
                if use_second_pass:
                    z2 = netH(b2, args.norm_feat)
                if use_third_pass:
                    z3 = netH(b3, args.norm_feat)

            if args.ssl_task in 'simclr':
                if args.use_new_ntxent:
                    z = torch.cat([z1.unsqueeze(1), z2.unsqueeze(1)], dim=1)
                    ssl_loss = ssl_loss_fn(z)
                else:
                    ssl_loss = ssl_loss_fn(z1, z2)
            elif args.ssl_task == 'supcon':
                z = torch.cat([z1.unsqueeze(1), z2.unsqueeze(1)], dim=1)
                ssl_loss = ssl_loss_fn(z, labels=labels_source)
            elif args.ssl_task == 'ls_supcon':
                ssl_loss = ssl_loss_fn(z1, z2, labels_source)
            elif args.ssl_task == 'crsc':
                z = torch.cat([z1.unsqueeze(1), z3.unsqueeze(1)], dim=1)
                ssl_loss = ssl_loss_fn(z, labels_source)
            elif args.ssl_task == 'crs':
                if args.use_new_ntxent:
                    z = torch.cat([z1.unsqueeze(1), z3.unsqueeze(1)], dim=1)
                    ssl_loss = ssl_loss_fn(z)
                else:
                    ssl_loss = ssl_loss_fn(z1, z3)
        else:
            ssl_loss = torch.tensor(0.0).cuda()

        if args.cr_weight > 0:
            try:
                cr_loss = dist(f_hard[conf <= args.cr_threshold],
                               f_weak[conf <= args.cr_threshold]).mean()

                if args.cr_metric == 'cos':
                    cr_loss *= -1
            except:
                print('Error computing CR loss')
                cr_loss = torch.tensor(0.0).cuda()
        else:
            cr_loss = torch.tensor(0.0).cuda()

        if args.ent_weight > 0:
            softmax_out = nn.Softmax(dim=1)(outputs_source)
            entropy_loss = torch.mean(Entropy(softmax_out))
            classifier_loss += args.ent_weight * entropy_loss

        if args.gent_weight > 0:
            softmax_out = nn.Softmax(dim=1)(outputs_source)
            msoftmax = softmax_out.mean(dim=0)
            gentropy_loss = torch.sum(-msoftmax *
                                      torch.log(msoftmax + args.epsilon))
            classifier_loss -= args.gent_weight * gentropy_loss

        loss = classifier_loss + args.ssl_weight * ssl_loss + args.cr_weight * cr_loss

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if iter_num % interval_iter == 0 or iter_num == max_iter:
            netF.eval()
            netH.eval()
            netB.eval()
            netC.eval()
            if args.dset == 'visda-c':
                acc_s_te, acc_list = cal_acc(dset_loaders['source_te'], netF,
                                             netH, netB, netC, args, True)
                log_str = 'Task: {}, Iter:{}/{}; Accuracy = {:.2f}%'.format(
                    args.name_src, iter_num, max_iter,
                    acc_s_te) + '\n' + acc_list
            else:
                acc_s_te, _ = cal_acc(dset_loaders['source_te'], netF, netH,
                                      netB, netC, args, False)
                log_str = 'Task: {}, Iter:{}/{}; Accuracy = {:.2f}%'.format(
                    args.name_src, iter_num, max_iter, acc_s_te)
            args.out_file.write(log_str + '\n')
            args.out_file.flush()
            print(log_str + '\n')

            if acc_s_te >= acc_init:
                acc_init = acc_s_te

                if args.dataparallel:
                    best_netF = netF.module.state_dict()
                    best_netH = netH.module.state_dict()
                    best_netB = netB.module.state_dict()
                    best_netC = netC.module.state_dict()
                else:
                    best_netF = netF.state_dict()
                    best_netH = netH.state_dict()
                    best_netB = netB.state_dict()
                    best_netC = netC.state_dict()

            netF.train()
            netH.train()
            netB.train()
            netC.train()

    torch.save(best_netF, osp.join(args.output_dir_src, "source_F.pt"))
    torch.save(best_netH, osp.join(args.output_dir_src, "source_H.pt"))
    torch.save(best_netB, osp.join(args.output_dir_src, "source_B.pt"))
    torch.save(best_netC, osp.join(args.output_dir_src, "source_C.pt"))

    return netF, netH, netB, netC
예제 #27
0
파일: image_source.py 프로젝트: yyht/SHOT
def train_source(args):
    dset_loaders = data_load(args)
    ## set base network
    if args.net[0:3] == 'res':
        netF = network.ResBase(res_name=args.net).cuda()

    netB = network.feat_bootleneck(type=args.classifier,
                                   feature_dim=netF.in_features,
                                   bottleneck_dim=args.bottleneck).cuda()
    netC = network.feat_classifier(type=args.layer,
                                   class_num=args.class_num,
                                   bottleneck_dim=args.bottleneck).cuda()

    param_group = []
    learning_rate = args.lr
    for k, v in netF.named_parameters():
        param_group += [{'params': v, 'lr': learning_rate}]
    for k, v in netB.named_parameters():
        param_group += [{'params': v, 'lr': learning_rate * 10}]
    for k, v in netC.named_parameters():
        param_group += [{'params': v, 'lr': learning_rate * 10}]
    optimizer = optim.SGD(param_group,
                          momentum=0.9,
                          weight_decay=5e-4,
                          nesterov=True)

    acc_init = 0
    for epoch in tqdm(range(args.max_epoch), leave=False):
        netF.train()
        netB.train()
        netC.train()
        iter_source = iter(dset_loaders['source_tr'])
        for _, (inputs_source, labels_source) in tqdm(enumerate(iter_source),
                                                      leave=False):
            if inputs_source.size(0) == 1:
                continue
            inputs_source, labels_source = inputs_source.cuda(
            ), labels_source.cuda()
            outputs_source = netC(netB(netF(inputs_source)))
            classifier_loss = loss.CrossEntropyLabelSmooth(
                num_classes=args.class_num,
                epsilon=args.smooth)(outputs_source, labels_source)
            optimizer.zero_grad()
            classifier_loss.backward()
            optimizer.step()

        if (epoch + 1) % 5 == 0 and args.trte == 'full':
            netF.eval()
            netB.eval()
            netC.eval()
            acc_s_te, _ = cal_acc(dset_loaders['source_te'], netF, netB, netC,
                                  args.dset == 'visda17')
            log_str = 'Task: {}, Iter:{}/{}; Accuracy = {:.2f}%'.format(
                args.name_src, epoch + 1, args.max_epoch, acc_s_te * 100)
            args.out_file.write(log_str + '\n')
            args.out_file.flush()
            print(log_str + '\n')
            best_netF = netF.state_dict()
            best_netB = netB.state_dict()
            best_netC = netC.state_dict()
            torch.save(
                best_netF,
                osp.join(args.output_dir_src,
                         'source_F_' + str(epoch + 1) + '.pt'))
            torch.save(
                best_netB,
                osp.join(args.output_dir_src,
                         'source_B_' + str(epoch + 1) + '.pt'))
            torch.save(
                best_netC,
                osp.join(args.output_dir_src,
                         'source_C_' + str(epoch + 1) + '.pt'))

        if args.trte == 'val':
            netF.eval()
            netB.eval()
            netC.eval()
            acc_s_tr, _ = cal_acc(dset_loaders['source_tr'], netF, netB, netC)
            acc_s_te, _ = cal_acc(dset_loaders['source_te'], netF, netB, netC)
            log_str = 'Task: {}, Iter:{}/{}; Accuracy = {:.2f}%/ {:.2f}%'.format(
                args.name_src, epoch + 1, args.max_epoch, acc_s_tr * 100,
                acc_s_te * 100)
            args.out_file.write(log_str + '\n')
            args.out_file.flush()
            print(log_str + '\n')

            if acc_s_te > acc_init:
                acc_init = acc_s_te
                best_netF = netF.state_dict()
                best_netB = netB.state_dict()
                best_netC = netC.state_dict()

    torch.save(best_netF, osp.join(args.output_dir_src, 'source_F_val.pt'))
    torch.save(best_netB, osp.join(args.output_dir_src, 'source_B_val.pt'))
    torch.save(best_netC, osp.join(args.output_dir_src, 'source_C_val.pt'))
    return netF, netB, netC
예제 #28
0
def test_target(args):
    dset_loaders = data_load(args)
    ## set base network
    if args.norm_layer == 'batchnorm':
        norm_layer = nn.BatchNorm2d
    elif args.norm_layer == 'groupnorm':

        def gn_helper(planes):
            return nn.GroupNorm(8, planes)

        norm_layer = gn_helper
    if args.net[0:3] == 'res':
        if '26' in args.net:
            netF = network.ResCifarBase(26, norm_layer=norm_layer)
        else:
            netF = network.ResBase(res_name=args.net, args=args)
    elif args.net[0:3] == 'vgg':
        netF = network.VGGBase(vgg_name=args.net)

    if args.ssl_before_btn:
        netH = network.ssl_head(ssl_task=args.ssl_task,
                                feature_dim=netF.in_features,
                                embedding_dim=args.embedding_dim)
    else:
        netH = network.ssl_head(ssl_task=args.ssl_task,
                                feature_dim=args.bottleneck,
                                embedding_dim=args.embedding_dim)
    if args.bottleneck != 0:
        netB = network.feat_bootleneck(type=args.classifier,
                                       feature_dim=netF.in_features,
                                       bottleneck_dim=args.bottleneck,
                                       norm_btn=args.norm_btn)
        netC = network.feat_classifier(type=args.layer,
                                       class_num=args.class_num,
                                       bottleneck_dim=args.bottleneck,
                                       bias=args.classifier_bias,
                                       args=args)
    else:
        netB = nn.Identity()
        netC = network.feat_classifier(type=args.layer,
                                       class_num=args.class_num,
                                       bottleneck_dim=netF.in_features,
                                       bias=args.classifier_bias,
                                       args=args)

    args.modelpath = args.output_dir_src + '/source_F.pt'
    netF.load_state_dict(torch.load(args.modelpath))
    args.modelpath = args.output_dir_src + '/source_H.pt'
    netH.load_state_dict(torch.load(args.modelpath))
    try:
        args.modelpath = args.output_dir_src + '/source_B.pt'
        netB.load_state_dict(torch.load(args.modelpath))
    except:
        print('Skipped loading btn for version compatibility')
    args.modelpath = args.output_dir_src + '/source_C.pt'
    netC.load_state_dict(torch.load(args.modelpath))

    if args.dataparallel:
        netF = nn.DataParallel(netF).cuda()
        netH = nn.DataParallel(netH).cuda()
        netB = nn.DataParallel(netB).cuda()
        netC = nn.DataParallel(netC).cuda()
    else:
        netF.cuda()
        netH.cuda()
        netB.cuda()
        netC.cuda()

    netF.eval()
    netH.eval()
    netB.eval()
    netC.eval()

    if args.da == 'oda':
        acc_os1, acc_os2, acc_unknown = cal_acc_oda(dset_loaders['test'], netF,
                                                    netH, netB, netC)
        log_str = '\nTraining: {}, Task: {}, Accuracy = {:.2f}% / {:.2f}% / {:.2f}%'.format(
            args.trte, args.name, acc_os2, acc_os1, acc_unknown)
    else:
        if args.dset in ['visda-c', 'CIFAR-10-C', 'CIFAR-100-C']:
            acc, acc_list = cal_acc(dset_loaders['test'], netF, netH, netB,
                                    netC, args, True)
            log_str = '\nTraining: {}, Task: {}, Accuracy = {:.2f}%'.format(
                args.trte, args.name, acc) + '\n' + acc_list
        else:
            acc, _ = cal_acc(dset_loaders['test'], netF, netH, netB, netC,
                             args, False)
            log_str = '\nTraining: {}, Task: {}, Accuracy = {:.2f}%'.format(
                args.trte, args.name, acc)

    args.out_file.write(log_str)
    args.out_file.flush()
    print(log_str)
예제 #29
0
파일: image_target.py 프로젝트: yyht/SHOT
def train_target(args, zz=''):
    dset_loaders = data_load(args)
    ## set base network
    if args.net[0:3] == 'res':
        netF = network.ResBase(res_name=args.net).cuda()

    netB = network.feat_bootleneck(type=args.classifier,
                                   feature_dim=netF.in_features,
                                   bottleneck_dim=args.bottleneck).cuda()
    netC = network.feat_classifier(type=args.layer,
                                   class_num=args.class_num,
                                   bottleneck_dim=args.bottleneck).cuda()

    args.modelpath = args.output_dir_src + '/source_F_' + str(zz) + '.pt'
    netF.load_state_dict(torch.load(args.modelpath))
    args.modelpath = args.output_dir_src + '/source_B_' + str(zz) + '.pt'
    netB.load_state_dict(torch.load(args.modelpath))
    args.modelpath = args.output_dir_src + '/source_C_' + str(zz) + '.pt'
    netC.load_state_dict(torch.load(args.modelpath))
    netC.eval()
    for k, v in netC.named_parameters():
        v.requires_grad = False

    param_group = []
    for k, v in netF.named_parameters():
        if args.lr_decay1 > 0:
            param_group += [{'params': v, 'lr': args.lr * args.lr_decay1}]
        else:
            v.requires_grad = False
    for k, v in netB.named_parameters():
        if args.lr_decay2 > 0:
            param_group += [{'params': v, 'lr': args.lr * args.lr_decay2}]
        else:
            v.requires_grad = False
    optimizer = optim.SGD(param_group,
                          momentum=0.9,
                          weight_decay=5e-4,
                          nesterov=True)

    for epoch in tqdm(range(args.max_epoch), leave=False):
        netF.eval()
        netB.eval()
        mem_label = obtain_label(dset_loaders['test'], netF, netB, netC, args)
        mem_label = torch.from_numpy(mem_label).cuda()
        netF.train()
        netB.train()
        iter_test = iter(dset_loaders['target'])

        for _, (inputs_test, _, tar_idx) in tqdm(enumerate(iter_test),
                                                 leave=False):
            if inputs_test.size(0) == 1:
                continue
            inputs_test = inputs_test.cuda()

            pred = mem_label[tar_idx]
            features_test = netB(netF(inputs_test))
            outputs_test = netC(features_test)

            classifier_loss = loss.CrossEntropyLabelSmooth(
                num_classes=args.class_num, epsilon=0)(outputs_test, pred)
            classifier_loss *= args.cls_par

            if args.ent:
                softmax_out = nn.Softmax(dim=1)(outputs_test)
                entropy_loss = torch.mean(loss.Entropy(softmax_out))
                if args.gent:
                    msoftmax = softmax_out.mean(dim=0)
                    gentropy_loss = torch.sum(
                        -msoftmax * torch.log(msoftmax + args.epsilon))
                    entropy_loss -= gentropy_loss
                classifier_loss += entropy_loss * args.ent_par

            optimizer.zero_grad()
            classifier_loss.backward()
            optimizer.step()

        netF.eval()
        netB.eval()
        acc, _ = cal_acc(dset_loaders['test'], netF, netB, netC)
        log_str = 'Task: {}, Iter:{}/{}; Accuracy = {:.2f}%'.format(
            args.name, epoch + 1, args.max_epoch, acc * 100)
        args.out_file.write(log_str + '\n')
        args.out_file.flush()
        print(log_str + '\n')

    if args.issave:
        torch.save(
            netF.state_dict(),
            osp.join(args.output_dir, 'target_F_' + args.savename + '.pt'))
        torch.save(
            netB.state_dict(),
            osp.join(args.output_dir, 'target_B_' + args.savename + '.pt'))
        torch.save(
            netC.state_dict(),
            osp.join(args.output_dir, 'target_C_' + args.savename + '.pt'))

    return netF, netB, netC
예제 #30
0
def train_target(args):
    dset_loaders = data_load(args)
    ## set base network
    if args.norm_layer == 'batchnorm':
        norm_layer = nn.BatchNorm2d
    elif args.norm_layer == 'groupnorm':

        def gn_helper(planes):
            return nn.GroupNorm(8, planes)

        norm_layer = gn_helper
    if args.net[0:3] == 'res':
        if '26' in args.net:
            netF = network.ResCifarBase(26, norm_layer=norm_layer)
            args.bottleneck = netF.in_features // 2
        else:
            netF = network.ResBase(res_name=args.net, args=args)
    elif args.net[0:3] == 'vgg':
        netF = network.VGGBase(vgg_name=args.net)

    # print(args.ssl_before_btn)
    if args.ssl_before_btn:
        netH = network.ssl_head(ssl_task=args.ssl_task,
                                feature_dim=netF.in_features,
                                embedding_dim=args.embedding_dim)
    else:
        netH = network.ssl_head(ssl_task=args.ssl_task,
                                feature_dim=args.bottleneck,
                                embedding_dim=args.embedding_dim)
    if args.bottleneck != 0:
        netB = network.feat_bootleneck(type=args.classifier,
                                       feature_dim=netF.in_features,
                                       bottleneck_dim=args.bottleneck,
                                       norm_btn=args.norm_btn)
        if args.reset_running_stats and args.classifier == 'bn':
            netB.norm.running_mean.fill_(0.)
            netB.norm.running_var.fill_(1.)

        if args.reset_bn_params and args.classifier == 'bn':
            netB.norm.weight.data.fill_(1.)
            netB.norm.bias.data.fill_(0.)
        netC = network.feat_classifier(type=args.layer,
                                       class_num=args.class_num,
                                       bottleneck_dim=args.bottleneck,
                                       bias=args.classifier_bias,
                                       temp=args.angular_temp,
                                       args=args)
    else:
        netB = nn.Identity()
        netC = network.feat_classifier(type=args.layer,
                                       class_num=args.class_num,
                                       bottleneck_dim=netF.in_features,
                                       bias=args.classifier_bias,
                                       temp=args.angular_temp,
                                       args=args)

    modelpath = args.output_dir_src + '/source_F.pt'
    netF.load_state_dict(torch.load(modelpath), strict=False)
    modelpath = args.output_dir_src + '/source_H.pt'
    netH.load_state_dict(torch.load(modelpath), strict=False)
    try:
        modelpath = args.output_dir_src + '/source_B.pt'
        netB.load_state_dict(torch.load(modelpath), strict=False)
    except:
        print('Skipped loading btn for version compatibility')
    modelpath = args.output_dir_src + '/source_C.pt'
    netC.load_state_dict(torch.load(modelpath), strict=False)
    netC.eval()
    for k, v in netC.named_parameters():
        v.requires_grad = False

    if args.dataparallel:
        netF = nn.DataParallel(netF).cuda()
        netH = nn.DataParallel(netH).cuda()
        netB = nn.DataParallel(netB).cuda()
        netC = nn.DataParallel(netC).cuda()
    else:
        netF.cuda()
        netH.cuda()
        netB.cuda()
        netC.cuda()

    param_group = []
    for k, v in netF.named_parameters():
        if args.lr_decay1 > 0:
            param_group += [{'params': v, 'lr': args.lr * args.lr_decay1}]
        else:
            v.requires_grad = False
    for k, v in netB.named_parameters():
        if args.lr_decay2 > 0:
            param_group += [{'params': v, 'lr': args.lr * args.lr_decay2}]
        else:
            v.requires_grad = False
    for k, v in netH.named_parameters():
        if args.lr_decay2 > 0:
            param_group += [{'params': v, 'lr': args.lr * args.lr_decay2}]
        else:
            v.requires_grad = False

    if args.ssl_task in ['simclr', 'crs']:
        ssl_loss_fn = NTXentLoss(args.batch_size, args.temperature,
                                 True).cuda()
    elif args.ssl_task in ['supcon', 'crsc']:
        ssl_loss_fn = SupConLoss(temperature=args.temperature,
                                 base_temperature=args.temperature).cuda()
    elif args.ssl_task == 'ls_supcon':
        ssl_loss_fn = LabelSmoothedSCLLoss(args.batch_size, args.temperature,
                                           args.class_num, args.ssl_smooth)

    if args.cr_weight > 0:
        if args.cr_metric == 'cos':
            dist = nn.CosineSimilarity(dim=1).cuda()
        elif args.cr_metric == 'l1':
            dist = nn.PairwiseDistance(p=1)
        elif args.cr_metric == 'l2':
            dist = nn.PairwiseDistance(p=2)
        elif args.cr_metric == 'bce':
            dist = nn.BCEWithLogitsLoss(reduction='sum').cuda()
        elif args.cr_metric == 'kl':
            dist = nn.KLDivLoss(reduction='sum').cuda()

    use_second_pass = (args.ssl_task in ['simclr', 'supcon', 'ls_supcon'
                                         ]) and (args.ssl_weight > 0)
    use_third_pass = (args.cr_weight >
                      0) or (args.ssl_task in ['crsc', 'crs']
                             and args.ssl_weight > 0) or (args.cls3)

    optimizer = optim.SGD(param_group)
    optimizer = op_copy(optimizer)

    max_iter = args.max_epoch * len(dset_loaders["target"])
    interval_iter = max_iter // args.interval
    iter_num = 0

    centroid = None

    while iter_num < max_iter:
        try:
            inputs_test, _, tar_idx = iter_test.next()
        except:
            iter_test = iter(dset_loaders["target"])
            inputs_test, _, tar_idx = iter_test.next()

        try:
            if inputs_test.size(0) == 1:
                continue
        except:
            if inputs_test[0].size(0) == 1:
                continue

        if iter_num % interval_iter == 0 and (
                args.cls_par > 0
                or args.ssl_task in ['supcon', 'ls_supcon', 'crsc']):
            netF.eval()
            netH.eval()
            netB.eval()
            if centroid is None or args.recompute_centroid:
                mem_label, mem_conf, centroid, labelset = obtain_label(
                    dset_loaders['pl'], netF, netH, netB, netC, args)
                mem_label = torch.from_numpy(mem_label).cuda()
            else:
                pass

            netF.train()
            netH.train()
            netB.train()

        inputs_test1 = None
        inputs_test2 = None
        inputs_test3 = None

        pred = mem_label[tar_idx]

        if type(inputs_test) is list:
            inputs_test1 = inputs_test[0].cuda()
            inputs_test2 = inputs_test[1].cuda()
            if len(inputs_test) == 3:
                inputs_test3 = inputs_test[2].cuda()
        else:
            inputs_test1 = inputs_test.cuda()

        if args.layer in ['add_margin', 'arc_margin', 'sphere'
                          ] and args.use_margin_forward:
            labels_forward = pred
        else:
            labels_forward = None

        if inputs_test is not None:
            f1 = netF(inputs_test1)
            b1 = netB(f1)
            outputs_test = netC(b1, labels_forward)
        if use_second_pass:
            f2 = netF(inputs_test2)
            b2 = netB(f2)
        if use_third_pass:
            if args.sg3:
                with torch.no_grad():
                    f3 = netF(inputs_test3)
                    b3 = netB(f3)
                    c3 = netC(b3, labels_forward)
                    conf = torch.max(F.softmax(c3, dim=1), dim=1)[0]
            else:
                f3 = netF(inputs_test3)
                b3 = netB(f3)
                c3 = netC(b3, labels_forward)
                conf = torch.max(F.softmax(c3, dim=1), dim=1)[0]

        iter_num += 1
        lr_scheduler(args,
                     optimizer,
                     iter_num=iter_num,
                     max_iter=max_iter,
                     gamma=args.gamma,
                     power=args.power)

        pred = compute_pl(args, b3, centroid, labelset)

        if args.cr_weight > 0:
            if args.cr_site == 'feat':
                f_hard = f1
                f_weak = f3
            elif args.cr_site == 'btn':
                f_hard = b1
                f_weak = b3
            elif args.cr_site == 'cls':
                f_hard = outputs_test
                f_weak = c3
                if args.cr_metric != 'cos':
                    f_hard = F.softmax(f_hard, dim=-1)
                    f_weak = F.softmax(f_weak, dim=-1)
            else:
                raise NotImplementedError

        if args.cls_par > 0:
            # with torch.no_grad():
            #    conf, _ = torch.max(F.softmax(outputs_test, dim=-1), dim=-1)
            #    conf = conf.cpu().numpy()
            conf_cls = mem_conf[tar_idx]

            #pred = mem_label[tar_idx]
            if args.cls_smooth > 0:
                classifier_loss = CrossEntropyLabelSmooth(
                    num_classes=args.class_num, epsilon=args.cls_smooth)(
                        outputs_test[conf_cls >= args.conf_threshold],
                        pred[conf_cls >= args.conf_threshold])
            else:
                classifier_loss = nn.CrossEntropyLoss()(
                    outputs_test[conf_cls >= args.conf_threshold],
                    pred[conf_cls >= args.conf_threshold])
            if args.cls3:
                if args.cls_smooth > 0:
                    classifier_loss = CrossEntropyLabelSmooth(
                        num_classes=args.class_num, epsilon=args.cls_smooth)(
                            c3[conf_cls >= args.conf_threshold],
                            pred[conf_cls >= args.conf_threshold])
                else:
                    classifier_loss = nn.CrossEntropyLoss()(
                        c3[conf_cls >= args.conf_threshold],
                        pred[conf_cls >= args.conf_threshold])
            classifier_loss *= args.cls_par
            if iter_num < interval_iter and args.dset == "visda-c":
                classifier_loss *= 0
        else:
            classifier_loss = torch.tensor(0.0).cuda()

        if args.ent:
            softmax_out = nn.Softmax(dim=1)(outputs_test)
            entropy_loss = torch.mean(loss.Entropy(softmax_out))
            if args.gent:
                msoftmax = softmax_out.mean(dim=0)
                gentropy_loss = torch.sum(-msoftmax *
                                          torch.log(msoftmax + args.epsilon))
                entropy_loss -= gentropy_loss
            im_loss = entropy_loss * args.ent_par
            classifier_loss += im_loss

        if args.ssl_weight > 0:
            if args.ssl_before_btn:
                z1 = netH(f1, args.norm_feat)
                if use_second_pass:
                    z2 = netH(f2, args.norm_feat)
                if use_third_pass:
                    z3 = netH(f3, args.norm_feat)
            else:
                z1 = netH(b1, args.norm_feat)
                if use_second_pass:
                    z2 = netH(b2, args.norm_feat)
                if use_third_pass:
                    z3 = netH(b3, args.norm_feat)

            if args.ssl_task == 'simclr':
                ssl_loss = ssl_loss_fn(z1, z2)
            elif args.ssl_task == 'supcon':
                z = torch.cat([z1.unsqueeze(1), z2.unsqueeze(1)], dim=1)
                pl = mem_label[tar_idx]
                ssl_loss = ssl_loss_fn(z, pl)
            elif args.ssl_task == 'ls_supcon':
                pl = mem_label[tar_idx]
                ssl_loss = ssl_loss_fn(z1, z2, pl).squeeze()
            elif args.ssl_task == 'crsc':
                z = torch.cat([z1.unsqueeze(1), z3.unsqueeze(1)], dim=1)
                pl = mem_label[tar_idx]
                ssl_loss = ssl_loss_fn(z, pl)
            elif args.ssl_task == 'crs':
                ssl_loss = ssl_loss_fn(z1, z3)
            classifier_loss += args.ssl_weight * ssl_loss

        if args.cr_weight > 0:
            try:
                cr_loss = dist(f_hard[conf >= args.cr_threshold],
                               f_weak[conf >= args.cr_threshold]).mean()

                if args.cr_metric == 'cos':
                    cr_loss *= -1
            except:
                print('Error computing CR loss')
                cr_loss = torch.tensor(0.0).cuda()
            classifier_loss += args.cr_weight * cr_loss

        optimizer.zero_grad()
        classifier_loss.backward()
        optimizer.step()

        centroid = update_centroid(args, b3, centroid, c3, labelset)

        if iter_num % interval_iter == 0 or iter_num == max_iter:
            netF.eval()
            netH.eval()
            netB.eval()
            if args.dset in ['visda-c', 'CIFAR-10-C', 'CIFAR-100-C']:
                acc_s_te, acc_list = cal_acc(dset_loaders['test'], netF, netH,
                                             netB, netC, True)
                log_str = 'Task: {}, Iter:{}/{}; Accuracy = {:.2f}%'.format(
                    args.name, iter_num, max_iter, acc_s_te) + '\n' + acc_list
            else:
                acc_s_te, _ = cal_acc(dset_loaders['test'], netF, netH, netB,
                                      netC, False)
                log_str = 'Task: {}, Iter:{}/{}; Accuracy = {:.2f}%'.format(
                    args.name, iter_num, max_iter, acc_s_te)

            args.out_file.write(log_str + '\n')
            args.out_file.flush()
            print(log_str + '\n')
            netF.train()
            netH.train()
            netB.train()

    if args.issave:
        if args.dataparallel:
            torch.save(
                netF.module.state_dict(),
                osp.join(args.output_dir, "target_F_" + args.savename + ".pt"))
            torch.save(
                netH.module.state_dict(),
                osp.join(args.output_dir, "target_H_" + args.savename + ".pt"))
            torch.save(
                netB.module.state_dict(),
                osp.join(args.output_dir, "target_B_" + args.savename + ".pt"))
            torch.save(
                netC.module.state_dict(),
                osp.join(args.output_dir, "target_C_" + args.savename + ".pt"))
        else:
            torch.save(
                netF.state_dict(),
                osp.join(args.output_dir, "target_F_" + args.savename + ".pt"))
            torch.save(
                netH.state_dict(),
                osp.join(args.output_dir, "target_H_" + args.savename + ".pt"))
            torch.save(
                netB.state_dict(),
                osp.join(args.output_dir, "target_B_" + args.savename + ".pt"))
            torch.save(
                netC.state_dict(),
                osp.join(args.output_dir, "target_C_" + args.savename + ".pt"))

    return netF, netH, netB, netC