def test_target(args): dset_loaders = data_load(args) ## set base network if args.net[0:3] == 'res': netF = network.ResBase(res_name=args.net).cuda() elif args.net[0:3] == 'vgg': netF = network.VGGBase(vgg_name=args.net).cuda() netB = network.feat_bootleneck(type=args.classifier, feature_dim=netF.in_features, bottleneck_dim=args.bottleneck).cuda() netC = network.feat_classifier(type=args.layer, class_num=args.class_num, bottleneck_dim=args.bottleneck).cuda() args.modelpath = args.output_dir_ori + "/target_F_" + args.savename + ".pt" netF.load_state_dict(torch.load(args.modelpath)) args.modelpath = args.output_dir_ori + "/target_B_" + args.savename + ".pt" netB.load_state_dict(torch.load(args.modelpath)) args.modelpath = args.output_dir_ori + "/target_C_" + args.savename + ".pt" netC.load_state_dict(torch.load(args.modelpath)) netF.eval() netB.eval() netC.eval() acc, y, py = cal_acc(dset_loaders['test'], netF, netB, netC) log_str = '\nTask: {}, Accuracy = {:.2f}%'.format(args.name, acc * 100) args.out_file.write(log_str) args.out_file.flush() print(log_str) return y, py
def test(args): dset_loaders = digit_load(args) ## set base network if args.dset == 'u': netF = network.LeNetBase() #.cuda() elif args.dset == 'm': netF = network.LeNetBase() #.cuda() elif args.dset == 's': netF = network.DTNBase() #.cuda() netB = network.feat_bootleneck(type=args.classifier, feature_dim=netF.in_features, bottleneck_dim=args.bottleneck) #.cuda() netC = network.feat_classifier(type=args.layer, class_num=args.class_num, bottleneck_dim=args.bottleneck) #.cuda() args.modelpath = args.output_dir + '/F.pt' netF.load_state_dict(torch.load(args.modelpath)) args.modelpath = args.output_dir + '/B.pt' netB.load_state_dict(torch.load(args.modelpath)) args.modelpath = args.output_dir + '/C.pt' netC.load_state_dict(torch.load(args.modelpath)) netF.eval() netB.eval() netC.eval() acc, _ = cal_acc(dset_loaders['test'], netF, netB, netC) log_str = 'Task: {}, [DONT CARE] Accuracy = {:.2f}%'.format(args.dset, acc) try: args.out_file.write(log_str + '\n') args.out_file.flush() except: pass print(log_str + '\n')
def test_target(args): dset_loaders = digit_load(args) ## set base network if args.dset == 'u2m': netF = network.LeNetBase().cuda() elif args.dset == 'm2u': netF = network.LeNetBase().cuda() elif args.dset == 's2m': netF = network.DTNBase().cuda() netB = network.feat_bootleneck(type=args.classifier, feature_dim=netF.in_features, bottleneck_dim=args.bottleneck).cuda() netC = network.feat_classifier(type=args.layer, class_num=args.class_num, bottleneck_dim=args.bottleneck).cuda() args.modelpath = args.output_dir + '/source_F_val.pt' netF.load_state_dict(torch.load(args.modelpath)) args.modelpath = args.output_dir + '/source_B_val.pt' netB.load_state_dict(torch.load(args.modelpath)) args.modelpath = args.output_dir + '/source_C_val.pt' netC.load_state_dict(torch.load(args.modelpath)) netF.eval() netB.eval() netC.eval() acc, _ = cal_acc(dset_loaders['test'], netF, netB, netC) log_str = 'Task: {}, Accuracy = {:.2f}%'.format(args.dset, acc * 100) args.out_file.write(log_str + '\n') args.out_file.flush() print(log_str + '\n')
def test_target(args, zz): dset_loaders = data_load(args) ## set base network if args.net[0:3] == 'res': netF = network.ResBase(res_name=args.net).cuda() netB = network.feat_bootleneck(type=args.classifier, feature_dim=netF.in_features, bottleneck_dim=args.bottleneck).cuda() netC = network.feat_classifier(type=args.layer, class_num=args.class_num, bottleneck_dim=args.bottleneck).cuda() args.modelpath = args.output_dir_src + '/source_F_' + str(zz) + '.pt' netF.load_state_dict(torch.load(args.modelpath)) args.modelpath = args.output_dir_src + '/source_B_' + str(zz) + '.pt' netB.load_state_dict(torch.load(args.modelpath)) args.modelpath = args.output_dir_src + '/source_C_' + str(zz) + '.pt' netC.load_state_dict(torch.load(args.modelpath)) netF.eval() netB.eval() netC.eval() acc, _ = cal_acc(dset_loaders['test'], netF, netB, netC, args.da == 'oda') log_str = '\nZz: {}, Task: {}, Accuracy = {:.2f}%'.format( zz, args.name, acc * 100) args.out_file.write(log_str) args.out_file.flush() print(log_str)
def test_target(args): dset_loaders = data_load(args) ## set base network if args.net[0:3] == 'res': netF = network.ResBase(res_name=args.net).cuda() elif args.net[0:3] == 'vgg': netF = network.VGGBase(vgg_name=args.net).cuda() netB = network.feat_bootleneck(type=args.classifier, feature_dim=netF.in_features, bottleneck_dim=args.bottleneck).cuda() netC = network.feat_classifier(type=args.layer, class_num = args.class_num, bottleneck_dim=args.bottleneck).cuda() args.modelpath = args.output_dir_src + '/source_F.pt' netF.load_state_dict(torch.load(args.modelpath)) args.modelpath = args.output_dir_src + '/source_B.pt' netB.load_state_dict(torch.load(args.modelpath)) args.modelpath = args.output_dir_src + '/source_C.pt' netC.load_state_dict(torch.load(args.modelpath)) netF.eval() netB.eval() netC.eval() if args.da == 'oda': acc_os1, acc_os2, acc_unknown = cal_acc_oda(dset_loaders['test'], netF, netB, netC) log_str = '\nTraining: {}, Task: {}, Accuracy = {:.2f}% / {:.2f}% / {:.2f}%'.format(args.trte, args.name, acc_os2, acc_os1, acc_unknown) else: if args.dset=='VISDA-C': acc, acc_list = cal_acc(dset_loaders['test'], netF, netB, netC, True) log_str = '\nTraining: {}, Task: {}, Accuracy = {:.2f}%'.format(args.trte, args.name, acc) + '\n' + acc_list else: acc, _ = cal_acc(dset_loaders['test'], netF, netB, netC, False) log_str = '\nTraining: {}, Task: {}, Accuracy = {:.2f}%'.format(args.trte, args.name, acc) args.out_file.write(log_str) args.out_file.flush() print(log_str)
def test_target(args): dset_loaders = data_load(args) ## set base network if args.net[0:3] == 'res' or args.net[0:3] == 'vgg': if args.net[0:3] == 'res': netF = network.ResBase(res_name=args.net).cuda() else: netF = network.VGGBase(vgg_name=args.net).cuda() netB = network.feat_bootleneck(type=args.classifier, feature_dim=netF.in_features, bottleneck_dim=args.bottleneck).cuda() netC = network.feat_classifier(type=args.layer, class_num=args.class_num, bottleneck_dim=args.bottleneck).cuda() args.modelpath = args.output_dir_src + '/source_F.pt' netF.load_state_dict(torch.load(args.modelpath)) args.modelpath = args.output_dir_src + '/source_B.pt' netB.load_state_dict(torch.load(args.modelpath)) args.modelpath = args.output_dir_src + '/source_C.pt' netC.load_state_dict(torch.load(args.modelpath)) netF.eval() netB.eval() netC.eval() if args.dset == 'VISDA-RSUT' or args.dset == 'VISDA-RSUT-50' or args.dset == 'VISDA-RSUT-10': # For VisDA, print acc of each class. if args.net[0:3] == 'res' or args.net[0:3] == 'vgg': acc, acc_list, acc_cls_avg = cal_acc(dset_loaders['test'], netF=netF, netB=netB, netC=netC, per_class_flag=True, visda_flag=True) log_str = '\nTraining: {}, Task: {}, Accuracy = {:.2f}%, Cls Avg Acc = {:.2f}%'.format(args.trte, args.name, acc, acc_cls_avg) + '\n' + acc_list else: # For Home, DomainNet, no need to print acc of each class. if args.net[0:3] == 'res' or args.net[0:3] == 'vgg': acc, acc_cls_avg, _ = cal_acc(dset_loaders['test'], netF=netF, netB=netB, netC=netC, per_class_flag=True, visda_flag=False) log_str = '\nTraining: {}, Task: {}, Accuracy = {:.2f}%, Cls Avg Acc = {:.2f}%'.format(args.trte, args.name, acc, acc_cls_avg) args.out_file.write(log_str) args.out_file.flush() print(log_str)
def train_target(args): dset_loaders = data_load(args) if args.net[0:3] == 'res': netF = network.ResBase(res_name=args.net).cuda() elif args.net[0:3] == 'vgg': netF = network.VGGBase(vgg_name=args.net).cuda() netB = network.feat_bootleneck(type=args.classifier, feature_dim=netF.in_features, bottleneck_dim=args.bottleneck).cuda() netC = network.feat_classifier(type=args.layer, class_num=args.class_num, bottleneck_dim=args.bottleneck).cuda() modelpath = args.output_dir_src + '/source_F.pt' netF.load_state_dict(torch.load(modelpath)) modelpath = args.output_dir_src + '/source_B.pt' netB.load_state_dict(torch.load(modelpath)) modelpath = args.output_dir_src + '/source_C.pt' netC.load_state_dict(torch.load(modelpath)) netC.eval() for k, v in netC.named_parameters(): v.requires_grad = False param_group = [] for k, v in netF.named_parameters(): if args.lr_decay1 > 0: param_group += [{'params': v, 'lr': args.lr * args.lr_decay1}] else: v.requires_grad = False for k, v in netB.named_parameters(): if args.lr_decay2 > 0: param_group += [{'params': v, 'lr': args.lr * args.lr_decay2}] else: v.requires_grad = False optimizer = optim.SGD(param_group) optimizer = op_copy(optimizer) max_iter = args.max_epoch * len(dset_loaders["target"]) interval_iter = max_iter // args.interval iter_num = 0 iter_sw = int(max_iter / 2.0) while iter_num < max_iter: try: inputs_test, _, tar_idx = iter_test.next() except: iter_test = iter(dset_loaders["target"]) inputs_test, _, tar_idx = iter_test.next() if inputs_test.size(0) == 1: continue if iter_num % interval_iter == 0 and args.cls_par > 0: netF.eval() netB.eval() mem_label_soft, mtx_infor_nh, feas_FC = obtain_label( dset_loaders['test'], netF, netB, netC, args, iter_num, iter_sw) mem_label_soft = torch.from_numpy(mem_label_soft).cuda() feas_all = feas_FC[0] ops_all = feas_FC[1] netF.train() netB.train() inputs_test = inputs_test.cuda() iter_num += 1 lr_scheduler(optimizer, iter_num=iter_num, max_iter=max_iter) features_F_self = netF(inputs_test) features_F_nh = get_mtx_sam_wgt_nh(feas_all, mtx_infor_nh, tar_idx) features_F_nh = features_F_nh.cuda() features_F_mix = 0.8 * features_F_self + 0.2 * features_F_nh outputs_test_mix = netC(netB(features_F_mix)) ops_test_self = netC(netB(features_F_self)) outputs_test_nh = netC(netB(features_F_nh)) if args.cls_par > 0: log_probs = nn.LogSoftmax(dim=1)(outputs_test_mix) targets = mem_label_soft[tar_idx] loss_soft = (-targets * log_probs).sum(dim=1) classifier_loss = loss_soft.mean() classifier_loss *= args.cls_par if iter_num < interval_iter and args.dset == "VISDA-C": classifier_loss *= 0 else: classifier_loss = torch.tensor(0.0).cuda() if args.ent: softmax_out = nn.Softmax(dim=1)( outputs_test_mix) # outputs_test_mix entropy_loss = torch.mean(loss.Entropy(softmax_out)) if args.gent: msoftmax = softmax_out.mean(dim=0) gentropy_loss = torch.sum(-msoftmax * torch.log(msoftmax + args.epsilon)) entropy_loss -= gentropy_loss im_loss = entropy_loss * args.ent_par classifier_loss += im_loss optimizer.zero_grad() classifier_loss.backward() optimizer.step() if iter_num % interval_iter == 0 or iter_num == max_iter: netF.eval() netB.eval() if args.dset == 'VISDA-C': acc_s_te, acc_list = cal_acc(dset_loaders['test'], netF, netB, netC, True) log_str = 'Task: {}, Iter:{}/{}; Accuracy = {:.2f}%'.format( args.name, iter_num, max_iter, acc_s_te) + '\n' + acc_list else: acc_s_te, _ = cal_acc(dset_loaders['test'], netF, netB, netC, False) log_str = 'Task: {}, Iter:{}/{}; Accuracy = {:.2f}%'.format( args.name, iter_num, max_iter, acc_s_te) args.out_file.write(log_str + '\n') args.out_file.flush() print(log_str + '\n') netF.train() netB.train() if args.issave: torch.save( netF.state_dict(), osp.join(args.output_dir, "target_F_" + args.savename + ".pt")) torch.save( netB.state_dict(), osp.join(args.output_dir, "target_B_" + args.savename + ".pt")) torch.save( netC.state_dict(), osp.join(args.output_dir, "target_C_" + args.savename + ".pt")) return netF, netB, netC
def train_target(args): dset_loaders, txt_tar = data_load(args) dsets = dict() ## set base network if args.net[0:3] == 'res': netF = network.ResBase(res_name=args.net).cuda() elif args.net[0:3] == 'vgg': netF = network.VGGBase(vgg_name=args.net).cuda() netB = network.feat_bootleneck(type=args.classifier, feature_dim=netF.in_features, bottleneck_dim=args.bottleneck).cuda() netC = network.feat_classifier(type=args.layer, class_num=args.class_num, bottleneck_dim=args.bottleneck).cuda() modelpath = args.output_dir_src + '/source_F.pt' netF.load_state_dict(torch.load(modelpath)) modelpath = args.output_dir_src + '/source_B.pt' netB.load_state_dict(torch.load(modelpath)) modelpath = args.output_dir_src + '/source_C.pt' netC.load_state_dict(torch.load(modelpath)) param_group = [] for k, v in netF.named_parameters(): if args.lr_decay1 > 0: param_group += [{'params': v, 'lr': args.lr * args.lr_decay1}] else: v.requires_grad = False for k, v in netB.named_parameters(): if args.lr_decay2 > 0: param_group += [{'params': v, 'lr': args.lr * args.lr_decay2}] else: v.requires_grad = False for k, v in netC.named_parameters(): v.requires_grad = False optimizer = optim.SGD(param_group) optimizer = op_copy(optimizer) max_iter = args.max_epoch * len(dset_loaders["target"]) iter_per_epoch = len(dset_loaders["target"]) print("Iter per epoch: {}".format(iter_per_epoch)) interval_iter = max_iter // args.interval iter_num = 0 if args.paral: netF = torch.nn.DataParallel(netF) netB = torch.nn.DataParallel(netB) netC = torch.nn.DataParallel(netC) netF.train() netB.train() netC.eval() if args.scd_lamb: scd_lamb_init = args.scd_lamb # specify hyperparameter for secondary label correcion manually else: if args.dset[0:5] == "VISDA" : scd_lamb_init = 0.1 elif args.dset[0:11] == "office-home": scd_lamb_init = 0.2 if args.s == 3 and args.t == 2: scd_lamb_init *= 0.1 elif args.dset[0:9] == "domainnet": scd_lamb_init = 0.02 scd_lamb = scd_lamb_init while iter_num < max_iter: k = 1.0 k_s = 0.6 if iter_num % interval_iter == 0 and args.cls_par > 0: # interval_itr = itr per epoch netF.eval() netB.eval() label_prob_dict = obtain_label(dset_loaders['test'], netF, netB, netC, args) mem_label, pseudo_lb_prob = label_prob_dict['primary_lb'], label_prob_dict['primary_lb_prob'] mem_label, pseudo_lb_prob = torch.from_numpy(mem_label).cuda(), torch.from_numpy(pseudo_lb_prob).cuda() if args.scd_label: second_label, second_prob = label_prob_dict['secondary_lb'], label_prob_dict['secondary_lb_prob'] second_label, second_prob = torch.from_numpy(second_label).cuda(), torch.from_numpy(second_prob).cuda() if args.third_label: third_label, third_prob = label_prob_dict['third_lb'], label_prob_dict['third_lb_prob'] third_label, third_prob = torch.from_numpy(third_label).cuda(), torch.from_numpy(third_prob).cuda() if args.fourth_label: fourth_label, fourth_prob = label_prob_dict['fourth_lb'], label_prob_dict['fourth_lb_prob'] fourth_label, fourth_prob = torch.from_numpy(fourth_label).cuda(), torch.from_numpy(fourth_prob).cuda() if args.topk_ent: all_entropy = label_prob_dict['entropy'] all_entropy = torch.from_numpy(all_entropy) if args.dset[0:5] == "VISDA" : if iter_num // iter_per_epoch < 1: k = 0.6 elif iter_num // iter_per_epoch < 2: k = 0.7 elif iter_num // iter_per_epoch < 3: k = 0.8 elif iter_num // iter_per_epoch < 4: k = 0.9 else: k = 1.0 if iter_num // iter_per_epoch >= 8: scd_lamb *= 0.1 elif args.dset[0:11] == "office-home" or args.dset[0:9] == "domainnet": if iter_num // iter_per_epoch < 2: k = 0.2 elif iter_num // iter_per_epoch < 4: k = 0.4 elif iter_num // iter_per_epoch < 8: k = 0.6 elif iter_num // iter_per_epoch < 12: k = 0.8 if args.topk: dsets["target"] = ImageList_PA(txt_tar, mem_label, pseudo_lb_prob, k_low=k, k_up=None, transform=image_train()) dset_loaders["target"] = DataLoader(dsets["target"], batch_size=args.batch_size, shuffle=True, num_workers=args.worker, drop_last=False) if args.topk_ent: dsets["target"] = ImageList_PA(txt_tar, mem_label, -1.0 * all_entropy, k_low=k, k_up=None, transform=image_train()) dset_loaders["target"] = DataLoader(dsets["target"], batch_size=args.batch_size, shuffle=True, num_workers=args.worker, drop_last=False) if args.scd_label: # 2nd label threshold: prob top 60% dsets["target_scd"] = ImageList_PA(txt_tar, second_label, second_prob, k_low=k_s, k_up=None, transform=image_train()) dset_loaders["target_scd"] = DataLoader(dsets["target_scd"], batch_size=args.batch_size, shuffle=True, num_workers=args.worker, drop_last=False) if args.third_label: # 3rd label threshold: prob top 60% dsets["target_third"] = ImageList_PA(txt_tar, third_label, third_prob, k_low=k_s, k_up=None, transform=image_train()) dset_loaders["target_third"] = DataLoader(dsets["target_third"], batch_size=args.batch_size, shuffle=True, num_workers=args.worker, drop_last=False) if args.fourth_label: # 4th label threshold: prob top 60% dsets["target_fourth"] = ImageList_PA(txt_tar, fourth_label, fourth_prob, k_low=k_s, k_up=None, transform=image_train()) dset_loaders["target_fourth"] = DataLoader(dsets["target_fourth"], batch_size=args.batch_size, shuffle=True, num_workers=args.worker, drop_last=False) netF.train() netB.train() try: inputs_test, _, tar_idx = iter_test.next() except: iter_test = iter(dset_loaders["target"]) inputs_test, _, tar_idx = iter_test.next() # tar_idx: chosen indices in current itr if inputs_test.size(0) == 1: continue if args.scd_label: try: inputs_test_scd, _, tar_idx_scd = iter_test_scd.next() except: iter_test_scd = iter(dset_loaders["target_scd"]) inputs_test_scd, _, tar_idx_scd = iter_test_scd.next() if inputs_test_scd.size(0) == 1: continue if args.third_label: try: inputs_test_third, _, tar_idx_third = iter_test_third.next() except: iter_test_third = iter(dset_loaders["target_third"]) inputs_test_third, _, tar_idx_third = iter_test_third.next() if inputs_test_third.size(0) == 1: continue if args.fourth_label: try: inputs_test_fourth, _, tar_idx_fourth = iter_test_fourth.next() except: iter_test_fourth = iter(dset_loaders["target_fourth"]) inputs_test_fourth, _, tar_idx_fourth = iter_test_fourth.next() if inputs_test_fourth.size(0) == 1: continue iter_num += 1 inputs_test = inputs_test.cuda() features_test = netB(netF(inputs_test)) outputs_test = netC(features_test) if args.scd_label: inputs_test_scd = inputs_test_scd.cuda() if inputs_test_scd.ndim == 3: inputs_test_scd = inputs_test_scd.unsqueeze(0) features_test_scd = netB(netF(inputs_test_scd)) outputs_test_scd = netC(features_test_scd) first_prob_of_scd = pseudo_lb_prob[tar_idx_scd] scd_prob = second_prob[tar_idx_scd] if not args.no_mask: mask = (scd_prob / first_prob_of_scd.float()).clamp(max=1.0) else: mask = torch.ones_like(scd_prob).cuda() if args.third_label: inputs_test_third = inputs_test_third.cuda() if inputs_test_third.ndim == 3: inputs_test_third = inputs_test_third.unsqueeze(0) features_test_third = netB(netF(inputs_test_third)) outputs_test_third = netC(features_test_third) first_prob_of_third = pseudo_lb_prob[tar_idx_third] thi_prob = third_prob[tar_idx_third] mask_third = (thi_prob / first_prob_of_third.float()).clamp(max=1.0) if args.fourth_label: inputs_test_fourth = inputs_test_fourth.cuda() if inputs_test_fourth.ndim == 3: inputs_test_fourth = inputs_test_fourth.unsqueeze(0) features_test_fourth = netB(netF(inputs_test_fourth)) outputs_test_fourth = netC(features_test_fourth) first_prob_of_fourth = pseudo_lb_prob[tar_idx_fourth] fth_prob = fourth_prob[tar_idx_fourth] mask_fourth = (fth_prob / first_prob_of_fourth.float()).clamp(max=1.0) if args.intra_dense or args.inter_sep: intra_dist = torch.zeros(1).cuda() inter_dist = torch.zeros(1).cuda() pred = mem_label[tar_idx] same_first = True diff_first = True cos = nn.CosineSimilarity(dim=1, eps=1e-6) for i in range(pred.size(0)): for j in range(i, pred.size(0)): # dist = torch.norm(features_test[i] - features_test[j]) dist = 0.5 * (1 - cos(features_test[i].unsqueeze(0), features_test[j].unsqueeze(0))) if pred[i].item() == pred[j].item(): if same_first: intra_dist = dist.unsqueeze(0) same_first = False else: intra_dist = torch.cat((intra_dist, dist.unsqueeze(0))) else: if diff_first: inter_dist = dist.unsqueeze(0) diff_first = False else: inter_dist = torch.cat((inter_dist, dist.unsqueeze(0))) intra_dist = torch.mean(intra_dist) inter_dist = torch.mean(inter_dist) if args.cls_par > 0: pred = mem_label[tar_idx] classifier_loss = nn.CrossEntropyLoss()(outputs_test, pred) # self-train by pseudo label classifier_loss *= args.cls_par if args.scd_label: pred_scd = second_label[tar_idx_scd] classifier_loss_scd = nn.CrossEntropyLoss(reduction='none')(outputs_test_scd, pred_scd) # self-train by pseudo label classifier_loss_scd = torch.mean(mask * classifier_loss_scd) classifier_loss_scd *= args.cls_par classifier_loss += classifier_loss_scd * scd_lamb if args.third_label: pred_third = third_label[tar_idx_third] classifier_loss_third = nn.CrossEntropyLoss(reduction='none')(outputs_test_third, pred_third) # self-train by pseudo label classifier_loss_third = torch.mean(mask_third * classifier_loss_third) classifier_loss_third *= args.cls_par classifier_loss += classifier_loss_third * scd_lamb # TODO: better weighting is possible if args.fourth_label: pred_fourth = fourth_label[tar_idx_fourth] classifier_loss_fourth = nn.CrossEntropyLoss(reduction='none')(outputs_test_fourth, pred_fourth) # self-train by pseudo label classifier_loss_fourth = torch.mean(mask_fourth * classifier_loss_fourth) classifier_loss_fourth *= args.cls_par classifier_loss += classifier_loss_fourth * scd_lamb # TODO: better weighting is possible if iter_num < interval_iter and (args.dset == "VISDA-C" or args.dset == "VISDA-RSUT" or args.dset == 'VISDA-RSUT-50' or args.dset == 'VISDA-RSUT-10'): classifier_loss *= 0 else: classifier_loss = torch.tensor(0.0).cuda() if args.ent: softmax_out = nn.Softmax(dim=1)(outputs_test) entropy_loss = torch.mean(loss.Entropy(softmax_out)) # Minimize local entropy if args.scd_label: softmax_out_scd = nn.Softmax(dim=1)(outputs_test_scd) entropy_loss_scd = torch.mean(mask * loss.Entropy(softmax_out_scd)) # Minimize local entropy if args.third_label: softmax_out_third = nn.Softmax(dim=1)(outputs_test_third) entropy_loss_third = torch.mean(mask_third * loss.Entropy(softmax_out_third)) # Minimize local entropy if args.fourth_label: softmax_out_fourth = nn.Softmax(dim=1)(outputs_test_fourth) entropy_loss_fourth = torch.mean(mask_fourth * loss.Entropy(softmax_out_fourth)) # Minimize local entropy if args.gent: msoftmax = softmax_out.mean(dim=0) gentropy_loss = torch.sum(-msoftmax * torch.log(msoftmax + args.epsilon)) entropy_loss -= gentropy_loss # Maximize global entropy if args.scd_label: msoftmax_scd = softmax_out_scd.mean(dim=0) gentropy_loss_scd = torch.sum(-msoftmax_scd * torch.log(msoftmax_scd + args.epsilon)) entropy_loss_scd -= gentropy_loss_scd # Maximize global entropy if args.third_label: msoftmax_third = softmax_out_third.mean(dim=0) gentropy_loss_third = torch.sum(-msoftmax_third * torch.log(msoftmax_third + args.epsilon)) entropy_loss_third -= gentropy_loss_third # Maximize global entropy if args.fourth_label: msoftmax_fourth = softmax_out_fourth.mean(dim=0) gentropy_loss_fourth = torch.sum(-msoftmax_fourth * torch.log(msoftmax_fourth + args.epsilon)) entropy_loss_fourth -= gentropy_loss_fourth # Maximize global entropy im_loss = entropy_loss * args.ent_par if args.scd_label: im_loss += entropy_loss_scd * args.ent_par * scd_lamb if args.third_label: im_loss += entropy_loss_third * args.ent_par * scd_lamb # TODO: better weighting is possible if args.fourth_label: im_loss += entropy_loss_fourth * args.ent_par * scd_lamb # TODO: better weighting is possible classifier_loss += im_loss if args.intra_dense: classifier_loss += args.lamb_intra * intra_dist.squeeze() if args.inter_sep: classifier_loss += args.lamb_inter * inter_dist.squeeze() optimizer.zero_grad() classifier_loss.backward() optimizer.step() lr_scheduler(optimizer, iter_num=iter_num, max_iter=max_iter) if iter_num % interval_iter == 0 or iter_num == max_iter: netF.eval() netB.eval() if args.dset == 'VISDA-RSUT' or args.dset == 'VISDA-RSUT-50' or args.dset == 'VISDA-RSUT-10': # For VisDA, print the acc of each cls acc_s_te, acc_list, acc_cls_avg = cal_acc(dset_loaders['test'], netF, netB, netC, visda_flag=True) # flag for VisDA -> need cls avg acc. log_str = 'Task: {}, Iter:{}/{}; Accuracy = {:.2f}%, Cls Avg Acc = {:.2f}%'.format(args.name, iter_num, max_iter, acc_s_te, acc_cls_avg) + '\n' + acc_list else: # In imbalanced setting, use per-class avg acc as metric # For Office-Home, DomainNet, no need to print the acc of each cls acc_s_te, acc_cls_avg, _ = cal_acc(dset_loaders['test'], netF, netB, netC, visda_flag=False) log_str = 'Task: {}, Iter:{}/{}; Accuracy = {:.2f}%, Cls Avg Acc = {:.2f}%'.format(args.name, iter_num, max_iter, acc_s_te, acc_cls_avg) args.out_file.write(log_str + '\n') args.out_file.flush() print(log_str + '\n') netF.train() netB.train() if args.issave: torch.save(netF.state_dict(), osp.join(args.output_dir, "target_F_" + args.savename + ".pt")) torch.save(netB.state_dict(), osp.join(args.output_dir, "target_B_" + args.savename + ".pt")) torch.save(netC.state_dict(), osp.join(args.output_dir, "target_C_" + args.savename + ".pt")) return netF, netB, netC
def train_target(args): dset_loaders = digit_load(args) ## set base network if args.dset == 'u2m': netF = network.LeNetBase().cuda() elif args.dset == 'm2u': netF = network.LeNetBase().cuda() elif args.dset == 's2m': netF = network.DTNBase().cuda() netB = network.feat_bootleneck(type=args.classifier, feature_dim=netF.in_features, bottleneck_dim=args.bottleneck).cuda() netC = network.feat_classifier(type=args.layer, class_num=args.class_num, bottleneck_dim=args.bottleneck).cuda() args.modelpath = args.output_dir + '/source_F_val.pt' netF.load_state_dict(torch.load(args.modelpath)) args.modelpath = args.output_dir + '/source_B_val.pt' netB.load_state_dict(torch.load(args.modelpath)) args.modelpath = args.output_dir + '/source_C_val.pt' netC.load_state_dict(torch.load(args.modelpath)) # 只设置netC为测试模式,也就是只设置判别器为测试模型,这样在训练整个模型的时候,netC的参数就不会变化,这样,就保证了F=g.h中的h不变了 netC.eval() for k, v in netC.named_parameters(): v.requires_grad = False param_group = [] for k, v in netF.named_parameters(): param_group += [{'params': v, 'lr': args.lr}] for k, v in netB.named_parameters(): param_group += [{'params': v, 'lr': args.lr}] optimizer = optim.SGD(param_group, momentum=0.9, weight_decay=5e-4, nesterov=True) # scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1) for epoch in tqdm(range(args.max_epoch), leave=False): # 设置前面的为训练模式 netF.train() netB.train() iter_test = iter(dset_loaders["target"]) # 在这里还保存了源域的g模型,并且设置为测试模式 # 注意,这里拷贝的是上次epoch的模型,而不是源域模型不变 prev_F = copy.deepcopy(netF) prev_B = copy.deepcopy(netB) prev_F.eval() prev_B.eval() # 获取质心,对应论文里面的第一个公式 center = obtain_center(dset_loaders['target'], prev_F, prev_B, netC, args) for _, (inputs_test, _) in tqdm(enumerate(iter_test), leave=False): if inputs_test.size(0) == 1: continue inputs_test = inputs_test.cuda() with torch.no_grad(): # 注意,是每进行一个数据batch的iteration就预测label一次, # 另外,无论iteration多少次,他们预测label使用的模型都是上次epoch使用的模型 # 下面这两句对应论文里面的第二个公式 # todo li 论文里面还有第三个和第四个公式,怎么没看到在哪啊。 features_test = prev_B(prev_F(inputs_test)) pred = obtain_label(features_test, center) # 这里是正常的数据经过网络 features_test = netB(netF(inputs_test)) outputs_test = netC(features_test) # 计算损失 classifier_loss = CrossEntropyLabelSmooth( num_classes=args.class_num, epsilon=0)(outputs_test, pred) # 这里计算IM loss # 这里计算的是softmax的输出,对dim=1进行softmax softmax_out = nn.Softmax(dim=1)(outputs_test) # 这个entropy计算的是-sum(softmax_out*log(softmax_out),dim=1),得到每个batch的概率之和 # 然后再进行一个mean,相当于是计算出来了概率之和相对于每个batch的平均值 # 这个im_loss对应论文里面的Lent im_loss = torch.mean(Entropy(softmax_out)) # msoftmax计算出来了每个类别的概率,batch的平均值,这个对应论文里面的p^k msoftmax = softmax_out.mean(dim=0) # 这里的这个-=配合sum里面的负号,就是+=. 这里是求K个的平均值,对应论文里面的Ldiv im_loss -= torch.sum(-msoftmax * torch.log(msoftmax + 1e-5)) # args.par在这里用到了,是权衡IM loss和classifier_loss的超参数 total_loss = im_loss + args.par * classifier_loss optimizer.zero_grad() total_loss.backward() optimizer.step() netF.eval() netB.eval() acc, _ = cal_acc(dset_loaders['test'], netF, netB, netC) log_str = 'Task: {}, Iter:{}/{}; Accuracy = {:.2f}%'.format( args.dset, epoch + 1, args.max_epoch, acc * 100) args.out_file.write(log_str + '\n') args.out_file.flush() print(log_str + '\n') # torch.save(netF.state_dict(), osp.join(args.output_dir, "target_F.pt")) # torch.save(netB.state_dict(), osp.join(args.output_dir, "target_B.pt")) # torch.save(netC.state_dict(), osp.join(args.output_dir, "target_C.pt")) return netF, netB, netC
def train_target(args): dset_loaders = data_load(args) ## set base network if args.net[0:3] == 'res': netF = network.ResBase(res_name=args.net).cuda() elif args.net[0:3] == 'vgg': netF = network.VGGBase(vgg_name=args.net).cuda() netB = network.feat_bootleneck(type=args.classifier, feature_dim=netF.in_features, bottleneck_dim=args.bottleneck).cuda() netC = network.feat_classifier(type=args.layer, class_num=args.class_num, bottleneck_dim=args.bottleneck).cuda() wandb.watch(netB) wandb.watch(netC) wandb.watch(netF) modelpath = args.output_dir_src + '/source_F.pt' netF.load_state_dict(torch.load(modelpath)) modelpath = args.output_dir_src + '/source_B.pt' netB.load_state_dict(torch.load(modelpath)) modelpath = args.output_dir_src + '/source_C.pt' netC.load_state_dict(torch.load(modelpath)) netC.eval() for k, v in netC.named_parameters(): v.requires_grad = False param_group = [] for k, v in netF.named_parameters(): if args.lr_decay1 > 0: param_group += [{'params': v, 'lr': args.lr * args.lr_decay1}] else: v.requires_grad = False for k, v in netB.named_parameters(): if args.lr_decay2 > 0: param_group += [{'params': v, 'lr': args.lr * args.lr_decay2}] else: v.requires_grad = False optimizer = optim.SGD(param_group) optimizer = op_copy(optimizer) max_iter = args.max_epoch * len(dset_loaders["target"]) interval_iter = max_iter // args.interval iter_num = 0 while iter_num < max_iter: try: inputs_test, _, tar_idx = iter_test.next() except: iter_test = iter(dset_loaders["target"]) inputs_test, _, tar_idx = iter_test.next() if inputs_test.size(0) == 1: continue if iter_num % interval_iter == 0 and args.cls_par > 0: netF.eval() netB.eval() # pseudolabels mem_label = obtain_label(dset_loaders['test'], netF, netB, netC, args) mem_label = torch.from_numpy(mem_label).cuda() netF.train() netB.train() inputs_test = inputs_test.cuda() iter_num += 1 lr_scheduler(optimizer, iter_num=iter_num, max_iter=max_iter) features_test = netB(netF(inputs_test)) outputs_test = netC(features_test) if args.cls_par > 0: pred = mem_label[tar_idx] # cross-entropy loss between classifier_loss = nn.CrossEntropyLoss()(outputs_test, pred) classifier_loss *= args.cls_par if iter_num < interval_iter and (args.dset == "VISDA18" or args.dset == 'VISDA-C'): classifier_loss *= 0 else: classifier_loss = torch.tensor(0.0).cuda() if args.ent: softmax_out = nn.Softmax(dim=1)(outputs_test) entropy_loss = torch.mean( loss.Entropy(softmax_out)) # expectation over samples drawn if args.gent: msoftmax = softmax_out.mean(dim=0) gentropy_loss = torch.sum(-msoftmax * torch.log(msoftmax + args.epsilon)) entropy_loss -= gentropy_loss im_loss = entropy_loss * args.ent_par classifier_loss += im_loss wandb.log({ "target_classifier_loss": classifier_loss, "info_max_loss": im_loss }) optimizer.zero_grad() classifier_loss.backward() optimizer.step() if iter_num % interval_iter == 0 or iter_num == max_iter: netF.eval() netB.eval() if args.dset == 'VISDA18' or args.dset == 'VISCA-C': # if args.dset=='VISDA-C': acc_s_te, acc_list = cal_acc(dset_loaders['test'], netF, netB, netC, True) log_str = 'Task: {}, Iter:{}/{}; Accuracy = {:.2f}%'.format( args.name, iter_num, max_iter, acc_s_te) + '\n' + acc_list wandb.log({"accuracy": acc_s_te}) # true test accuracy else: acc_s_te, _ = cal_acc(dset_loaders['test'], netF, netB, netC, False) log_str = 'Task: {}, Iter:{}/{}; Accuracy = {:.2f}%'.format( args.name, iter_num, max_iter, acc_s_te) args.out_file.write(log_str + '\n') args.out_file.flush() print(log_str + '\n') netF.train() netB.train() if args.issave: torch.save( netF.state_dict(), osp.join(args.output_dir, "target_F_" + args.savename + ".pt")) torch.save( netB.state_dict(), osp.join(args.output_dir, "target_B_" + args.savename + ".pt")) torch.save( netC.state_dict(), osp.join(args.output_dir, "target_C_" + args.savename + ".pt")) return netF, netB, netC
def train(args): ent_loss_record = [] gent_loss_record = [] sent_loss_record = [] total_loss_record = [] dset_loaders = digit_load(args) ## set base network if args.dset == 'u': netF = network.LeNetBase() #.cuda() elif args.dset == 'm': netF = network.LeNetBase() #.cuda() elif args.dset == 's': netF = network.DTNBase() #.cuda() netB = network.feat_bootleneck(type=args.classifier, feature_dim=netF.in_features, bottleneck_dim=args.bottleneck) #.cuda() netC = network.feat_classifier(type=args.layer, class_num=args.class_num, bottleneck_dim=args.bottleneck) #.cuda() param_group = [] learning_rate = args.lr for k, v in netF.named_parameters(): param_group += [{'params': v, 'lr': learning_rate}] for k, v in netB.named_parameters(): param_group += [{'params': v, 'lr': learning_rate}] for k, v in netC.named_parameters(): param_group += [{'params': v, 'lr': learning_rate}] optimizer = optim.SGD(param_group) optimizer = op_copy(optimizer) acc_init = 0 max_iter = args.max_epoch * len(dset_loaders["train"]) interval_iter = max_iter // 10 iter_num = 0 netF.train() netB.train() netC.train() while iter_num < max_iter: try: inputs_source, strong_inputs, target = iter_source.next() except: iter_source = iter(dset_loaders["train"]) inputs_source, strong_inputs, target = iter_source.next() if inputs_source.size(0) == 1: continue iter_num += 1 lr_scheduler(optimizer, iter_num=iter_num, max_iter=max_iter) inputs_source = inputs_source #.cuda() outputs_source = netC(netB(netF(inputs_source))) total_loss = torch.tensor(0.0) #.cuda() softmax_out = nn.Softmax(dim=1)(outputs_source) if args.ent: ent_loss = torch.mean(loss.Entropy(softmax_out)) total_loss += ent_loss ent_loss_record.append(ent_loss.detach().cpu()) if args.gent: msoftmax = softmax_out.mean(dim=0) gent_loss = -torch.sum(-msoftmax * torch.log(msoftmax + 1e-5)) gent_loss_record.append(gent_loss.detach().cpu()) total_loss += gent_loss if args.sent: sent_loss = compute_aug_loss(strong_inputs, target, netC, netB, netF) total_loss += sent_loss sent_loss_record.append(sent_loss.detach().cpu()) optimizer.zero_grad() total_loss.backward() optimizer.step() total_loss_record.append(total_loss.detach().cpu()) if iter_num % interval_iter == 0 or iter_num == max_iter: print(iter_num, interval_iter, max_iter) # netF.eval() # netB.eval() # netC.eval() # acc_s_tr, _ = cal_acc(dset_loaders['train'], netF, netB, netC) # acc_s_te, _ = cal_acc(dset_loaders['test'], netF, netB, netC) # log_str = 'Task: {}, Iter:{}/{}; Accuracy = {:.2f}%/ {:.2f}%'.format(args.dset, iter_num, max_iter, acc_s_tr, acc_s_te) # args.out_file.write(log_str + '\n') # args.out_file.flush() # print(log_str+'\n') # if acc_s_te >= acc_init: # acc_init = acc_s_te # best_netF = netF.state_dict() # best_netB = netB.state_dict() # best_netC = netC.state_dict() # netF.train() # netB.train() # netC.train() best_netF = netF.state_dict() best_netB = netB.state_dict() best_netC = netC.state_dict() torch.save(best_netF, osp.join(args.output_dir, "F.pt")) torch.save(best_netB, osp.join(args.output_dir, "B.pt")) torch.save(best_netC, osp.join(args.output_dir, "C.pt")) fig, (ax1, ax2, ax3, ax4) = plt.subplots(nrows=4, sharex=True, figsize=(16, 8)) ax1.plot(list(range(len(ent_loss_record))), ent_loss_record, 'r') ax2.plot(list(range(len(gent_loss_record))), gent_loss_record, 'g') ax3.plot(list(range(len(sent_loss_record))), sent_loss_record, 'b') ax4.plot(list(range(len(total_loss_record))), total_loss_record, 'm') plt.tight_layout() plt.savefig(args.output_dir + '/loss.png') return netF, netB, netC
def train_source(args): dset_loaders = digit_load(args) ## set base network if args.dset == 'u2m': netF = network.LeNetBase().cuda() elif args.dset == 'm2u': netF = network.LeNetBase().cuda() elif args.dset == 's2m': netF = network.DTNBase().cuda() netB = network.feat_bootleneck(type=args.classifier, feature_dim=netF.in_features, bottleneck_dim=args.bottleneck).cuda() netC = network.feat_classifier(type=args.layer, class_num=args.class_num, bottleneck_dim=args.bottleneck).cuda() param_group = [] learning_rate = args.lr for k, v in netF.named_parameters(): param_group += [{'params': v, 'lr': learning_rate}] for k, v in netB.named_parameters(): param_group += [{'params': v, 'lr': learning_rate}] for k, v in netC.named_parameters(): param_group += [{'params': v, 'lr': learning_rate}] optimizer = optim.SGD(param_group) optimizer = op_copy(optimizer) acc_init = 0 max_iter = args.max_epoch * len(dset_loaders["source_tr"]) interval_iter = max_iter // 10 iter_num = 0 netF.train() netB.train() netC.train() while iter_num < max_iter: try: inputs_source, labels_source = iter_source.next() except: iter_source = iter(dset_loaders["source_tr"]) inputs_source, labels_source = iter_source.next() if inputs_source.size(0) == 1: continue iter_num += 1 lr_scheduler(optimizer, iter_num=iter_num, max_iter=max_iter) inputs_source, labels_source = inputs_source.cuda( ), labels_source.cuda() outputs_source = netC(netB(netF(inputs_source))) classifier_loss = loss.CrossEntropyLabelSmooth( num_classes=args.class_num, epsilon=args.smooth)(outputs_source, labels_source) optimizer.zero_grad() classifier_loss.backward() optimizer.step() if iter_num % interval_iter == 0 or iter_num == max_iter: netF.eval() netB.eval() netC.eval() acc_s_tr, _ = cal_acc(dset_loaders['source_tr'], netF, netB, netC) acc_s_te, _ = cal_acc(dset_loaders['source_te'], netF, netB, netC) log_str = 'Task: {}, Iter:{}/{}; Accuracy = {:.2f}%/ {:.2f}%'.format( args.dset, iter_num, max_iter, acc_s_tr, acc_s_te) args.out_file.write(log_str + '\n') args.out_file.flush() print(log_str + '\n') if acc_s_te >= acc_init: acc_init = acc_s_te best_netF = netF.state_dict() best_netB = netB.state_dict() best_netC = netC.state_dict() netF.train() netB.train() netC.train() torch.save(best_netF, osp.join(args.output_dir, "source_F.pt")) torch.save(best_netB, osp.join(args.output_dir, "source_B.pt")) torch.save(best_netC, osp.join(args.output_dir, "source_C.pt")) return netF, netB, netC
def split_target(args): train_bs = args.batch_size if args.dset == 's2m': train_target = mnist.MNIST( './data/mnist/', train=True, download=True, transform=transforms.Compose([ transforms.Resize(32), transforms.Lambda(lambda x: x.convert("RGB")), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ])) train_target2 = mnist.MNIST_twice( './data/mnist/', train=True, download=True, transform=transforms.Compose([ transforms.Resize(32), transforms.Lambda(lambda x: x.convert("RGB")), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ])) test_target = mnist.MNIST( './data/mnist/', train=False, download=True, transform=transforms.Compose([ transforms.Resize(32), transforms.Lambda(lambda x: x.convert("RGB")), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ])) elif args.dset == 'u2m': train_target = mnist.MNIST('./data/mnist/', train=True, download=True, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.5, ), (0.5, )) ])) train_target2 = mnist.MNIST_twice('./data/mnist/', train=True, download=True, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.5, ), (0.5, )) ])) test_target = mnist.MNIST('./data/mnist/', train=False, download=True, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.5, ), (0.5, )) ])) elif args.dset == 'm2u': train_target = usps.USPS( './data/usps/', train=True, download=True, transform=transforms.Compose([ transforms.ToTensor(), # transforms.Lambda(lambda x: _gaussian_blur(x, sigma=0.1)), transforms.Normalize((0.5, ), (0.5, )) ])) train_target2 = usps.USPS_twice( './data/usps/', train=True, download=True, transform=transforms.Compose([ transforms.ToTensor(), # transforms.Lambda(lambda x: _gaussian_blur(x, sigma=0.1)), transforms.Normalize((0.5, ), (0.5, )) ])) test_target = usps.USPS( './data/usps/', train=False, download=True, transform=transforms.Compose([ transforms.ToTensor(), # transforms.Lambda(lambda x: _gaussian_blur(x, sigma=0.1)), transforms.Normalize((0.5, ), (0.5, )) ])) dset_loaders = {} dset_loaders["target_te"] = DataLoader(test_target, batch_size=train_bs, shuffle=False, num_workers=args.worker, drop_last=False) dset_loaders["target"] = DataLoader(train_target, batch_size=train_bs, shuffle=False, num_workers=args.worker, drop_last=False) dset_loaders["target2"] = DataLoader(train_target2, batch_size=train_bs, shuffle=False, num_workers=args.worker, drop_last=False) if args.dset == 'u2m': netF = network.LeNetBase().cuda() elif args.dset == 'm2u': netF = network.LeNetBase().cuda() elif args.dset == 's2m': netF = network.DTNBase().cuda() netB = network.feat_bootleneck(type=args.classifier, feature_dim=netF.in_features, bottleneck_dim=args.bottleneck).cuda() netC = network.feat_classifier(type=args.layer, class_num=args.class_num, bottleneck_dim=args.bottleneck).cuda() if args.model == 'source': modelpath = args.output_dir + "/source_F.pt" netF.load_state_dict(torch.load(modelpath)) modelpath = args.output_dir + "/source_B.pt" netB.load_state_dict(torch.load(modelpath)) modelpath = args.output_dir + "/source_C.pt" netC.load_state_dict(torch.load(modelpath)) pass else: modelpath = args.output_dir + "/target_F_" + args.savename + ".pt" netF.load_state_dict(torch.load(modelpath)) modelpath = args.output_dir + "/target_B_" + args.savename + ".pt" netB.load_state_dict(torch.load(modelpath)) modelpath = args.output_dir + "/target_C_" + args.savename + ".pt" netC.load_state_dict(torch.load(modelpath)) netF.eval() netB.eval() netC.eval() start_test = True with torch.no_grad(): iter_test = iter(dset_loaders['target_te']) for i in range(len(dset_loaders['target_te'])): data = iter_test.next() # pdb.set_trace() inputs = data[0] labels = data[1] inputs = inputs.cuda() outputs = netC(netB(netF(inputs))) if start_test: all_output = outputs.float().cpu() all_label = labels.float() start_test = False else: all_output = torch.cat((all_output, outputs.float().cpu()), 0) all_label = torch.cat((all_label, labels.float()), 0) top_pred, predict = torch.max(all_output, 1) acc = torch.sum( torch.squeeze(predict).float() == all_label).item() / float( all_label.size()[0]) * 100 mean_ent = loss.Entropy(nn.Softmax(dim=1)(all_output)) log_str = 'Task: {}, Iter:{}/{}; Accuracy = {:.2f}%; Mean Ent = {:.4f}'.format( args.dset + '_test', 0, 0, acc, mean_ent.mean()) args.out_file.write(log_str + '\n') args.out_file.flush() print(log_str + '\n') start_test = True with torch.no_grad(): iter_test = iter(dset_loaders['target']) for i in range(len(dset_loaders['target'])): data = iter_test.next() # pdb.set_trace() inputs = data[0] labels = data[1] inputs = inputs.cuda() outputs = netC(netB(netF(inputs))) if start_test: all_output = outputs.float().cpu() all_label = labels.float() start_test = False else: all_output = torch.cat((all_output, outputs.float().cpu()), 0) all_label = torch.cat((all_label, labels.float()), 0) top_pred, predict = torch.max(all_output, 1) acc = torch.sum( torch.squeeze(predict).float() == all_label).item() / float( all_label.size()[0]) * 100 mean_ent = loss.Entropy(nn.Softmax(dim=1)(all_output)) log_str = 'Task: {}, Iter:{}/{}; Accuracy = {:.2f}%; Mean Ent = {:.4f}'.format( args.dset + '_train', 0, 0, acc, mean_ent.mean()) args.out_file.write(log_str + '\n') args.out_file.flush() print(log_str + '\n') if args.ps == 0: est_p = (mean_ent < mean_ent.mean()).sum().item() / mean_ent.size(0) log_str = 'Task: {:.2f}'.format(est_p) print(log_str + '\n') args.out_file.write(log_str + '\n') args.out_file.flush() PS = est_p else: PS = args.ps if args.choice == "ent": value = mean_ent elif args.choice == "maxp": value = -top_pred elif args.choice == "marginp": pred, _ = torch.sort(all_output, 1) value = pred[:, 1] - pred[:, 0] else: value = torch.rand(len(mean_ent)) predict = predict.numpy() train_idx = np.zeros(predict.shape) cls_k = args.class_num for c in range(cls_k): c_idx = np.where(predict == c) c_idx = c_idx[0] c_value = value[c_idx] _, idx_ = torch.sort(c_value) c_num = len(idx_) c_num_s = int(c_num * PS) # print(c, c_num, c_num_s) for ei in range(0, c_num_s): ee = c_idx[idx_[ei]] train_idx[ee] = 1 train_target.targets = predict new_src = copy.deepcopy(train_target) new_tar = copy.deepcopy(train_target2) # pdb.set_trace() if args.dset == 'm2u': new_src.train_data = np.delete(new_src.train_data, np.where(train_idx == 0)[0], axis=0) new_src.train_labels = np.delete(new_src.train_labels, np.where(train_idx == 0)[0], axis=0) new_tar.train_data = np.delete(new_tar.train_data, np.where(train_idx == 1)[0], axis=0) new_tar.train_labels = np.delete(new_tar.train_labels, np.where(train_idx == 1)[0], axis=0) else: new_src.data = np.delete(new_src.data, np.where(train_idx == 0)[0], axis=0) new_src.targets = np.delete(new_src.targets, np.where(train_idx == 0)[0], axis=0) new_tar.data = np.delete(new_tar.data, np.where(train_idx == 1)[0], axis=0) new_tar.targets = np.delete(new_tar.targets, np.where(train_idx == 1)[0], axis=0) # pdb.set_trace() return new_src, new_tar
def copy_target_simp(args): dset_loaders = data_load(args) if args.net_src[0:3] == 'res': netF = network.ResBase(res_name=args.net_src).cuda() netC = network.feat_classifier_simpl(class_num=args.class_num, feat_dim=netF.in_features).cuda() args.modelpath = args.output_dir_src + '/source_F.pt' netF.load_state_dict(torch.load(args.modelpath)) args.modelpath = args.output_dir_src + '/source_C.pt' netC.load_state_dict(torch.load(args.modelpath)) source_model = nn.Sequential(netF, netC).cuda() source_model.eval() if args.net[0:3] == 'res': netF = network.ResBase(res_name=args.net, pretrain=True).cuda() netB = network.feat_bootleneck(type=args.classifier, feature_dim=netF.in_features, bottleneck_dim=args.bottleneck).cuda() netC = network.feat_classifier(type=args.layer, class_num=args.class_num, bottleneck_dim=args.bottleneck).cuda() param_group = [] learning_rate = args.lr for k, v in netF.named_parameters(): param_group += [{'params': v, 'lr': learning_rate * 0.1}] for k, v in netB.named_parameters(): param_group += [{'params': v, 'lr': learning_rate}] for k, v in netC.named_parameters(): param_group += [{'params': v, 'lr': learning_rate}] optimizer = optim.SGD(param_group) optimizer = op_copy(optimizer) ent_best = 1.0 max_iter = args.max_epoch * len(dset_loaders["target"]) interval_iter = max_iter // 10 iter_num = 0 model = nn.Sequential(netF, netB, netC).cuda() model.eval() start_test = True with torch.no_grad(): iter_test = iter(dset_loaders["target_te"]) for i in range(len(dset_loaders["target_te"])): data = iter_test.next() inputs, labels = data[0], data[1] inputs = inputs.cuda() outputs = source_model(inputs) outputs = nn.Softmax(dim=1)(outputs) _, src_idx = torch.sort(outputs, 1, descending=True) if args.topk > 0: topk = np.min([args.topk, args.class_num]) for i in range(outputs.size()[0]): outputs[i, src_idx[i, topk:]] = ( 1.0 - outputs[i, src_idx[i, :topk]].sum()) / ( outputs.size()[1] - topk) if start_test: all_output = outputs.float() all_label = labels start_test = False else: all_output = torch.cat((all_output, outputs.float()), 0) all_label = torch.cat((all_label, labels), 0) mem_P = all_output.detach() model.train() while iter_num < max_iter: if args.ema < 1.0 and iter_num > 0 and iter_num % interval_iter == 0: model.eval() start_test = True with torch.no_grad(): iter_test = iter(dset_loaders["target_te"]) for i in range(len(dset_loaders["target_te"])): data = iter_test.next() inputs = data[0] inputs = inputs.cuda() outputs = model(inputs) outputs = nn.Softmax(dim=1)(outputs) if start_test: all_output = outputs.float() start_test = False else: all_output = torch.cat((all_output, outputs.float()), 0) mem_P = mem_P * args.ema + all_output.detach() * (1 - args.ema) model.train() try: inputs_target, y, tar_idx = iter_target.next() except: iter_target = iter(dset_loaders["target"]) inputs_target, y, tar_idx = iter_target.next() if inputs_target.size(0) == 1: continue iter_num += 1 lr_scheduler(optimizer, iter_num=iter_num, max_iter=max_iter, power=1.5) inputs_target = inputs_target.cuda() with torch.no_grad(): outputs_target_by_source = mem_P[tar_idx, :] _, src_idx = torch.sort(outputs_target_by_source, 1, descending=True) outputs_target = model(inputs_target) outputs_target = torch.nn.Softmax(dim=1)(outputs_target) classifier_loss = nn.KLDivLoss(reduction='batchmean')( outputs_target.log(), outputs_target_by_source) optimizer.zero_grad() entropy_loss = torch.mean(loss.Entropy(outputs_target)) msoftmax = outputs_target.mean(dim=0) gentropy_loss = torch.sum(-msoftmax * torch.log(msoftmax + 1e-5)) entropy_loss -= gentropy_loss classifier_loss += entropy_loss classifier_loss.backward() if args.mix > 0: alpha = 0.3 lam = np.random.beta(alpha, alpha) index = torch.randperm(inputs_target.size()[0]).cuda() mixed_input = lam * inputs_target + (1 - lam) * inputs_target[index, :] mixed_output = (lam * outputs_target + (1 - lam) * outputs_target[index, :]).detach() update_batch_stats(model, False) outputs_target_m = model(mixed_input) update_batch_stats(model, True) outputs_target_m = torch.nn.Softmax(dim=1)(outputs_target_m) classifier_loss = args.mix * nn.KLDivLoss(reduction='batchmean')( outputs_target_m.log(), mixed_output) classifier_loss.backward() optimizer.step() if iter_num % interval_iter == 0 or iter_num == max_iter: model.eval() acc_s_te, mean_ent = cal_acc(dset_loaders['test'], netF, netB, netC, False) log_str = 'Task: {}, Iter:{}/{}; Accuracy = {:.2f}%, Ent = {:.4f}'.format( args.name, iter_num, max_iter, acc_s_te, mean_ent) args.out_file.write(log_str + '\n') args.out_file.flush() print(log_str + '\n') model.train() torch.save(netF.state_dict(), osp.join(args.output_dir, "source_F.pt")) torch.save(netB.state_dict(), osp.join(args.output_dir, "source_B.pt")) torch.save(netC.state_dict(), osp.join(args.output_dir, "source_C.pt"))
def split_target(args): test_transform = torchvision.transforms.Compose([ torchvision.transforms.Resize((256, 256)), torchvision.transforms.CenterCrop(224), torchvision.transforms.ToTensor(), torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) txt_tar = open(args.t_dset_path).readlines() dset_loaders = {} test_set = ImageList(txt_tar, transform=test_transform) dset_loaders["target"] = torch.utils.data.DataLoader(test_set, batch_size=args.batch_size*3, shuffle=False, num_workers=args.worker, drop_last=False) netF = network.ResBase(res_name=args.net).cuda() netB = network.feat_bootleneck(type=args.classifier, feature_dim=netF.in_features, bottleneck_dim=args.bottleneck).cuda() netC = network.feat_classifier(type=args.layer, class_num = args.class_num, bottleneck_dim=args.bottleneck).cuda() if args.model == "source": modelpath = args.output_dir + "/source_F.pt" netF.load_state_dict(torch.load(modelpath)) modelpath = args.output_dir + "/source_B.pt" netB.load_state_dict(torch.load(modelpath)) modelpath = args.output_dir + "/source_C.pt" netC.load_state_dict(torch.load(modelpath)) else: modelpath = args.output_dir + "/target_F_" + args.savename + ".pt" netF.load_state_dict(torch.load(modelpath)) modelpath = args.output_dir + "/target_B_" + args.savename + ".pt" netB.load_state_dict(torch.load(modelpath)) modelpath = args.output_dir + "/target_C_" + args.savename + ".pt" netC.load_state_dict(torch.load(modelpath)) netF.eval() netB.eval() netC.eval() start_test = True with torch.no_grad(): iter_test = iter(dset_loaders['target']) for i in range(len(dset_loaders['target'])): data = iter_test.next() inputs = data[0] labels = data[1] inputs = inputs.cuda() outputs = netC(netB(netF(inputs))) if start_test: all_output = outputs.float().cpu() all_label = labels.float() start_test = False else: all_output = torch.cat((all_output, outputs.float().cpu()), 0) all_label = torch.cat((all_label, labels.float()), 0) top_pred, predict = torch.max(all_output, 1) acc = torch.sum(torch.squeeze(predict).float() == all_label).item() / float(all_label.size()[0]) * 100 mean_ent = loss.Entropy(nn.Softmax(dim=1)(all_output)) if args.dset == 'VISDA-C': matrix = confusion_matrix(all_label, torch.squeeze(predict).float()) matrix = matrix[np.unique(all_label).astype(int),:] all_acc = matrix.diagonal()/matrix.sum(axis=1) * 100 acc = all_acc.mean() aa = [str(np.round(i, 2)) for i in all_acc] acc_list = ' '.join(aa) print(acc_list) args.out_file.write(acc_list + '\n') args.out_file.flush() log_str = 'Task: {}, Iter:{}/{}; Accuracy = {:.2f}%; Mean Ent = {:.4f}'.format(args.name, 0, 0, acc, mean_ent.mean()) args.out_file.write(log_str + '\n') args.out_file.flush() print(log_str+'\n') if args.ps == 0: est_p = (mean_ent<mean_ent.mean()).sum().item() / mean_ent.size(0) log_str = 'Task: {:.2f}'.format(est_p) print(log_str + '\n') args.out_file.write(log_str + '\n') args.out_file.flush() PS = est_p else: PS = args.ps if args.choice == "ent": value = mean_ent elif args.choice == "maxp": value = - top_pred elif args.choice == "marginp": pred, _ = torch.sort(all_output, 1) value = pred[:,1] - pred[:,0] else: value = torch.rand(len(mean_ent)) ori_target = txt_tar.copy() new_tar = [] new_src = [] predict = predict.numpy() cls_k = args.class_num for c in range(cls_k): c_idx = np.where(predict==c) c_idx = c_idx[0] c_value = value[c_idx] _, idx_ = torch.sort(c_value) c_num = len(idx_) c_num_s = int(c_num * PS) for ei in range(0, c_num_s): ee = c_idx[idx_[ei]] reci = ori_target[ee].strip().split(' ') line = reci[0] + ' ' + str(c) + '\n' new_src.append(line) for ei in range(c_num_s, c_num): ee = c_idx[idx_[ei]] reci = ori_target[ee].strip().split(' ') line = reci[0] + ' ' + str(c) + '\n' new_tar.append(line) return new_src.copy(), new_tar.copy()
def train_source(args): dset_loaders = data_load(args) ## set base network if args.net[0:3] == 'res' or args.net[0:3] == 'vgg': if args.net[0:3] == 'res': netF = network.ResBase(res_name=args.net).cuda() else: netF = network.VGGBase(vgg_name=args.net).cuda() netB = network.feat_bootleneck(type=args.classifier, feature_dim=netF.in_features, bottleneck_dim=args.bottleneck).cuda() # classifier: bn netC = network.feat_classifier(type=args.layer, class_num=args.class_num, bottleneck_dim=args.bottleneck).cuda() # layer: wn if args.resume: args.modelpath = args.output_dir_src + '/source_F.pt' netF.load_state_dict(torch.load(args.modelpath)) args.modelpath = args.output_dir_src + '/source_B.pt' netB.load_state_dict(torch.load(args.modelpath)) args.modelpath = args.output_dir_src + '/source_C.pt' netC.load_state_dict(torch.load(args.modelpath)) param_group = [] learning_rate = args.lr for k, v in netF.named_parameters(): param_group += [{'params': v, 'lr': learning_rate * 0.1}] for k, v in netB.named_parameters(): param_group += [{'params': v, 'lr': learning_rate}] for k, v in netC.named_parameters(): param_group += [{'params': v, 'lr': learning_rate}] optimizer = optim.SGD(param_group) optimizer = op_copy(optimizer) acc_init = 0. max_iter = args.max_epoch * len(dset_loaders["source_tr"]) print_loss_interval = 25 interval_iter = 100 iter_num = 0 if args.net[0:3] == 'res' or args.net[0:3] == 'vgg': netF.train() netB.train() netC.train() while iter_num < max_iter: try: inputs_source, labels_source = iter_source.next() except: iter_source = iter(dset_loaders["source_tr"]) inputs_source, labels_source = iter_source.next() if inputs_source.size(0) == 1: continue iter_num += 1 lr_scheduler(optimizer, iter_num=iter_num, max_iter=max_iter) inputs_source, labels_source = inputs_source.cuda(), labels_source.cuda() outputs_source = netC(netB(netF(inputs_source))) classifier_loss = CrossEntropyLabelSmooth(num_classes=args.class_num, epsilon=args.smooth)(outputs_source, labels_source) optimizer.zero_grad() classifier_loss.backward() optimizer.step() if iter_num % print_loss_interval == 0: print("Iter:{:>4d}/{} | Classification loss on Source: {:.2f}".format(iter_num, max_iter, classifier_loss.item())) if iter_num % interval_iter == 0 or iter_num == max_iter: netF.eval() netB.eval() netC.eval() if args.dset == 'VISDA-RSUT' or args.dset == 'VISDA-RSUT-50' or args.dset == 'VISDA-RSUT-10': # The small classes in VisDA-C (RSUT) still have relatively many samples. # Safe to use per-class average accuracy. acc_s_te, acc_list, acc_cls_avg_te= cal_acc(dset_loaders['source_te'], netF, netB, netC, per_class_flag=True, visda_flag=True) log_str = 'Task: {}, Iter:{}/{}; Accuracy = {:.2f}%, Cls Avg Acc = {:.2f}'.format(args.name_src, iter_num, max_iter, acc_s_te, acc_cls_avg_te) + '\n' + acc_list cur_acc = acc_cls_avg_te else: if args.trte == 'stratified': # Stratified cross validation ensures the existence of every class in the validation set. # Safe to use per-class average accuracy. acc_s_te, acc_cls_avg_te, _ = cal_acc(dset_loaders['source_te'], netF, netB, netC, per_class_flag=True, visda_flag=False) log_str = 'Task: {}, Iter:{}/{}; Accuracy = {:.2f}%, Cls Avg Acc = {:.2f}'.format(args.name_src, iter_num, max_iter, acc_s_te, acc_cls_avg_te) cur_acc = acc_cls_avg_te else: # Conventional cross validation may lead to the absence of certain classes in validation set, # esp. when the dataset includes some very small classes, e.g., Office-Home (RSUT), DomainNet. # Use overall accuracy to avoid 'nan' issue. acc_s_te, _ = cal_acc(dset_loaders['source_te'], netF, netB, netC, per_class_flag=False, visda_flag=False) log_str = 'Task: {}, Iter:{}/{}; Accuracy = {:.2f}%'.format(args.name_src, iter_num, max_iter, acc_s_te) cur_acc = acc_s_te args.out_file.write(log_str + '\n') args.out_file.flush() print(log_str + '\n') if cur_acc >= acc_init and iter_num >= 3 * len(dset_loaders["source_tr"]): # first 3 epochs: not stable yet acc_init = cur_acc best_netF = netF.state_dict() best_netB = netB.state_dict() best_netC = netC.state_dict() netF.train() netB.train() netC.train() torch.save(best_netF, osp.join(args.output_dir_src, "source_F.pt")) torch.save(best_netB, osp.join(args.output_dir_src, "source_B.pt")) torch.save(best_netC, osp.join(args.output_dir_src, "source_C.pt")) return netF, netB, netC
def train_target(args): dset_loaders = digit_load(args) ## set base network if args.dset == 'u2m': netF = network.LeNetBase().cuda() elif args.dset == 'm2u': netF = network.LeNetBase().cuda() elif args.dset == 's2m': netF = network.DTNBase().cuda() netB = network.feat_bootleneck(type=args.classifier, feature_dim=netF.in_features, bottleneck_dim=args.bottleneck).cuda() netC = network.feat_classifier(type=args.layer, class_num=args.class_num, bottleneck_dim=args.bottleneck).cuda() args.modelpath = args.output_dir + '/source_F_val.pt' netF.load_state_dict(torch.load(args.modelpath)) args.modelpath = args.output_dir + '/source_B_val.pt' netB.load_state_dict(torch.load(args.modelpath)) args.modelpath = args.output_dir + '/source_C_val.pt' netC.load_state_dict(torch.load(args.modelpath)) netC.eval() for k, v in netC.named_parameters(): v.requires_grad = False param_group = [] for k, v in netF.named_parameters(): param_group += [{'params': v, 'lr': args.lr}] for k, v in netB.named_parameters(): param_group += [{'params': v, 'lr': args.lr}] optimizer = optim.SGD(param_group, momentum=0.9, weight_decay=5e-4, nesterov=True) for epoch in tqdm(range(args.max_epoch), leave=False): iter_test = iter(dset_loaders["target"]) netF.eval() netF.eval() mem_label = obtain_label(dset_loaders['target_te'], netF, netB, netC, args) mem_label = torch.from_numpy(mem_label).cuda() netF.train() netB.train() for _, (inputs_test, _, tar_idx) in tqdm(enumerate(iter_test), leave=False): if inputs_test.size(0) == 1: continue inputs_test = inputs_test.cuda() pred = mem_label[tar_idx] features_test = netB(netF(inputs_test)) outputs_test = netC(features_test) classifier_loss = CrossEntropyLabelSmooth( num_classes=args.class_num, epsilon=0)(outputs_test, pred) softmax_out = nn.Softmax(dim=1)(outputs_test) im_loss = torch.mean(Entropy(softmax_out)) msoftmax = softmax_out.mean(dim=0) im_loss -= torch.sum(-msoftmax * torch.log(msoftmax + 1e-5)) total_loss = im_loss + args.cls_par * classifier_loss optimizer.zero_grad() total_loss.backward() optimizer.step() netF.eval() netB.eval() acc, _ = cal_acc(dset_loaders['test'], netF, netB, netC) log_str = 'Task: {}, Iter:{}/{}; Accuracy = {:.2f}%'.format( args.dset, epoch + 1, args.max_epoch, acc * 100) args.out_file.write(log_str + '\n') args.out_file.flush() print(log_str + '\n') # torch.save(netF.state_dict(), osp.join(args.output_dir, "target_F.pt")) # torch.save(netB.state_dict(), osp.join(args.output_dir, "target_B.pt")) # torch.save(netC.state_dict(), osp.join(args.output_dir, "target_C.pt")) return netF, netB, netC
def train_target_bait(args, zz=''): dset_loaders = data_load(args) ## set base network if args.net[0:3] == 'res': netF = network.ResBase(res_name=args.net).cuda() netB = network.feat_bootleneck(type=args.classifier, feature_dim=netF.in_features, bottleneck_dim=args.bottleneck).cuda() netC = network.feat_classifier(type=args.layer, class_num=args.class_num, bottleneck_dim=args.bottleneck).cuda() oldC = network.feat_classifier(type=args.layer, class_num=args.class_num, bottleneck_dim=args.bottleneck).cuda() args.modelpath = args.output_dir_src + '/source_F_' + str(zz) + '.pt' netF.load_state_dict(torch.load(args.modelpath)) args.modelpath = args.output_dir_src + '/source_B_' + str(zz) + '.pt' netB.load_state_dict(torch.load(args.modelpath)) args.modelpath = args.output_dir_src + '/source_C_' + str(zz) + '.pt' netC.load_state_dict(torch.load(args.modelpath)) oldC.load_state_dict(torch.load(args.modelpath)) oldC.eval() netC.train() for k, v in oldC.named_parameters(): v.requires_grad = False param_group = [] param_group_c = [] for k, v in netF.named_parameters(): param_group += [{'params': v, 'lr': args.lr * 0.01}] #0.1 for k, v in netB.named_parameters(): param_group += [{'params': v, 'lr': args.lr * .1}] # 1 for k, v in netC.named_parameters(): param_group_c += [{'params': v, 'lr': args.lr * .1}] #1 optimizer = optim.SGD(param_group, momentum=0.9, weight_decay=5e-4, nesterov=True) optimizer_c = optim.SGD(param_group_c, momentum=0.9, weight_decay=5e-4, nesterov=True) netF.train() netB.train() iter_num = 0 iter_target = iter(dset_loaders["target"]) while iter_num < (args.max_epoch) * len(dset_loaders["target"]): try: inputs_test, _, tar_idx = iter_target.next() except: iter_target = iter(dset_loaders["target"]) inputs_test, _, tar_idx = iter_target.next() if inputs_test.size(0) == 1: continue iter_num += 1 inputs_test = inputs_test.cuda() batch_size = inputs_test.shape[0] if True: total_loss = 0 features_test = netB(netF(inputs_test)) outputs_test = netC(features_test) outputs_test_old = oldC(features_test) softmax_out = nn.Softmax(dim=1)(outputs_test) softmax_out_old = nn.Softmax(dim=1)(outputs_test_old) loss_cast = loss.SKL(softmax_out, softmax_out_old).sum(dim=1) entropy_old = Entropy(softmax_out_old) indx = entropy_old.topk(int(batch_size * 0.5), largest=True)[-1] ones_mask = torch.ones(batch_size).cuda() * -1 ones_mask[indx] = 1 loss_cast = loss_cast * ones_mask total_loss -= torch.mean(loss_cast) * 10 optimizer_c.zero_grad() total_loss.backward() optimizer_c.step() for _ in range(1): total_loss = 0 features_test = netB(netF(inputs_test)) outputs_test = netC(features_test) softmax_out = nn.Softmax(dim=1)(outputs_test) outputs_test_old = oldC(features_test) softmax_out_old = nn.Softmax(dim=1)(outputs_test_old) msoftmax = softmax_out_old.mean(dim=0) cb_loss = torch.sum(msoftmax * torch.log(msoftmax + 1e-5)) total_loss += cb_loss msoftmax = softmax_out.mean(dim=0) cb_loss = torch.sum(msoftmax * torch.log(msoftmax + 1e-5)) total_loss += cb_loss loss_bite = (-softmax_out_old * torch.log(softmax_out + 1e-5)).sum( 1) - (softmax_out * torch.log(softmax_out_old + 1e-5)).sum(1) total_loss += torch.mean(loss_bite) #*0.8 optimizer.zero_grad() total_loss.backward() optimizer.step() if iter_num % int(args.interval * len(dset_loaders["target"])) == 0: netF.eval() netB.eval() netC.eval() acc, acc_list = cal_acc(dset_loaders['test'], netF, netB, oldC, args.dset == "visda17") log_str = 'Task: {}, Iter:{}/{}; Accuracy = {:.2f}%'.format(args.name, iter_num, \ args.max_epoch * len(dset_loaders["target"]), acc) + '\n' + str(acc_list) args.out_file.write(log_str + '\n') args.out_file.flush() print(log_str + '\n') netF.train() netB.train() netC.train() return netF, netB, netC
def train_source(args): dset_loaders = data_load(args) ## set base network if args.net[0:3] == 'res': netF = network.ResBase(res_name=args.net).cuda() netB = network.feat_bootleneck(type=args.classifier, feature_dim=netF.in_features, bottleneck_dim=args.bottleneck).cuda() netC = network.feat_classifier(type=args.layer, class_num = args.class_num, bottleneck_dim=args.bottleneck).cuda() param_group = [] learning_rate = args.lr for k, v in netF.named_parameters(): param_group += [{'params': v, 'lr': learning_rate*0.1}]#1 for k, v in netB.named_parameters(): param_group += [{'params': v, 'lr': learning_rate*1}]#10 for k, v in netC.named_parameters(): param_group += [{'params': v, 'lr': learning_rate*1}]#10 optimizer = optim.SGD(param_group, momentum=0.9, weight_decay=5e-4, nesterov=True) acc_init = 0 netF.train() netB.train() netC.train() iter_num = 0 iter_source = iter(dset_loaders["source_tr"]) while iter_num < args.max_epoch * len(dset_loaders["source_tr"]): try: inputs_source, labels_source = iter_source.next() except: iter_source = iter(dset_loaders["source_tr"]) inputs_source, labels_source = iter_source.next() if inputs_source.size(0) == 1: continue iter_num += 1 inputs_source, labels_source = inputs_source.cuda(), labels_source.cuda() outputs_source = netC(netB(netF(inputs_source))) classifier_loss = loss.CrossEntropyLabelSmooth(num_classes=args.class_num, epsilon=args.smooth)(outputs_source, labels_source) optimizer.zero_grad() classifier_loss.backward() optimizer.step() if (iter_num % int(args.interval*len(dset_loaders["source_tr"])) == 0): netF.eval() netB.eval() netC.eval() acc_s_te, acc_list = cal_acc(dset_loaders['source_te'], netF, netB, netC, args.dset=="visda17") log_str = 'Task: {}, Iter:{}; Accuracy = {:.2f}%'.format(args.name_src, iter_num, acc_s_te) + '\n' + str(acc_list) args.out_file.write(log_str + '\n') args.out_file.flush() print(log_str+'\n') if acc_s_te >= acc_init: acc_init = acc_s_te best_netF = netF.state_dict() best_netB = netB.state_dict() best_netC = netC.state_dict() netF.train() netB.train() netC.train() torch.save(best_netF, osp.join(args.output_dir_src, "source_F_val.pt")) torch.save(best_netB, osp.join(args.output_dir_src, "source_B_val.pt")) torch.save(best_netC, osp.join(args.output_dir_src, "source_C_val.pt")) return netF, netB, netC
def train(args, txt_src, txt_tgt): ## set pre-process dset_loaders = data_load(args, txt_src, txt_tgt) # pdb.set_trace() max_len = max(len(dset_loaders["source"]), len(dset_loaders["target"])) max_iter = args.max_epoch * max_len interval_iter = max_iter // 10 if args.dset == 'u2m': netG = network.LeNetBase().cuda() elif args.dset == 'm2u': netG = network.LeNetBase().cuda() elif args.dset == 's2m': netG = network.DTNBase().cuda() netB = network.feat_bootleneck(type=args.classifier, feature_dim=netG.in_features, bottleneck_dim=args.bottleneck).cuda() netC = network.feat_classifier(type=args.layer, class_num=args.class_num, bottleneck_dim=args.bottleneck).cuda() if args.model == 'source': modelpath = args.output_dir + "/source_F.pt" netG.load_state_dict(torch.load(modelpath)) modelpath = args.output_dir + "/source_B.pt" netB.load_state_dict(torch.load(modelpath)) else: modelpath = args.output_dir + "/target_F_" + args.savename + ".pt" netG.load_state_dict(torch.load(modelpath)) modelpath = args.output_dir + "/target_B_" + args.savename + ".pt" netB.load_state_dict(torch.load(modelpath)) netF = nn.Sequential(netB, netC) optimizer_g = optim.SGD(netG.parameters(), lr=args.lr * 0.1) optimizer_f = optim.SGD(netF.parameters(), lr=args.lr) base_network = nn.Sequential(netG, netF) source_loader_iter = iter(dset_loaders["source"]) target_loader_iter = iter(dset_loaders["target"]) list_acc = [] best_ent = 100 for iter_num in range(1, max_iter + 1): base_network.train() lr_scheduler(optimizer_g, init_lr=args.lr * 0.1, iter_num=iter_num, max_iter=max_iter) lr_scheduler(optimizer_f, init_lr=args.lr, iter_num=iter_num, max_iter=max_iter) try: inputs_source, labels_source = source_loader_iter.next() except: source_loader_iter = iter(dset_loaders["source"]) inputs_source, labels_source = source_loader_iter.next() try: inputs_target, _, target_idx = target_loader_iter.next() except: target_loader_iter = iter(dset_loaders["target"]) inputs_target, _, target_idx = target_loader_iter.next() targets_s = torch.zeros(args.batch_size, args.class_num).scatter_( 1, labels_source.view(-1, 1), 1) inputs_s = inputs_source.cuda() targets_s = targets_s.cuda() inputs_t = inputs_target[0].cuda() inputs_t2 = inputs_target[1].cuda() with torch.no_grad(): # compute guessed labels of unlabel samples outputs_u = base_network(inputs_t) outputs_u2 = base_network(inputs_t2) p = (torch.softmax(outputs_u, dim=1) + torch.softmax(outputs_u2, dim=1)) / 2 pt = p**(1 / args.T) targets_u = pt / pt.sum(dim=1, keepdim=True) targets_u = targets_u.detach() #################################################################### all_inputs = torch.cat([inputs_s, inputs_t, inputs_t2], dim=0) all_targets = torch.cat([targets_s, targets_u, targets_u], dim=0) if args.alpha > 0: l = np.random.beta(args.alpha, args.alpha) l = max(l, 1 - l) else: l = 1 idx = torch.randperm(all_inputs.size(0)) input_a, input_b = all_inputs, all_inputs[idx] target_a, target_b = all_targets, all_targets[idx] mixed_input = l * input_a + (1 - l) * input_b mixed_target = l * target_a + (1 - l) * target_b # interleave labeled and unlabed samples between batches to get correct batchnorm calculation mixed_input = list(torch.split(mixed_input, args.batch_size)) mixed_input = utils.interleave(mixed_input, args.batch_size) # s = [sa, sb, sc] # t1 = [t1a, t1b, t1c] # t2 = [t2a, t2b, t2c] # => s' = [sa, t1b, t2c] t1' = [t1a, sb, t1c] t2' = [t2a, t2b, sc] logits = base_network(mixed_input[0]) logits = [logits] for input in mixed_input[1:]: temp = base_network(input) logits.append(temp) # put interleaved samples back # [i[:,0] for i in aa] logits = utils.interleave(logits, args.batch_size) logits_x = logits[0] logits_u = torch.cat(logits[1:], dim=0) train_criterion = utils.SemiLoss() Lx, Lu, w = train_criterion(logits_x, mixed_target[:args.batch_size], logits_u, mixed_target[args.batch_size:], iter_num, max_iter, args.lambda_u) loss = Lx + w * Lu optimizer_g.zero_grad() optimizer_f.zero_grad() loss.backward() optimizer_g.step() optimizer_f.step() if iter_num % interval_iter == 0 or iter_num == max_iter: base_network.eval() acc, py, score, y = cal_acc(dset_loaders["train"], base_network, flag=False) mean_ent = torch.mean(Entropy(score)) log_str = 'Task: {}, Iter:{}/{}; Accuracy = {:.2f}%; Mean Ent = {:.4f}'.format( args.dset + '_train', iter_num, max_iter, acc, mean_ent) args.out_file.write(log_str + '\n') args.out_file.flush() print(log_str + '\n') acc, py, score, y = cal_acc(dset_loaders["test"], base_network, flag=False) mean_ent = torch.mean(Entropy(score)) list_acc.append(acc) if best_ent > mean_ent: val_acc = acc best_ent = mean_ent best_y = y best_py = py best_score = score log_str = 'Task: {}, Iter:{}/{}; Accuracy = {:.2f}%; Mean Ent = {:.4f}'.format( args.dset + '_test', iter_num, max_iter, acc, mean_ent) args.out_file.write(log_str + '\n') args.out_file.flush() print(log_str + '\n') idx = np.argmax(np.array(list_acc)) max_acc = list_acc[idx] final_acc = list_acc[-1] log_str = '\n==========================================\n' log_str += '\nVal Acc = {:.2f}\nMax Acc = {:.2f}\nFin Acc = {:.2f}\n'.format( val_acc, max_acc, final_acc) args.out_file.write(log_str + '\n') args.out_file.flush() # torch.save(base_network.state_dict(), osp.join(args.output_dir, args.log + ".pt")) # sio.savemat(osp.join(args.output_dir, args.log + ".mat"), {'y':best_y.cpu().numpy(), # 'py':best_py.cpu().numpy(), 'score':best_score.cpu().numpy()}) return base_network, py
def train_target(args): dset_loaders = digit_load(args) ## set base network if args.dset == 'u2m': netF = network.LeNetBase().cuda() elif args.dset == 'm2u': netF = network.LeNetBase().cuda() elif args.dset == 's2m': netF = network.DTNBase().cuda() netB = network.feat_bootleneck(type=args.classifier, feature_dim=netF.in_features, bottleneck_dim=args.bottleneck).cuda() netC = network.feat_classifier(type=args.layer, class_num=args.class_num, bottleneck_dim=args.bottleneck).cuda() args.modelpath = args.output_dir + '/source_F.pt' netF.load_state_dict(torch.load(args.modelpath)) args.modelpath = args.output_dir + '/source_B.pt' netB.load_state_dict(torch.load(args.modelpath)) args.modelpath = args.output_dir + '/source_C.pt' netC.load_state_dict(torch.load(args.modelpath)) netC.eval() for k, v in netC.named_parameters(): v.requires_grad = False param_group = [] for k, v in netF.named_parameters(): param_group += [{'params': v, 'lr': args.lr}] for k, v in netB.named_parameters(): param_group += [{'params': v, 'lr': args.lr}] optimizer = optim.SGD(param_group) optimizer = op_copy(optimizer) max_iter = args.max_epoch * len(dset_loaders["target"]) interval_iter = len(dset_loaders["target"]) # interval_iter = max_iter // args.interval iter_num = 0 while iter_num < max_iter: optimizer.zero_grad() try: inputs_test, _, tar_idx = iter_test.next() except: iter_test = iter(dset_loaders["target"]) inputs_test, _, tar_idx = iter_test.next() if inputs_test.size(0) == 1: continue if iter_num % interval_iter == 0 and args.cls_par > 0: netF.eval() netB.eval() mem_label = obtain_label(dset_loaders['target_te'], netF, netB, netC, args) mem_label = torch.from_numpy(mem_label).cuda() netF.train() netB.train() iter_num += 1 lr_scheduler(optimizer, iter_num=iter_num, max_iter=max_iter) inputs_test = inputs_test.cuda() features_test = netB(netF(inputs_test)) outputs_test = netC(features_test) if args.cls_par > 0: pred = mem_label[tar_idx] classifier_loss = args.cls_par * nn.CrossEntropyLoss()( outputs_test, pred) else: classifier_loss = torch.tensor(0.0).cuda() if args.ent: softmax_out = nn.Softmax(dim=1)(outputs_test) entropy_loss = torch.mean(loss.Entropy(softmax_out)) # if args.gent: # msoftmax = softmax_out.mean(dim=0) # entropy_loss -= torch.sum(-msoftmax * torch.log(msoftmax + 1e-5)) im_loss = entropy_loss * args.ent_par classifier_loss += im_loss optimizer.zero_grad() classifier_loss.backward() optimizer.step() if iter_num % interval_iter == 0 or iter_num == max_iter: netF.eval() netB.eval() acc, _ = cal_acc(dset_loaders['test'], netF, netB, netC) log_str = 'Task: {}, Iter:{}/{}; Accuracy = {:.2f}%'.format( args.dset, iter_num, max_iter, acc) args.out_file.write(log_str + '\n') args.out_file.flush() print(log_str + '\n') netF.train() netB.train() if args.issave: torch.save( netF.state_dict(), osp.join(args.output_dir, "target_F_" + args.savename + ".pt")) torch.save( netB.state_dict(), osp.join(args.output_dir, "target_B_" + args.savename + ".pt")) torch.save( netC.state_dict(), osp.join(args.output_dir, "target_C_" + args.savename + ".pt")) return netF, netB, netC
def train_target(args): dset_loaders = data_load(args) ## set base network if args.net[0:3] == 'res': netF = network.ResBase(res_name=args.net).cuda() elif args.net[0:3] == 'vgg': netF = network.VGGBase(vgg_name=args.net).cuda() netB = network.feat_bootleneck(type=args.classifier, feature_dim=netF.in_features, bottleneck_dim=args.bottleneck).cuda() netC = network.feat_classifier(type=args.layer, class_num=args.class_num, bottleneck_dim=args.bottleneck).cuda() args.modelpath = args.output_dir_src + '/source_F.pt' netF.load_state_dict(torch.load(args.modelpath)) args.modelpath = args.output_dir_src + '/source_B.pt' netB.load_state_dict(torch.load(args.modelpath)) args.modelpath = args.output_dir_src + '/source_C.pt' netC.load_state_dict(torch.load(args.modelpath)) netC.eval() for k, v in netC.named_parameters(): v.requires_grad = False param_group = [] for k, v in netF.named_parameters(): if args.lr_decay1 > 0: param_group += [{'params': v, 'lr': args.lr * args.lr_decay1}] else: v.requires_grad = False for k, v in netB.named_parameters(): if args.lr_decay2 > 0: param_group += [{'params': v, 'lr': args.lr * args.lr_decay2}] else: v.requires_grad = False optimizer = optim.SGD(param_group) optimizer = op_copy(optimizer) tt = 0 iter_num = 0 max_iter = args.max_epoch * len(dset_loaders["target"]) interval_iter = max_iter // args.interval while iter_num < max_iter: try: inputs_test, _, tar_idx = iter_test.next() except: iter_test = iter(dset_loaders["target"]) inputs_test, _, tar_idx = iter_test.next() if inputs_test.size(0) == 1: continue if iter_num % interval_iter == 0: netF.eval() netB.eval() mem_label, ENT_THRESHOLD = obtain_label(dset_loaders['test'], netF, netB, netC, args) mem_label = torch.from_numpy(mem_label).cuda() netF.train() netB.train() inputs_test = inputs_test.cuda() iter_num += 1 lr_scheduler(optimizer, iter_num=iter_num, max_iter=max_iter) pred = mem_label[tar_idx] features_test = netB(netF(inputs_test)) outputs_test = netC(features_test) softmax_out = nn.Softmax(dim=1)(outputs_test) outputs_test_known = outputs_test[pred < args.class_num, :] pred = pred[pred < args.class_num] if len(pred) == 0: print(tt) del features_test del outputs_test tt += 1 continue if args.cls_par > 0: classifier_loss = nn.CrossEntropyLoss()(outputs_test_known, pred) classifier_loss *= args.cls_par else: classifier_loss = torch.tensor(0.0).cuda() if args.ent: softmax_out_known = nn.Softmax(dim=1)(outputs_test_known) entropy_loss = torch.mean(loss.Entropy(softmax_out_known)) if args.gent: msoftmax = softmax_out.mean(dim=0) gentropy_loss = torch.sum(-msoftmax * torch.log(msoftmax + args.epsilon)) entropy_loss -= gentropy_loss classifier_loss += entropy_loss * args.ent_par optimizer.zero_grad() classifier_loss.backward() optimizer.step() if iter_num % interval_iter == 0 or iter_num == max_iter: netF.eval() netB.eval() acc_os1, acc_os2, acc_unknown = cal_acc(dset_loaders['test'], netF, netB, netC, True, ENT_THRESHOLD) log_str = 'Task: {}, Iter:{}/{}; Accuracy = {:.2f}% / {:.2f}% / {:.2f}%'.format( args.name, iter_num, max_iter, acc_os2, acc_os1, acc_unknown) args.out_file.write(log_str + '\n') args.out_file.flush() print(log_str + '\n') netF.train() netB.train() if args.issave: torch.save( netF.state_dict(), osp.join(args.output_dir, "target_F_" + args.savename + ".pt")) torch.save( netB.state_dict(), osp.join(args.output_dir, "target_B_" + args.savename + ".pt")) torch.save( netC.state_dict(), osp.join(args.output_dir, "target_C_" + args.savename + ".pt")) return netF, netB, netC
def train_target(args): dset_loaders = data_load(args) ## set base network if args.net[0:3] == 'res': netF = network.ResBase(res_name=args.net).cuda() elif args.net[0:3] == 'vgg': netF = network.VGGBase(vgg_name=args.net).cuda() netB = network.feat_bootleneck(type=args.classifier, feature_dim=netF.in_features, bottleneck_dim=args.bottleneck).cuda() netC = network.feat_classifier(type=args.layer, class_num = args.class_num, bottleneck_dim=args.bottleneck).cuda() if not args.ssl == 0: netR = network.feat_classifier(type='linear', class_num=4, bottleneck_dim=2*args.bottleneck).cuda() netR_dict, acc_rot = train_target_rot(args) netR.load_state_dict(netR_dict) modelpath = args.output_dir_src + '/source_F.pt' netF.load_state_dict(torch.load(modelpath)) modelpath = args.output_dir_src + '/source_B.pt' netB.load_state_dict(torch.load(modelpath)) modelpath = args.output_dir_src + '/source_C.pt' netC.load_state_dict(torch.load(modelpath)) netC.eval() for k, v in netC.named_parameters(): v.requires_grad = False param_group = [] for k, v in netF.named_parameters(): if args.lr_decay1 > 0: param_group += [{'params': v, 'lr': args.lr * args.lr_decay1}] else: v.requires_grad = False for k, v in netB.named_parameters(): if args.lr_decay2 > 0: param_group += [{'params': v, 'lr': args.lr * args.lr_decay2}] else: v.requires_grad = False if not args.ssl == 0: for k, v in netR.named_parameters(): param_group += [{'params': v, 'lr': args.lr * args.lr_decay2}] netR.train() optimizer = optim.SGD(param_group) optimizer = op_copy(optimizer) max_iter = args.max_epoch * len(dset_loaders["target"]) interval_iter = max_iter // args.interval iter_num = 0 while iter_num < max_iter: optimizer.zero_grad() try: inputs_test, _, tar_idx = iter_test.next() except: iter_test = iter(dset_loaders["target"]) inputs_test, _, tar_idx = iter_test.next() if inputs_test.size(0) == 1: continue if iter_num % interval_iter == 0 and args.cls_par > 0: netF.eval() netB.eval() mem_label = obtain_label(dset_loaders['target_te'], netF, netB, netC, args) mem_label = torch.from_numpy(mem_label).cuda() netF.train() netB.train() inputs_test = inputs_test.cuda() iter_num += 1 lr_scheduler(optimizer, iter_num=iter_num, max_iter=max_iter) if args.cls_par > 0: pred = mem_label[tar_idx] features_test = netB(netF(inputs_test)) outputs_test = netC(features_test) if args.cls_par > 0: classifier_loss = nn.CrossEntropyLoss()(outputs_test, pred) classifier_loss *= args.cls_par if iter_num < interval_iter and args.dset == "VISDA-C": classifier_loss *= 0 else: classifier_loss = torch.tensor(0.0).cuda() if args.ent: softmax_out = nn.Softmax(dim=1)(outputs_test) entropy_loss = torch.mean(loss.Entropy(softmax_out)) if args.gent: msoftmax = softmax_out.mean(dim=0) gentropy_loss = torch.sum(-msoftmax * torch.log(msoftmax + args.epsilon)) entropy_loss -= gentropy_loss im_loss = entropy_loss * args.ent_par classifier_loss += im_loss classifier_loss.backward() if not args.ssl == 0: r_labels_target = np.random.randint(0, 4, len(inputs_test)) r_inputs_target = rotation.rotate_batch_with_labels(inputs_test, r_labels_target) r_labels_target = torch.from_numpy(r_labels_target).cuda() r_inputs_target = r_inputs_target.cuda() f_outputs = netB(netF(inputs_test)) f_outputs = f_outputs.detach() f_r_outputs = netB(netF(r_inputs_target)) r_outputs_target = netR(torch.cat((f_outputs, f_r_outputs), 1)) rotation_loss = args.ssl * nn.CrossEntropyLoss()(r_outputs_target, r_labels_target) rotation_loss.backward() optimizer.step() if iter_num % interval_iter == 0 or iter_num == max_iter: netF.eval() netB.eval() if args.dset=='VISDA-C': acc_s_te, acc_list = cal_acc(dset_loaders['test'], netF, netB, netC, True) log_str = 'Task: {}, Iter:{}/{}; Accuracy = {:.2f}%'.format(args.name, iter_num, max_iter, acc_s_te) + '\n' + acc_list else: acc_s_te, _ = cal_acc(dset_loaders['test'], netF, netB, netC, False) log_str = 'Task: {}, Iter:{}/{}; Accuracy = {:.2f}%'.format(args.name, iter_num, max_iter, acc_s_te) args.out_file.write(log_str + '\n') args.out_file.flush() print(log_str+'\n') netF.train() netB.train() if args.issave: torch.save(netF.state_dict(), osp.join(args.output_dir, "target_F_" + args.savename + ".pt")) torch.save(netB.state_dict(), osp.join(args.output_dir, "target_B_" + args.savename + ".pt")) torch.save(netC.state_dict(), osp.join(args.output_dir, "target_C_" + args.savename + ".pt")) return netF, netB, netC
def pretrain_on_source(src_data_loader, src_data_loader_eval, output_dir): ## set base network if params.mode == 'u2m': netF = network.LeNetBase().cuda() elif params.mode == 'm2u': netF = network.LeNetBase().cuda() elif params.mode == 's2m': netF = network.DTNBase().cuda() netB = network.feat_bootleneck(type=params.classifier, feature_dim=netF.in_features, bottleneck_dim=params.bottleneck).cuda() netC = network.feat_classifier(type=params.layer, class_num=params.class_num, bottleneck_dim=params.bottleneck).cuda() param_group = [] learning_rate = params.lr for k, v in netF.named_parameters(): param_group += [{'params': v, 'lr': learning_rate}] for k, v in netB.named_parameters(): param_group += [{'params': v, 'lr': learning_rate}] for k, v in netC.named_parameters(): param_group += [{'params': v, 'lr': learning_rate}] optimizer = optim.SGD(param_group, momentum=0.9, weight_decay=5e-4, nesterov=True) acc_init = 0 out_file = open(os.path.join(output_dir, 'log_pretrain.txt'), 'w') for epoch in range(params.epochs): # scheduler.step() netF.train() netB.train() netC.train() iter_source = iter(src_data_loader) for _, (inputs_source, labels_source) in enumerate(iter_source): if inputs_source.size(0) == 1: continue inputs_source, labels_source = inputs_source.cuda( ), labels_source.cuda() outputs_source = netC(netB(netF(inputs_source))) classifier_loss = network.CrossEntropyLabelSmooth( num_classes=params.class_num, epsilon=params.smooth)(outputs_source, labels_source) optimizer.zero_grad() classifier_loss.backward() optimizer.step() netF.eval() netB.eval() netC.eval() acc_s_tr, _ = network.cal_acc(src_data_loader, netF, netB, netC) acc_s_te, _ = network.cal_acc(src_data_loader_eval, netF, netB, netC) log_str = 'Task: {}, Iter:{}/{}; Accuracy = {:.2f}%/ {:.2f}%'.format( params.mode, epoch + 1, params.epochs, acc_s_tr * 100, acc_s_te * 100) out_file.write(log_str + '\n') out_file.flush() print(log_str + '\n') if acc_s_te >= acc_init: acc_init = acc_s_te best_netF = netF.state_dict() best_netB = netB.state_dict() best_netC = netC.state_dict() torch.save(best_netF, os.path.join(output_dir, "source_F_val.pt")) torch.save(best_netB, os.path.join(output_dir, "source_B_val.pt")) torch.save(best_netC, os.path.join(output_dir, "source_C_val.pt")) return netF, netB, netC
def train_target_rot(args): dset_loaders = data_load(args) ## set base network if args.net[0:3] == 'res': netF = network.ResBase(res_name=args.net).cuda() elif args.net[0:3] == 'vgg': netF = network.VGGBase(vgg_name=args.net).cuda() netB = network.feat_bootleneck(type=args.classifier, feature_dim=netF.in_features, bottleneck_dim=args.bottleneck).cuda() netR = network.feat_classifier(type='linear', class_num=4, bottleneck_dim=2*args.bottleneck).cuda() modelpath = args.output_dir_src + '/source_F.pt' netF.load_state_dict(torch.load(modelpath)) netF.eval() for k, v in netF.named_parameters(): v.requires_grad = False modelpath = args.output_dir_src + '/source_B.pt' netB.load_state_dict(torch.load(modelpath)) netB.eval() for k, v in netB.named_parameters(): v.requires_grad = False param_group = [] for k, v in netR.named_parameters(): param_group += [{'params': v, 'lr': args.lr*1}] netR.train() optimizer = optim.SGD(param_group) optimizer = op_copy(optimizer) max_iter = args.max_epoch * len(dset_loaders["target"]) interval_iter = max_iter // 10 iter_num = 0 rot_acc = 0 while iter_num < max_iter: optimizer.zero_grad() try: inputs_test, _, tar_idx = iter_test.next() except: iter_test = iter(dset_loaders["target"]) inputs_test, _, tar_idx = iter_test.next() if inputs_test.size(0) == 1: continue inputs_test = inputs_test.cuda() iter_num += 1 lr_scheduler(optimizer, iter_num=iter_num, max_iter=max_iter) r_labels_target = np.random.randint(0, 4, len(inputs_test)) r_inputs_target = rotation.rotate_batch_with_labels(inputs_test, r_labels_target) r_labels_target = torch.from_numpy(r_labels_target).cuda() r_inputs_target = r_inputs_target.cuda() f_outputs = netB(netF(inputs_test)) f_r_outputs = netB(netF(r_inputs_target)) r_outputs_target = netR(torch.cat((f_outputs, f_r_outputs), 1)) rotation_loss = nn.CrossEntropyLoss()(r_outputs_target, r_labels_target) rotation_loss.backward() optimizer.step() if iter_num % interval_iter == 0 or iter_num == max_iter: netR.eval() acc_rot = cal_acc_rot(dset_loaders['target'], netF, netB, netR) log_str = 'Task: {}, Iter:{}/{}; Accuracy = {:.2f}%'.format(args.name, iter_num, max_iter, acc_rot) args.out_file.write(log_str + '\n') args.out_file.flush() print(log_str+'\n') netR.train() if rot_acc < acc_rot: rot_acc = acc_rot best_netR = netR.state_dict() log_str = 'Best Accuracy = {:.2f}%'.format(rot_acc) args.out_file.write(log_str + '\n') args.out_file.flush() print(log_str+'\n') return best_netR, rot_acc
def train_source(args): dset_loaders = data_load(args) ## set base network if args.norm_layer == 'batchnorm': norm_layer = nn.BatchNorm2d elif args.norm_layer == 'groupnorm': def gn_helper(planes): return nn.GroupNorm(8, planes) norm_layer = gn_helper if args.net[0:3] == 'res': if '26' in args.net: netF = network.ResCifarBase(26, norm_layer=norm_layer) args.bottleneck = netF.in_features // 2 else: netF = network.ResBase(res_name=args.net, args=args) elif args.net[0:3] == 'vgg': netF = network.VGGBase(vgg_name=args.net) if args.ssl_before_btn: netH = network.ssl_head(ssl_task=args.ssl_task, feature_dim=netF.in_features, embedding_dim=args.embedding_dim) else: netH = network.ssl_head(ssl_task=args.ssl_task, feature_dim=args.bottleneck, embedding_dim=args.embedding_dim) if args.bottleneck != 0: netB = network.feat_bootleneck(type=args.classifier, feature_dim=netF.in_features, bottleneck_dim=args.bottleneck, norm_btn=args.norm_btn) netC = network.feat_classifier(type=args.layer, class_num=args.class_num, bottleneck_dim=args.bottleneck, bias=args.classifier_bias, temp=args.angular_temp, args=args) else: netB = nn.Identity() netC = network.feat_classifier(type=args.layer, class_num=args.class_num, bottleneck_dim=netF.in_features, bias=args.classifier_bias, temp=args.angular_temp, args=args) if args.dataparallel: netF = nn.DataParallel(netF).cuda() netH = nn.DataParallel(netH).cuda() netB = nn.DataParallel(netB).cuda() netC = nn.DataParallel(netC).cuda() else: netF.cuda() netH.cuda() netB.cuda() netC.cuda() param_group = [] learning_rate = args.lr for k, v in netF.named_parameters(): if args.separate_wd and ('bias' in k or 'norm' in k): param_group += [{ 'params': v, 'lr': learning_rate * 0.1, 'weight_decay': 0 }] else: param_group += [{ 'params': v, 'lr': learning_rate * 0.1, 'weight_decay': args.weight_decay }] for k, v in netH.named_parameters(): if args.separate_wd and ('bias' in k or 'norm' in k): param_group += [{ 'params': v, 'lr': learning_rate, 'weight_decay': 0 }] else: param_group += [{ 'params': v, 'lr': learning_rate, 'weight_decay': args.weight_decay }] for k, v in netB.named_parameters(): if args.separate_wd and ('bias' in k or 'norm' in k): param_group += [{ 'params': v, 'lr': learning_rate, 'weight_decay': 0 }] else: param_group += [{ 'params': v, 'lr': learning_rate, 'weight_decay': args.weight_decay }] for k, v in netC.named_parameters(): if args.separate_wd and ('bias' in k or 'norm' in k): param_group += [{ 'params': v, 'lr': learning_rate, 'weight_decay': 0 }] else: param_group += [{ 'params': v, 'lr': learning_rate, 'weight_decay': args.weight_decay }] optimizer = optim.SGD(param_group) optimizer = op_copy(optimizer) acc_init = 0 if args.class_stratified: max_iter = args.max_epoch * len( dset_loaders["source_tr"].batch_sampler) else: max_iter = args.max_epoch * len(dset_loaders["source_tr"]) interval_iter = max_iter // 10 iter_num = 0 epoch = 0 netF.train() netH.train() netB.train() netC.train() if args.use_focal_loss: cls_loss_fn = FocalLoss(alpha=args.focal_alpha, gamma=args.focal_gamma, reduction='mean') else: if args.ce_weighting: w = torch.Tensor(args.ce_weight).cuda() w.requires_grad = False if args.smooth == 0: cls_loss_fn = nn.CrossEntropyLoss(weight=w).cuda() else: cls_loss_fn = CrossEntropyLabelSmooth( num_classes=args.class_num, epsilon=args.smooth, weight=w).cuda() else: if args.smooth == 0: cls_loss_fn = nn.CrossEntropyLoss().cuda() else: cls_loss_fn = CrossEntropyLabelSmooth( num_classes=args.class_num, epsilon=args.smooth).cuda() if args.ssl_task in ['simclr', 'crs']: if args.use_new_ntxent: ssl_loss_fn = SupConLoss(temperature=args.temperature, base_temperature=args.temperature).cuda() else: ssl_loss_fn = NTXentLoss(args.batch_size, args.temperature, True).cuda() elif args.ssl_task in ['supcon', 'crsc']: ssl_loss_fn = SupConLoss(temperature=args.temperature, base_temperature=args.temperature).cuda() elif args.ssl_task == 'ls_supcon': ssl_loss_fn = LabelSmoothedSCLLoss(args.batch_size, args.temperature, args.class_num, args.ssl_smooth) if args.cr_weight > 0: if args.cr_metric == 'cos': dist = nn.CosineSimilarity(dim=1).cuda() elif args.cr_metric == 'l1': dist = nn.PairwiseDistance(p=1).cuda() elif args.cr_metric == 'l2': dist = nn.PairwiseDistance(p=2).cuda() elif args.cr_metric == 'bce': dist = nn.BCEWithLogitsLoss(reduction='sum').cuda() elif args.cr_metric == 'kl': dist = nn.KLDivLoss(reduction='sum').cuda() elif args.cr_metric == 'js': dist = JSDivLoss(reduction='sum').cuda() use_second_pass = (args.ssl_task in ['simclr', 'supcon', 'ls_supcon' ]) and (args.ssl_weight > 0) use_third_pass = (args.cr_weight > 0) or (args.ssl_task in ['crsc', 'crs'] and args.ssl_weight > 0) or (args.cls3) while iter_num < max_iter: try: inputs_source, labels_source = iter_source.next() except: iter_source = iter(dset_loaders["source_tr"]) if args.class_stratified: dset_loaders["source_tr"].batch_sampler.set_epoch(epoch) epoch += 1 inputs_source, labels_source = iter_source.next() try: if inputs_source.size(0) == 1: continue except: if inputs_source[0].size(0) == 1: continue iter_num += 1 lr_scheduler(args, optimizer, iter_num=iter_num, max_iter=max_iter) inputs_source1 = None inputs_source2 = None inputs_source3 = None labels_source = labels_source.cuda() if args.layer in ['add_margin', 'arc_margin', 'shpere']: labels_forward = labels_source else: labels_forward = None if type(inputs_source) is list: inputs_source1 = inputs_source[0].cuda() inputs_source2 = inputs_source[1].cuda() if len(inputs_source) == 3: inputs_source3 = inputs_source[2].cuda() else: inputs_source1 = inputs_source.cuda() if inputs_source1 is not None: f1 = netF(inputs_source1) b1 = netB(f1) outputs_source = netC(b1, labels_forward) if use_second_pass: f2 = netF(inputs_source2) b2 = netB(f2) if use_third_pass: if args.sg3: with torch.no_grad(): f3 = netF(inputs_source3) b3 = netB(f3) c3 = netC(b3, labels_forward) conf = torch.max(F.softmax(c3, dim=1), dim=1)[0] else: f3 = netF(inputs_source3) b3 = netB(f3) c3 = netC(b3, labels_forward) conf = torch.max(F.softmax(c3, dim=1), dim=1)[0] if args.cr_weight > 0: if args.cr_site == 'feat': f_hard = f1 f_weak = f3 elif args.cr_site == 'btn': f_hard = b1 f_weak = b3 elif args.cr_site == 'cls': f_hard = outputs_source f_weak = c3 if args.cr_metric in ['kl', 'js']: f_hard = F.softmax(f_hard, dim=-1) if args.cr_metric in ['bce', 'kl', 'js']: f_weak = F.softmax(f_weak, dim=-1) else: raise NotImplementedError classifier_loss = cls_loss_fn(outputs_source, labels_source) if args.cls3: classifier_loss += cls_loss_fn(c3, labels_source) if args.ssl_weight > 0: if args.ssl_before_btn: z1 = netH(f1, args.norm_feat) if use_second_pass: z2 = netH(f2, args.norm_feat) if use_third_pass: z3 = netH(f3, args.norm_feat) else: z1 = netH(b1, args.norm_feat) if use_second_pass: z2 = netH(b2, args.norm_feat) if use_third_pass: z3 = netH(b3, args.norm_feat) if args.ssl_task in 'simclr': if args.use_new_ntxent: z = torch.cat([z1.unsqueeze(1), z2.unsqueeze(1)], dim=1) ssl_loss = ssl_loss_fn(z) else: ssl_loss = ssl_loss_fn(z1, z2) elif args.ssl_task == 'supcon': z = torch.cat([z1.unsqueeze(1), z2.unsqueeze(1)], dim=1) ssl_loss = ssl_loss_fn(z, labels=labels_source) elif args.ssl_task == 'ls_supcon': ssl_loss = ssl_loss_fn(z1, z2, labels_source) elif args.ssl_task == 'crsc': z = torch.cat([z1.unsqueeze(1), z3.unsqueeze(1)], dim=1) ssl_loss = ssl_loss_fn(z, labels_source) elif args.ssl_task == 'crs': if args.use_new_ntxent: z = torch.cat([z1.unsqueeze(1), z3.unsqueeze(1)], dim=1) ssl_loss = ssl_loss_fn(z) else: ssl_loss = ssl_loss_fn(z1, z3) else: ssl_loss = torch.tensor(0.0).cuda() if args.cr_weight > 0: try: cr_loss = dist(f_hard[conf <= args.cr_threshold], f_weak[conf <= args.cr_threshold]).mean() if args.cr_metric == 'cos': cr_loss *= -1 except: print('Error computing CR loss') cr_loss = torch.tensor(0.0).cuda() else: cr_loss = torch.tensor(0.0).cuda() if args.ent_weight > 0: softmax_out = nn.Softmax(dim=1)(outputs_source) entropy_loss = torch.mean(Entropy(softmax_out)) classifier_loss += args.ent_weight * entropy_loss if args.gent_weight > 0: softmax_out = nn.Softmax(dim=1)(outputs_source) msoftmax = softmax_out.mean(dim=0) gentropy_loss = torch.sum(-msoftmax * torch.log(msoftmax + args.epsilon)) classifier_loss -= args.gent_weight * gentropy_loss loss = classifier_loss + args.ssl_weight * ssl_loss + args.cr_weight * cr_loss optimizer.zero_grad() loss.backward() optimizer.step() if iter_num % interval_iter == 0 or iter_num == max_iter: netF.eval() netH.eval() netB.eval() netC.eval() if args.dset == 'visda-c': acc_s_te, acc_list = cal_acc(dset_loaders['source_te'], netF, netH, netB, netC, args, True) log_str = 'Task: {}, Iter:{}/{}; Accuracy = {:.2f}%'.format( args.name_src, iter_num, max_iter, acc_s_te) + '\n' + acc_list else: acc_s_te, _ = cal_acc(dset_loaders['source_te'], netF, netH, netB, netC, args, False) log_str = 'Task: {}, Iter:{}/{}; Accuracy = {:.2f}%'.format( args.name_src, iter_num, max_iter, acc_s_te) args.out_file.write(log_str + '\n') args.out_file.flush() print(log_str + '\n') if acc_s_te >= acc_init: acc_init = acc_s_te if args.dataparallel: best_netF = netF.module.state_dict() best_netH = netH.module.state_dict() best_netB = netB.module.state_dict() best_netC = netC.module.state_dict() else: best_netF = netF.state_dict() best_netH = netH.state_dict() best_netB = netB.state_dict() best_netC = netC.state_dict() netF.train() netH.train() netB.train() netC.train() torch.save(best_netF, osp.join(args.output_dir_src, "source_F.pt")) torch.save(best_netH, osp.join(args.output_dir_src, "source_H.pt")) torch.save(best_netB, osp.join(args.output_dir_src, "source_B.pt")) torch.save(best_netC, osp.join(args.output_dir_src, "source_C.pt")) return netF, netH, netB, netC
def train_source(args): dset_loaders = data_load(args) ## set base network if args.net[0:3] == 'res': netF = network.ResBase(res_name=args.net).cuda() netB = network.feat_bootleneck(type=args.classifier, feature_dim=netF.in_features, bottleneck_dim=args.bottleneck).cuda() netC = network.feat_classifier(type=args.layer, class_num=args.class_num, bottleneck_dim=args.bottleneck).cuda() param_group = [] learning_rate = args.lr for k, v in netF.named_parameters(): param_group += [{'params': v, 'lr': learning_rate}] for k, v in netB.named_parameters(): param_group += [{'params': v, 'lr': learning_rate * 10}] for k, v in netC.named_parameters(): param_group += [{'params': v, 'lr': learning_rate * 10}] optimizer = optim.SGD(param_group, momentum=0.9, weight_decay=5e-4, nesterov=True) acc_init = 0 for epoch in tqdm(range(args.max_epoch), leave=False): netF.train() netB.train() netC.train() iter_source = iter(dset_loaders['source_tr']) for _, (inputs_source, labels_source) in tqdm(enumerate(iter_source), leave=False): if inputs_source.size(0) == 1: continue inputs_source, labels_source = inputs_source.cuda( ), labels_source.cuda() outputs_source = netC(netB(netF(inputs_source))) classifier_loss = loss.CrossEntropyLabelSmooth( num_classes=args.class_num, epsilon=args.smooth)(outputs_source, labels_source) optimizer.zero_grad() classifier_loss.backward() optimizer.step() if (epoch + 1) % 5 == 0 and args.trte == 'full': netF.eval() netB.eval() netC.eval() acc_s_te, _ = cal_acc(dset_loaders['source_te'], netF, netB, netC, args.dset == 'visda17') log_str = 'Task: {}, Iter:{}/{}; Accuracy = {:.2f}%'.format( args.name_src, epoch + 1, args.max_epoch, acc_s_te * 100) args.out_file.write(log_str + '\n') args.out_file.flush() print(log_str + '\n') best_netF = netF.state_dict() best_netB = netB.state_dict() best_netC = netC.state_dict() torch.save( best_netF, osp.join(args.output_dir_src, 'source_F_' + str(epoch + 1) + '.pt')) torch.save( best_netB, osp.join(args.output_dir_src, 'source_B_' + str(epoch + 1) + '.pt')) torch.save( best_netC, osp.join(args.output_dir_src, 'source_C_' + str(epoch + 1) + '.pt')) if args.trte == 'val': netF.eval() netB.eval() netC.eval() acc_s_tr, _ = cal_acc(dset_loaders['source_tr'], netF, netB, netC) acc_s_te, _ = cal_acc(dset_loaders['source_te'], netF, netB, netC) log_str = 'Task: {}, Iter:{}/{}; Accuracy = {:.2f}%/ {:.2f}%'.format( args.name_src, epoch + 1, args.max_epoch, acc_s_tr * 100, acc_s_te * 100) args.out_file.write(log_str + '\n') args.out_file.flush() print(log_str + '\n') if acc_s_te > acc_init: acc_init = acc_s_te best_netF = netF.state_dict() best_netB = netB.state_dict() best_netC = netC.state_dict() torch.save(best_netF, osp.join(args.output_dir_src, 'source_F_val.pt')) torch.save(best_netB, osp.join(args.output_dir_src, 'source_B_val.pt')) torch.save(best_netC, osp.join(args.output_dir_src, 'source_C_val.pt')) return netF, netB, netC
def test_target(args): dset_loaders = data_load(args) ## set base network if args.norm_layer == 'batchnorm': norm_layer = nn.BatchNorm2d elif args.norm_layer == 'groupnorm': def gn_helper(planes): return nn.GroupNorm(8, planes) norm_layer = gn_helper if args.net[0:3] == 'res': if '26' in args.net: netF = network.ResCifarBase(26, norm_layer=norm_layer) else: netF = network.ResBase(res_name=args.net, args=args) elif args.net[0:3] == 'vgg': netF = network.VGGBase(vgg_name=args.net) if args.ssl_before_btn: netH = network.ssl_head(ssl_task=args.ssl_task, feature_dim=netF.in_features, embedding_dim=args.embedding_dim) else: netH = network.ssl_head(ssl_task=args.ssl_task, feature_dim=args.bottleneck, embedding_dim=args.embedding_dim) if args.bottleneck != 0: netB = network.feat_bootleneck(type=args.classifier, feature_dim=netF.in_features, bottleneck_dim=args.bottleneck, norm_btn=args.norm_btn) netC = network.feat_classifier(type=args.layer, class_num=args.class_num, bottleneck_dim=args.bottleneck, bias=args.classifier_bias, args=args) else: netB = nn.Identity() netC = network.feat_classifier(type=args.layer, class_num=args.class_num, bottleneck_dim=netF.in_features, bias=args.classifier_bias, args=args) args.modelpath = args.output_dir_src + '/source_F.pt' netF.load_state_dict(torch.load(args.modelpath)) args.modelpath = args.output_dir_src + '/source_H.pt' netH.load_state_dict(torch.load(args.modelpath)) try: args.modelpath = args.output_dir_src + '/source_B.pt' netB.load_state_dict(torch.load(args.modelpath)) except: print('Skipped loading btn for version compatibility') args.modelpath = args.output_dir_src + '/source_C.pt' netC.load_state_dict(torch.load(args.modelpath)) if args.dataparallel: netF = nn.DataParallel(netF).cuda() netH = nn.DataParallel(netH).cuda() netB = nn.DataParallel(netB).cuda() netC = nn.DataParallel(netC).cuda() else: netF.cuda() netH.cuda() netB.cuda() netC.cuda() netF.eval() netH.eval() netB.eval() netC.eval() if args.da == 'oda': acc_os1, acc_os2, acc_unknown = cal_acc_oda(dset_loaders['test'], netF, netH, netB, netC) log_str = '\nTraining: {}, Task: {}, Accuracy = {:.2f}% / {:.2f}% / {:.2f}%'.format( args.trte, args.name, acc_os2, acc_os1, acc_unknown) else: if args.dset in ['visda-c', 'CIFAR-10-C', 'CIFAR-100-C']: acc, acc_list = cal_acc(dset_loaders['test'], netF, netH, netB, netC, args, True) log_str = '\nTraining: {}, Task: {}, Accuracy = {:.2f}%'.format( args.trte, args.name, acc) + '\n' + acc_list else: acc, _ = cal_acc(dset_loaders['test'], netF, netH, netB, netC, args, False) log_str = '\nTraining: {}, Task: {}, Accuracy = {:.2f}%'.format( args.trte, args.name, acc) args.out_file.write(log_str) args.out_file.flush() print(log_str)
def train_target(args, zz=''): dset_loaders = data_load(args) ## set base network if args.net[0:3] == 'res': netF = network.ResBase(res_name=args.net).cuda() netB = network.feat_bootleneck(type=args.classifier, feature_dim=netF.in_features, bottleneck_dim=args.bottleneck).cuda() netC = network.feat_classifier(type=args.layer, class_num=args.class_num, bottleneck_dim=args.bottleneck).cuda() args.modelpath = args.output_dir_src + '/source_F_' + str(zz) + '.pt' netF.load_state_dict(torch.load(args.modelpath)) args.modelpath = args.output_dir_src + '/source_B_' + str(zz) + '.pt' netB.load_state_dict(torch.load(args.modelpath)) args.modelpath = args.output_dir_src + '/source_C_' + str(zz) + '.pt' netC.load_state_dict(torch.load(args.modelpath)) netC.eval() for k, v in netC.named_parameters(): v.requires_grad = False param_group = [] for k, v in netF.named_parameters(): if args.lr_decay1 > 0: param_group += [{'params': v, 'lr': args.lr * args.lr_decay1}] else: v.requires_grad = False for k, v in netB.named_parameters(): if args.lr_decay2 > 0: param_group += [{'params': v, 'lr': args.lr * args.lr_decay2}] else: v.requires_grad = False optimizer = optim.SGD(param_group, momentum=0.9, weight_decay=5e-4, nesterov=True) for epoch in tqdm(range(args.max_epoch), leave=False): netF.eval() netB.eval() mem_label = obtain_label(dset_loaders['test'], netF, netB, netC, args) mem_label = torch.from_numpy(mem_label).cuda() netF.train() netB.train() iter_test = iter(dset_loaders['target']) for _, (inputs_test, _, tar_idx) in tqdm(enumerate(iter_test), leave=False): if inputs_test.size(0) == 1: continue inputs_test = inputs_test.cuda() pred = mem_label[tar_idx] features_test = netB(netF(inputs_test)) outputs_test = netC(features_test) classifier_loss = loss.CrossEntropyLabelSmooth( num_classes=args.class_num, epsilon=0)(outputs_test, pred) classifier_loss *= args.cls_par if args.ent: softmax_out = nn.Softmax(dim=1)(outputs_test) entropy_loss = torch.mean(loss.Entropy(softmax_out)) if args.gent: msoftmax = softmax_out.mean(dim=0) gentropy_loss = torch.sum( -msoftmax * torch.log(msoftmax + args.epsilon)) entropy_loss -= gentropy_loss classifier_loss += entropy_loss * args.ent_par optimizer.zero_grad() classifier_loss.backward() optimizer.step() netF.eval() netB.eval() acc, _ = cal_acc(dset_loaders['test'], netF, netB, netC) log_str = 'Task: {}, Iter:{}/{}; Accuracy = {:.2f}%'.format( args.name, epoch + 1, args.max_epoch, acc * 100) args.out_file.write(log_str + '\n') args.out_file.flush() print(log_str + '\n') if args.issave: torch.save( netF.state_dict(), osp.join(args.output_dir, 'target_F_' + args.savename + '.pt')) torch.save( netB.state_dict(), osp.join(args.output_dir, 'target_B_' + args.savename + '.pt')) torch.save( netC.state_dict(), osp.join(args.output_dir, 'target_C_' + args.savename + '.pt')) return netF, netB, netC
def train_target(args): dset_loaders = data_load(args) ## set base network if args.norm_layer == 'batchnorm': norm_layer = nn.BatchNorm2d elif args.norm_layer == 'groupnorm': def gn_helper(planes): return nn.GroupNorm(8, planes) norm_layer = gn_helper if args.net[0:3] == 'res': if '26' in args.net: netF = network.ResCifarBase(26, norm_layer=norm_layer) args.bottleneck = netF.in_features // 2 else: netF = network.ResBase(res_name=args.net, args=args) elif args.net[0:3] == 'vgg': netF = network.VGGBase(vgg_name=args.net) # print(args.ssl_before_btn) if args.ssl_before_btn: netH = network.ssl_head(ssl_task=args.ssl_task, feature_dim=netF.in_features, embedding_dim=args.embedding_dim) else: netH = network.ssl_head(ssl_task=args.ssl_task, feature_dim=args.bottleneck, embedding_dim=args.embedding_dim) if args.bottleneck != 0: netB = network.feat_bootleneck(type=args.classifier, feature_dim=netF.in_features, bottleneck_dim=args.bottleneck, norm_btn=args.norm_btn) if args.reset_running_stats and args.classifier == 'bn': netB.norm.running_mean.fill_(0.) netB.norm.running_var.fill_(1.) if args.reset_bn_params and args.classifier == 'bn': netB.norm.weight.data.fill_(1.) netB.norm.bias.data.fill_(0.) netC = network.feat_classifier(type=args.layer, class_num=args.class_num, bottleneck_dim=args.bottleneck, bias=args.classifier_bias, temp=args.angular_temp, args=args) else: netB = nn.Identity() netC = network.feat_classifier(type=args.layer, class_num=args.class_num, bottleneck_dim=netF.in_features, bias=args.classifier_bias, temp=args.angular_temp, args=args) modelpath = args.output_dir_src + '/source_F.pt' netF.load_state_dict(torch.load(modelpath), strict=False) modelpath = args.output_dir_src + '/source_H.pt' netH.load_state_dict(torch.load(modelpath), strict=False) try: modelpath = args.output_dir_src + '/source_B.pt' netB.load_state_dict(torch.load(modelpath), strict=False) except: print('Skipped loading btn for version compatibility') modelpath = args.output_dir_src + '/source_C.pt' netC.load_state_dict(torch.load(modelpath), strict=False) netC.eval() for k, v in netC.named_parameters(): v.requires_grad = False if args.dataparallel: netF = nn.DataParallel(netF).cuda() netH = nn.DataParallel(netH).cuda() netB = nn.DataParallel(netB).cuda() netC = nn.DataParallel(netC).cuda() else: netF.cuda() netH.cuda() netB.cuda() netC.cuda() param_group = [] for k, v in netF.named_parameters(): if args.lr_decay1 > 0: param_group += [{'params': v, 'lr': args.lr * args.lr_decay1}] else: v.requires_grad = False for k, v in netB.named_parameters(): if args.lr_decay2 > 0: param_group += [{'params': v, 'lr': args.lr * args.lr_decay2}] else: v.requires_grad = False for k, v in netH.named_parameters(): if args.lr_decay2 > 0: param_group += [{'params': v, 'lr': args.lr * args.lr_decay2}] else: v.requires_grad = False if args.ssl_task in ['simclr', 'crs']: ssl_loss_fn = NTXentLoss(args.batch_size, args.temperature, True).cuda() elif args.ssl_task in ['supcon', 'crsc']: ssl_loss_fn = SupConLoss(temperature=args.temperature, base_temperature=args.temperature).cuda() elif args.ssl_task == 'ls_supcon': ssl_loss_fn = LabelSmoothedSCLLoss(args.batch_size, args.temperature, args.class_num, args.ssl_smooth) if args.cr_weight > 0: if args.cr_metric == 'cos': dist = nn.CosineSimilarity(dim=1).cuda() elif args.cr_metric == 'l1': dist = nn.PairwiseDistance(p=1) elif args.cr_metric == 'l2': dist = nn.PairwiseDistance(p=2) elif args.cr_metric == 'bce': dist = nn.BCEWithLogitsLoss(reduction='sum').cuda() elif args.cr_metric == 'kl': dist = nn.KLDivLoss(reduction='sum').cuda() use_second_pass = (args.ssl_task in ['simclr', 'supcon', 'ls_supcon' ]) and (args.ssl_weight > 0) use_third_pass = (args.cr_weight > 0) or (args.ssl_task in ['crsc', 'crs'] and args.ssl_weight > 0) or (args.cls3) optimizer = optim.SGD(param_group) optimizer = op_copy(optimizer) max_iter = args.max_epoch * len(dset_loaders["target"]) interval_iter = max_iter // args.interval iter_num = 0 centroid = None while iter_num < max_iter: try: inputs_test, _, tar_idx = iter_test.next() except: iter_test = iter(dset_loaders["target"]) inputs_test, _, tar_idx = iter_test.next() try: if inputs_test.size(0) == 1: continue except: if inputs_test[0].size(0) == 1: continue if iter_num % interval_iter == 0 and ( args.cls_par > 0 or args.ssl_task in ['supcon', 'ls_supcon', 'crsc']): netF.eval() netH.eval() netB.eval() if centroid is None or args.recompute_centroid: mem_label, mem_conf, centroid, labelset = obtain_label( dset_loaders['pl'], netF, netH, netB, netC, args) mem_label = torch.from_numpy(mem_label).cuda() else: pass netF.train() netH.train() netB.train() inputs_test1 = None inputs_test2 = None inputs_test3 = None pred = mem_label[tar_idx] if type(inputs_test) is list: inputs_test1 = inputs_test[0].cuda() inputs_test2 = inputs_test[1].cuda() if len(inputs_test) == 3: inputs_test3 = inputs_test[2].cuda() else: inputs_test1 = inputs_test.cuda() if args.layer in ['add_margin', 'arc_margin', 'sphere' ] and args.use_margin_forward: labels_forward = pred else: labels_forward = None if inputs_test is not None: f1 = netF(inputs_test1) b1 = netB(f1) outputs_test = netC(b1, labels_forward) if use_second_pass: f2 = netF(inputs_test2) b2 = netB(f2) if use_third_pass: if args.sg3: with torch.no_grad(): f3 = netF(inputs_test3) b3 = netB(f3) c3 = netC(b3, labels_forward) conf = torch.max(F.softmax(c3, dim=1), dim=1)[0] else: f3 = netF(inputs_test3) b3 = netB(f3) c3 = netC(b3, labels_forward) conf = torch.max(F.softmax(c3, dim=1), dim=1)[0] iter_num += 1 lr_scheduler(args, optimizer, iter_num=iter_num, max_iter=max_iter, gamma=args.gamma, power=args.power) pred = compute_pl(args, b3, centroid, labelset) if args.cr_weight > 0: if args.cr_site == 'feat': f_hard = f1 f_weak = f3 elif args.cr_site == 'btn': f_hard = b1 f_weak = b3 elif args.cr_site == 'cls': f_hard = outputs_test f_weak = c3 if args.cr_metric != 'cos': f_hard = F.softmax(f_hard, dim=-1) f_weak = F.softmax(f_weak, dim=-1) else: raise NotImplementedError if args.cls_par > 0: # with torch.no_grad(): # conf, _ = torch.max(F.softmax(outputs_test, dim=-1), dim=-1) # conf = conf.cpu().numpy() conf_cls = mem_conf[tar_idx] #pred = mem_label[tar_idx] if args.cls_smooth > 0: classifier_loss = CrossEntropyLabelSmooth( num_classes=args.class_num, epsilon=args.cls_smooth)( outputs_test[conf_cls >= args.conf_threshold], pred[conf_cls >= args.conf_threshold]) else: classifier_loss = nn.CrossEntropyLoss()( outputs_test[conf_cls >= args.conf_threshold], pred[conf_cls >= args.conf_threshold]) if args.cls3: if args.cls_smooth > 0: classifier_loss = CrossEntropyLabelSmooth( num_classes=args.class_num, epsilon=args.cls_smooth)( c3[conf_cls >= args.conf_threshold], pred[conf_cls >= args.conf_threshold]) else: classifier_loss = nn.CrossEntropyLoss()( c3[conf_cls >= args.conf_threshold], pred[conf_cls >= args.conf_threshold]) classifier_loss *= args.cls_par if iter_num < interval_iter and args.dset == "visda-c": classifier_loss *= 0 else: classifier_loss = torch.tensor(0.0).cuda() if args.ent: softmax_out = nn.Softmax(dim=1)(outputs_test) entropy_loss = torch.mean(loss.Entropy(softmax_out)) if args.gent: msoftmax = softmax_out.mean(dim=0) gentropy_loss = torch.sum(-msoftmax * torch.log(msoftmax + args.epsilon)) entropy_loss -= gentropy_loss im_loss = entropy_loss * args.ent_par classifier_loss += im_loss if args.ssl_weight > 0: if args.ssl_before_btn: z1 = netH(f1, args.norm_feat) if use_second_pass: z2 = netH(f2, args.norm_feat) if use_third_pass: z3 = netH(f3, args.norm_feat) else: z1 = netH(b1, args.norm_feat) if use_second_pass: z2 = netH(b2, args.norm_feat) if use_third_pass: z3 = netH(b3, args.norm_feat) if args.ssl_task == 'simclr': ssl_loss = ssl_loss_fn(z1, z2) elif args.ssl_task == 'supcon': z = torch.cat([z1.unsqueeze(1), z2.unsqueeze(1)], dim=1) pl = mem_label[tar_idx] ssl_loss = ssl_loss_fn(z, pl) elif args.ssl_task == 'ls_supcon': pl = mem_label[tar_idx] ssl_loss = ssl_loss_fn(z1, z2, pl).squeeze() elif args.ssl_task == 'crsc': z = torch.cat([z1.unsqueeze(1), z3.unsqueeze(1)], dim=1) pl = mem_label[tar_idx] ssl_loss = ssl_loss_fn(z, pl) elif args.ssl_task == 'crs': ssl_loss = ssl_loss_fn(z1, z3) classifier_loss += args.ssl_weight * ssl_loss if args.cr_weight > 0: try: cr_loss = dist(f_hard[conf >= args.cr_threshold], f_weak[conf >= args.cr_threshold]).mean() if args.cr_metric == 'cos': cr_loss *= -1 except: print('Error computing CR loss') cr_loss = torch.tensor(0.0).cuda() classifier_loss += args.cr_weight * cr_loss optimizer.zero_grad() classifier_loss.backward() optimizer.step() centroid = update_centroid(args, b3, centroid, c3, labelset) if iter_num % interval_iter == 0 or iter_num == max_iter: netF.eval() netH.eval() netB.eval() if args.dset in ['visda-c', 'CIFAR-10-C', 'CIFAR-100-C']: acc_s_te, acc_list = cal_acc(dset_loaders['test'], netF, netH, netB, netC, True) log_str = 'Task: {}, Iter:{}/{}; Accuracy = {:.2f}%'.format( args.name, iter_num, max_iter, acc_s_te) + '\n' + acc_list else: acc_s_te, _ = cal_acc(dset_loaders['test'], netF, netH, netB, netC, False) log_str = 'Task: {}, Iter:{}/{}; Accuracy = {:.2f}%'.format( args.name, iter_num, max_iter, acc_s_te) args.out_file.write(log_str + '\n') args.out_file.flush() print(log_str + '\n') netF.train() netH.train() netB.train() if args.issave: if args.dataparallel: torch.save( netF.module.state_dict(), osp.join(args.output_dir, "target_F_" + args.savename + ".pt")) torch.save( netH.module.state_dict(), osp.join(args.output_dir, "target_H_" + args.savename + ".pt")) torch.save( netB.module.state_dict(), osp.join(args.output_dir, "target_B_" + args.savename + ".pt")) torch.save( netC.module.state_dict(), osp.join(args.output_dir, "target_C_" + args.savename + ".pt")) else: torch.save( netF.state_dict(), osp.join(args.output_dir, "target_F_" + args.savename + ".pt")) torch.save( netH.state_dict(), osp.join(args.output_dir, "target_H_" + args.savename + ".pt")) torch.save( netB.state_dict(), osp.join(args.output_dir, "target_B_" + args.savename + ".pt")) torch.save( netC.state_dict(), osp.join(args.output_dir, "target_C_" + args.savename + ".pt")) return netF, netH, netB, netC