def test_target(args): dset_loaders = data_load(args) ## set base network if args.net[0:3] == 'res': netF = network.ResBase(res_name=args.net).cuda() elif args.net[0:3] == 'vgg': netF = network.VGGBase(vgg_name=args.net).cuda() netB = network.feat_bottleneck(type=args.classifier, feature_dim=netF.in_features, bottleneck_dim=args.bottleneck).cuda() netC = network.feat_classifier(type=args.layer, class_num=args.class_num, bottleneck_dim=args.bottleneck).cuda() args.modelpath = args.output_dir_src + '/source_F.pt' netF.load_state_dict(torch.load(args.modelpath)) args.modelpath = args.output_dir_src + '/source_B.pt' netB.load_state_dict(torch.load(args.modelpath)) args.modelpath = args.output_dir_src + '/source_C.pt' netC.load_state_dict(torch.load(args.modelpath)) netF.eval() netB.eval() netC.eval() acc, _ = cal_acc(dset_loaders['test'], netF, netB, netC, False) log_str = '\nTraining: {}, Task: {}, Accuracy = {:.2f}%'.format( args.trte, args.name, acc) args.out_file.write(log_str) args.out_file.flush() print(log_str)
def test_target(args): dset_loaders = data_load(args) ## set base network if args.net[0:3] == 'res': netF = network.ResBase(res_name=args.net).cuda() elif args.net[0:3] == 'vgg': netF = network.VGGBase(vgg_name=args.net).cuda() netB = network.feat_bootleneck(type=args.classifier, feature_dim=netF.in_features, bottleneck_dim=args.bottleneck).cuda() netC = network.feat_classifier(type=args.layer, class_num=args.class_num, bottleneck_dim=args.bottleneck).cuda() args.modelpath = args.output_dir_ori + "/target_F_" + args.savename + ".pt" netF.load_state_dict(torch.load(args.modelpath)) args.modelpath = args.output_dir_ori + "/target_B_" + args.savename + ".pt" netB.load_state_dict(torch.load(args.modelpath)) args.modelpath = args.output_dir_ori + "/target_C_" + args.savename + ".pt" netC.load_state_dict(torch.load(args.modelpath)) netF.eval() netB.eval() netC.eval() acc, y, py = cal_acc(dset_loaders['test'], netF, netB, netC) log_str = '\nTask: {}, Accuracy = {:.2f}%'.format(args.name, acc * 100) args.out_file.write(log_str) args.out_file.flush() print(log_str) return y, py
def test_target(args): dset_loaders = data_load(args) ## set base network if args.net[0:3] == 'res': netF = network.ResBase(res_name=args.net).cuda() elif args.net[0:3] == 'vgg': netF = network.VGGBase(vgg_name=args.net).cuda() netB = network.feat_bootleneck(type=args.classifier, feature_dim=netF.in_features, bottleneck_dim=args.bottleneck).cuda() netC = network.feat_classifier(type=args.layer, class_num=args.class_num, bottleneck_dim=args.bottleneck).cuda() wandb.watch(netF) wandb.watch(netB) wandb.watch(netC) args.modelpath = args.output_dir_src + '/source_F.pt' netF.load_state_dict(torch.load(args.modelpath)) args.modelpath = args.output_dir_src + '/source_B.pt' netB.load_state_dict(torch.load(args.modelpath)) args.modelpath = args.output_dir_src + '/source_C.pt' netC.load_state_dict(torch.load(args.modelpath)) if args.test_tar_only: # replace and test on target-trained model args.modelpath = args.output_dir_src + '/target_F_par_0.3.pt' netF.load_state_dict(torch.load(args.modelpath)) args.modelpath = args.output_dir_src + '/target_B_par_0.3.pt' netB.load_state_dict(torch.load(args.modelpath)) args.modelpath = args.output_dir_src + '/target_C_par_0.3.pt' netC.load_state_dict(torch.load(args.modelpath)) netF.eval() netB.eval() netC.eval() if args.da == 'oda': acc_os1, acc_os2, acc_unknown = cal_acc_oda(dset_loaders['test'], netF, netB, netC) log_str = '\nTraining: {}, Task: {}, Accuracy = {:.2f}% / {:.2f}% / {:.2f}%'.format( args.trte, args.name, acc_os2, acc_os1, acc_unknown) else: if args.dset == 'VISDA18' or args.dset == 'VISDA-C': acc, acc_list = cal_acc(dset_loaders['test'], netF, netB, netC, True) log_str = '\nTraining: {}, Task: {}, Accuracy = {:.2f}%'.format( args.trte, args.name, acc) + '\n' + acc_list wandb.log({"test_accuracy": acc}) else: acc, _ = cal_acc(dset_loaders['test'], netF, netB, netC, False) log_str = '\nTraining: {}, Task: {}, Accuracy = {:.2f}%'.format( args.trte, args.name, acc) args.out_file.write(log_str) args.out_file.flush() print(log_str)
def test_target(args): dset_loaders = data_load(args) ## set base network if args.net[0:3] == 'res' or args.net[0:3] == 'vgg': if args.net[0:3] == 'res': netF = network.ResBase(res_name=args.net).cuda() else: netF = network.VGGBase(vgg_name=args.net).cuda() netB = network.feat_bootleneck(type=args.classifier, feature_dim=netF.in_features, bottleneck_dim=args.bottleneck).cuda() netC = network.feat_classifier(type=args.layer, class_num=args.class_num, bottleneck_dim=args.bottleneck).cuda() args.modelpath = args.output_dir_src + '/source_F.pt' netF.load_state_dict(torch.load(args.modelpath)) args.modelpath = args.output_dir_src + '/source_B.pt' netB.load_state_dict(torch.load(args.modelpath)) args.modelpath = args.output_dir_src + '/source_C.pt' netC.load_state_dict(torch.load(args.modelpath)) netF.eval() netB.eval() netC.eval() if args.dset == 'VISDA-RSUT' or args.dset == 'VISDA-RSUT-50' or args.dset == 'VISDA-RSUT-10': # For VisDA, print acc of each class. if args.net[0:3] == 'res' or args.net[0:3] == 'vgg': acc, acc_list, acc_cls_avg = cal_acc(dset_loaders['test'], netF=netF, netB=netB, netC=netC, per_class_flag=True, visda_flag=True) log_str = '\nTraining: {}, Task: {}, Accuracy = {:.2f}%, Cls Avg Acc = {:.2f}%'.format(args.trte, args.name, acc, acc_cls_avg) + '\n' + acc_list else: # For Home, DomainNet, no need to print acc of each class. if args.net[0:3] == 'res' or args.net[0:3] == 'vgg': acc, acc_cls_avg, _ = cal_acc(dset_loaders['test'], netF=netF, netB=netB, netC=netC, per_class_flag=True, visda_flag=False) log_str = '\nTraining: {}, Task: {}, Accuracy = {:.2f}%, Cls Avg Acc = {:.2f}%'.format(args.trte, args.name, acc, acc_cls_avg) args.out_file.write(log_str) args.out_file.flush() print(log_str)
def train_target(args): dset_loaders, txt_tar = data_load(args) dsets = dict() ## set base network if args.net[0:3] == 'res': netF = network.ResBase(res_name=args.net).cuda() elif args.net[0:3] == 'vgg': netF = network.VGGBase(vgg_name=args.net).cuda() netB = network.feat_bootleneck(type=args.classifier, feature_dim=netF.in_features, bottleneck_dim=args.bottleneck).cuda() netC = network.feat_classifier(type=args.layer, class_num=args.class_num, bottleneck_dim=args.bottleneck).cuda() modelpath = args.output_dir_src + '/source_F.pt' netF.load_state_dict(torch.load(modelpath)) modelpath = args.output_dir_src + '/source_B.pt' netB.load_state_dict(torch.load(modelpath)) modelpath = args.output_dir_src + '/source_C.pt' netC.load_state_dict(torch.load(modelpath)) param_group = [] for k, v in netF.named_parameters(): if args.lr_decay1 > 0: param_group += [{'params': v, 'lr': args.lr * args.lr_decay1}] else: v.requires_grad = False for k, v in netB.named_parameters(): if args.lr_decay2 > 0: param_group += [{'params': v, 'lr': args.lr * args.lr_decay2}] else: v.requires_grad = False for k, v in netC.named_parameters(): v.requires_grad = False optimizer = optim.SGD(param_group) optimizer = op_copy(optimizer) max_iter = args.max_epoch * len(dset_loaders["target"]) iter_per_epoch = len(dset_loaders["target"]) print("Iter per epoch: {}".format(iter_per_epoch)) interval_iter = max_iter // args.interval iter_num = 0 if args.paral: netF = torch.nn.DataParallel(netF) netB = torch.nn.DataParallel(netB) netC = torch.nn.DataParallel(netC) netF.train() netB.train() netC.eval() if args.scd_lamb: scd_lamb_init = args.scd_lamb # specify hyperparameter for secondary label correcion manually else: if args.dset[0:5] == "VISDA" : scd_lamb_init = 0.1 elif args.dset[0:11] == "office-home": scd_lamb_init = 0.2 if args.s == 3 and args.t == 2: scd_lamb_init *= 0.1 elif args.dset[0:9] == "domainnet": scd_lamb_init = 0.02 scd_lamb = scd_lamb_init while iter_num < max_iter: k = 1.0 k_s = 0.6 if iter_num % interval_iter == 0 and args.cls_par > 0: # interval_itr = itr per epoch netF.eval() netB.eval() label_prob_dict = obtain_label(dset_loaders['test'], netF, netB, netC, args) mem_label, pseudo_lb_prob = label_prob_dict['primary_lb'], label_prob_dict['primary_lb_prob'] mem_label, pseudo_lb_prob = torch.from_numpy(mem_label).cuda(), torch.from_numpy(pseudo_lb_prob).cuda() if args.scd_label: second_label, second_prob = label_prob_dict['secondary_lb'], label_prob_dict['secondary_lb_prob'] second_label, second_prob = torch.from_numpy(second_label).cuda(), torch.from_numpy(second_prob).cuda() if args.third_label: third_label, third_prob = label_prob_dict['third_lb'], label_prob_dict['third_lb_prob'] third_label, third_prob = torch.from_numpy(third_label).cuda(), torch.from_numpy(third_prob).cuda() if args.fourth_label: fourth_label, fourth_prob = label_prob_dict['fourth_lb'], label_prob_dict['fourth_lb_prob'] fourth_label, fourth_prob = torch.from_numpy(fourth_label).cuda(), torch.from_numpy(fourth_prob).cuda() if args.topk_ent: all_entropy = label_prob_dict['entropy'] all_entropy = torch.from_numpy(all_entropy) if args.dset[0:5] == "VISDA" : if iter_num // iter_per_epoch < 1: k = 0.6 elif iter_num // iter_per_epoch < 2: k = 0.7 elif iter_num // iter_per_epoch < 3: k = 0.8 elif iter_num // iter_per_epoch < 4: k = 0.9 else: k = 1.0 if iter_num // iter_per_epoch >= 8: scd_lamb *= 0.1 elif args.dset[0:11] == "office-home" or args.dset[0:9] == "domainnet": if iter_num // iter_per_epoch < 2: k = 0.2 elif iter_num // iter_per_epoch < 4: k = 0.4 elif iter_num // iter_per_epoch < 8: k = 0.6 elif iter_num // iter_per_epoch < 12: k = 0.8 if args.topk: dsets["target"] = ImageList_PA(txt_tar, mem_label, pseudo_lb_prob, k_low=k, k_up=None, transform=image_train()) dset_loaders["target"] = DataLoader(dsets["target"], batch_size=args.batch_size, shuffle=True, num_workers=args.worker, drop_last=False) if args.topk_ent: dsets["target"] = ImageList_PA(txt_tar, mem_label, -1.0 * all_entropy, k_low=k, k_up=None, transform=image_train()) dset_loaders["target"] = DataLoader(dsets["target"], batch_size=args.batch_size, shuffle=True, num_workers=args.worker, drop_last=False) if args.scd_label: # 2nd label threshold: prob top 60% dsets["target_scd"] = ImageList_PA(txt_tar, second_label, second_prob, k_low=k_s, k_up=None, transform=image_train()) dset_loaders["target_scd"] = DataLoader(dsets["target_scd"], batch_size=args.batch_size, shuffle=True, num_workers=args.worker, drop_last=False) if args.third_label: # 3rd label threshold: prob top 60% dsets["target_third"] = ImageList_PA(txt_tar, third_label, third_prob, k_low=k_s, k_up=None, transform=image_train()) dset_loaders["target_third"] = DataLoader(dsets["target_third"], batch_size=args.batch_size, shuffle=True, num_workers=args.worker, drop_last=False) if args.fourth_label: # 4th label threshold: prob top 60% dsets["target_fourth"] = ImageList_PA(txt_tar, fourth_label, fourth_prob, k_low=k_s, k_up=None, transform=image_train()) dset_loaders["target_fourth"] = DataLoader(dsets["target_fourth"], batch_size=args.batch_size, shuffle=True, num_workers=args.worker, drop_last=False) netF.train() netB.train() try: inputs_test, _, tar_idx = iter_test.next() except: iter_test = iter(dset_loaders["target"]) inputs_test, _, tar_idx = iter_test.next() # tar_idx: chosen indices in current itr if inputs_test.size(0) == 1: continue if args.scd_label: try: inputs_test_scd, _, tar_idx_scd = iter_test_scd.next() except: iter_test_scd = iter(dset_loaders["target_scd"]) inputs_test_scd, _, tar_idx_scd = iter_test_scd.next() if inputs_test_scd.size(0) == 1: continue if args.third_label: try: inputs_test_third, _, tar_idx_third = iter_test_third.next() except: iter_test_third = iter(dset_loaders["target_third"]) inputs_test_third, _, tar_idx_third = iter_test_third.next() if inputs_test_third.size(0) == 1: continue if args.fourth_label: try: inputs_test_fourth, _, tar_idx_fourth = iter_test_fourth.next() except: iter_test_fourth = iter(dset_loaders["target_fourth"]) inputs_test_fourth, _, tar_idx_fourth = iter_test_fourth.next() if inputs_test_fourth.size(0) == 1: continue iter_num += 1 inputs_test = inputs_test.cuda() features_test = netB(netF(inputs_test)) outputs_test = netC(features_test) if args.scd_label: inputs_test_scd = inputs_test_scd.cuda() if inputs_test_scd.ndim == 3: inputs_test_scd = inputs_test_scd.unsqueeze(0) features_test_scd = netB(netF(inputs_test_scd)) outputs_test_scd = netC(features_test_scd) first_prob_of_scd = pseudo_lb_prob[tar_idx_scd] scd_prob = second_prob[tar_idx_scd] if not args.no_mask: mask = (scd_prob / first_prob_of_scd.float()).clamp(max=1.0) else: mask = torch.ones_like(scd_prob).cuda() if args.third_label: inputs_test_third = inputs_test_third.cuda() if inputs_test_third.ndim == 3: inputs_test_third = inputs_test_third.unsqueeze(0) features_test_third = netB(netF(inputs_test_third)) outputs_test_third = netC(features_test_third) first_prob_of_third = pseudo_lb_prob[tar_idx_third] thi_prob = third_prob[tar_idx_third] mask_third = (thi_prob / first_prob_of_third.float()).clamp(max=1.0) if args.fourth_label: inputs_test_fourth = inputs_test_fourth.cuda() if inputs_test_fourth.ndim == 3: inputs_test_fourth = inputs_test_fourth.unsqueeze(0) features_test_fourth = netB(netF(inputs_test_fourth)) outputs_test_fourth = netC(features_test_fourth) first_prob_of_fourth = pseudo_lb_prob[tar_idx_fourth] fth_prob = fourth_prob[tar_idx_fourth] mask_fourth = (fth_prob / first_prob_of_fourth.float()).clamp(max=1.0) if args.intra_dense or args.inter_sep: intra_dist = torch.zeros(1).cuda() inter_dist = torch.zeros(1).cuda() pred = mem_label[tar_idx] same_first = True diff_first = True cos = nn.CosineSimilarity(dim=1, eps=1e-6) for i in range(pred.size(0)): for j in range(i, pred.size(0)): # dist = torch.norm(features_test[i] - features_test[j]) dist = 0.5 * (1 - cos(features_test[i].unsqueeze(0), features_test[j].unsqueeze(0))) if pred[i].item() == pred[j].item(): if same_first: intra_dist = dist.unsqueeze(0) same_first = False else: intra_dist = torch.cat((intra_dist, dist.unsqueeze(0))) else: if diff_first: inter_dist = dist.unsqueeze(0) diff_first = False else: inter_dist = torch.cat((inter_dist, dist.unsqueeze(0))) intra_dist = torch.mean(intra_dist) inter_dist = torch.mean(inter_dist) if args.cls_par > 0: pred = mem_label[tar_idx] classifier_loss = nn.CrossEntropyLoss()(outputs_test, pred) # self-train by pseudo label classifier_loss *= args.cls_par if args.scd_label: pred_scd = second_label[tar_idx_scd] classifier_loss_scd = nn.CrossEntropyLoss(reduction='none')(outputs_test_scd, pred_scd) # self-train by pseudo label classifier_loss_scd = torch.mean(mask * classifier_loss_scd) classifier_loss_scd *= args.cls_par classifier_loss += classifier_loss_scd * scd_lamb if args.third_label: pred_third = third_label[tar_idx_third] classifier_loss_third = nn.CrossEntropyLoss(reduction='none')(outputs_test_third, pred_third) # self-train by pseudo label classifier_loss_third = torch.mean(mask_third * classifier_loss_third) classifier_loss_third *= args.cls_par classifier_loss += classifier_loss_third * scd_lamb # TODO: better weighting is possible if args.fourth_label: pred_fourth = fourth_label[tar_idx_fourth] classifier_loss_fourth = nn.CrossEntropyLoss(reduction='none')(outputs_test_fourth, pred_fourth) # self-train by pseudo label classifier_loss_fourth = torch.mean(mask_fourth * classifier_loss_fourth) classifier_loss_fourth *= args.cls_par classifier_loss += classifier_loss_fourth * scd_lamb # TODO: better weighting is possible if iter_num < interval_iter and (args.dset == "VISDA-C" or args.dset == "VISDA-RSUT" or args.dset == 'VISDA-RSUT-50' or args.dset == 'VISDA-RSUT-10'): classifier_loss *= 0 else: classifier_loss = torch.tensor(0.0).cuda() if args.ent: softmax_out = nn.Softmax(dim=1)(outputs_test) entropy_loss = torch.mean(loss.Entropy(softmax_out)) # Minimize local entropy if args.scd_label: softmax_out_scd = nn.Softmax(dim=1)(outputs_test_scd) entropy_loss_scd = torch.mean(mask * loss.Entropy(softmax_out_scd)) # Minimize local entropy if args.third_label: softmax_out_third = nn.Softmax(dim=1)(outputs_test_third) entropy_loss_third = torch.mean(mask_third * loss.Entropy(softmax_out_third)) # Minimize local entropy if args.fourth_label: softmax_out_fourth = nn.Softmax(dim=1)(outputs_test_fourth) entropy_loss_fourth = torch.mean(mask_fourth * loss.Entropy(softmax_out_fourth)) # Minimize local entropy if args.gent: msoftmax = softmax_out.mean(dim=0) gentropy_loss = torch.sum(-msoftmax * torch.log(msoftmax + args.epsilon)) entropy_loss -= gentropy_loss # Maximize global entropy if args.scd_label: msoftmax_scd = softmax_out_scd.mean(dim=0) gentropy_loss_scd = torch.sum(-msoftmax_scd * torch.log(msoftmax_scd + args.epsilon)) entropy_loss_scd -= gentropy_loss_scd # Maximize global entropy if args.third_label: msoftmax_third = softmax_out_third.mean(dim=0) gentropy_loss_third = torch.sum(-msoftmax_third * torch.log(msoftmax_third + args.epsilon)) entropy_loss_third -= gentropy_loss_third # Maximize global entropy if args.fourth_label: msoftmax_fourth = softmax_out_fourth.mean(dim=0) gentropy_loss_fourth = torch.sum(-msoftmax_fourth * torch.log(msoftmax_fourth + args.epsilon)) entropy_loss_fourth -= gentropy_loss_fourth # Maximize global entropy im_loss = entropy_loss * args.ent_par if args.scd_label: im_loss += entropy_loss_scd * args.ent_par * scd_lamb if args.third_label: im_loss += entropy_loss_third * args.ent_par * scd_lamb # TODO: better weighting is possible if args.fourth_label: im_loss += entropy_loss_fourth * args.ent_par * scd_lamb # TODO: better weighting is possible classifier_loss += im_loss if args.intra_dense: classifier_loss += args.lamb_intra * intra_dist.squeeze() if args.inter_sep: classifier_loss += args.lamb_inter * inter_dist.squeeze() optimizer.zero_grad() classifier_loss.backward() optimizer.step() lr_scheduler(optimizer, iter_num=iter_num, max_iter=max_iter) if iter_num % interval_iter == 0 or iter_num == max_iter: netF.eval() netB.eval() if args.dset == 'VISDA-RSUT' or args.dset == 'VISDA-RSUT-50' or args.dset == 'VISDA-RSUT-10': # For VisDA, print the acc of each cls acc_s_te, acc_list, acc_cls_avg = cal_acc(dset_loaders['test'], netF, netB, netC, visda_flag=True) # flag for VisDA -> need cls avg acc. log_str = 'Task: {}, Iter:{}/{}; Accuracy = {:.2f}%, Cls Avg Acc = {:.2f}%'.format(args.name, iter_num, max_iter, acc_s_te, acc_cls_avg) + '\n' + acc_list else: # In imbalanced setting, use per-class avg acc as metric # For Office-Home, DomainNet, no need to print the acc of each cls acc_s_te, acc_cls_avg, _ = cal_acc(dset_loaders['test'], netF, netB, netC, visda_flag=False) log_str = 'Task: {}, Iter:{}/{}; Accuracy = {:.2f}%, Cls Avg Acc = {:.2f}%'.format(args.name, iter_num, max_iter, acc_s_te, acc_cls_avg) args.out_file.write(log_str + '\n') args.out_file.flush() print(log_str + '\n') netF.train() netB.train() if args.issave: torch.save(netF.state_dict(), osp.join(args.output_dir, "target_F_" + args.savename + ".pt")) torch.save(netB.state_dict(), osp.join(args.output_dir, "target_B_" + args.savename + ".pt")) torch.save(netC.state_dict(), osp.join(args.output_dir, "target_C_" + args.savename + ".pt")) return netF, netB, netC
def train_target(args): dset_loaders = data_load(args) ## set base network if args.norm_layer == 'batchnorm': norm_layer = nn.BatchNorm2d elif args.norm_layer == 'groupnorm': def gn_helper(planes): return nn.GroupNorm(8, planes) norm_layer = gn_helper if args.net[0:3] == 'res': if '26' in args.net: netF = network.ResCifarBase(26, norm_layer=norm_layer) args.bottleneck = netF.in_features // 2 else: netF = network.ResBase(res_name=args.net, args=args) elif args.net[0:3] == 'vgg': netF = network.VGGBase(vgg_name=args.net) # print(args.ssl_before_btn) if args.ssl_before_btn: netH = network.ssl_head(ssl_task=args.ssl_task, feature_dim=netF.in_features, embedding_dim=args.embedding_dim) else: netH = network.ssl_head(ssl_task=args.ssl_task, feature_dim=args.bottleneck, embedding_dim=args.embedding_dim) if args.bottleneck != 0: netB = network.feat_bootleneck(type=args.classifier, feature_dim=netF.in_features, bottleneck_dim=args.bottleneck, norm_btn=args.norm_btn) if args.reset_running_stats and args.classifier == 'bn': netB.norm.running_mean.fill_(0.) netB.norm.running_var.fill_(1.) if args.reset_bn_params and args.classifier == 'bn': netB.norm.weight.data.fill_(1.) netB.norm.bias.data.fill_(0.) netC = network.feat_classifier(type=args.layer, class_num=args.class_num, bottleneck_dim=args.bottleneck, bias=args.classifier_bias, temp=args.angular_temp, args=args) else: netB = nn.Identity() netC = network.feat_classifier(type=args.layer, class_num=args.class_num, bottleneck_dim=netF.in_features, bias=args.classifier_bias, temp=args.angular_temp, args=args) modelpath = args.output_dir_src + '/source_F.pt' netF.load_state_dict(torch.load(modelpath), strict=False) modelpath = args.output_dir_src + '/source_H.pt' netH.load_state_dict(torch.load(modelpath), strict=False) try: modelpath = args.output_dir_src + '/source_B.pt' netB.load_state_dict(torch.load(modelpath), strict=False) except: print('Skipped loading btn for version compatibility') modelpath = args.output_dir_src + '/source_C.pt' netC.load_state_dict(torch.load(modelpath), strict=False) netC.eval() for k, v in netC.named_parameters(): v.requires_grad = False if args.dataparallel: netF = nn.DataParallel(netF).cuda() netH = nn.DataParallel(netH).cuda() netB = nn.DataParallel(netB).cuda() netC = nn.DataParallel(netC).cuda() else: netF.cuda() netH.cuda() netB.cuda() netC.cuda() param_group = [] for k, v in netF.named_parameters(): if args.lr_decay1 > 0: param_group += [{'params': v, 'lr': args.lr * args.lr_decay1}] else: v.requires_grad = False for k, v in netB.named_parameters(): if args.lr_decay2 > 0: param_group += [{'params': v, 'lr': args.lr * args.lr_decay2}] else: v.requires_grad = False for k, v in netH.named_parameters(): if args.lr_decay2 > 0: param_group += [{'params': v, 'lr': args.lr * args.lr_decay2}] else: v.requires_grad = False if args.ssl_task in ['simclr', 'crs']: ssl_loss_fn = NTXentLoss(args.batch_size, args.temperature, True).cuda() elif args.ssl_task in ['supcon', 'crsc']: ssl_loss_fn = SupConLoss(temperature=args.temperature, base_temperature=args.temperature).cuda() elif args.ssl_task == 'ls_supcon': ssl_loss_fn = LabelSmoothedSCLLoss(args.batch_size, args.temperature, args.class_num, args.ssl_smooth) if args.cr_weight > 0: if args.cr_metric == 'cos': dist = nn.CosineSimilarity(dim=1).cuda() elif args.cr_metric == 'l1': dist = nn.PairwiseDistance(p=1) elif args.cr_metric == 'l2': dist = nn.PairwiseDistance(p=2) elif args.cr_metric == 'bce': dist = nn.BCEWithLogitsLoss(reduction='sum').cuda() elif args.cr_metric == 'kl': dist = nn.KLDivLoss(reduction='sum').cuda() use_second_pass = (args.ssl_task in ['simclr', 'supcon', 'ls_supcon' ]) and (args.ssl_weight > 0) use_third_pass = (args.cr_weight > 0) or (args.ssl_task in ['crsc', 'crs'] and args.ssl_weight > 0) or (args.cls3) optimizer = optim.SGD(param_group) optimizer = op_copy(optimizer) max_iter = args.max_epoch * len(dset_loaders["target"]) interval_iter = max_iter // args.interval iter_num = 0 centroid = None while iter_num < max_iter: try: inputs_test, _, tar_idx = iter_test.next() except: iter_test = iter(dset_loaders["target"]) inputs_test, _, tar_idx = iter_test.next() try: if inputs_test.size(0) == 1: continue except: if inputs_test[0].size(0) == 1: continue if iter_num % interval_iter == 0 and ( args.cls_par > 0 or args.ssl_task in ['supcon', 'ls_supcon', 'crsc']): netF.eval() netH.eval() netB.eval() if centroid is None or args.recompute_centroid: mem_label, mem_conf, centroid, labelset = obtain_label( dset_loaders['pl'], netF, netH, netB, netC, args) mem_label = torch.from_numpy(mem_label).cuda() else: pass netF.train() netH.train() netB.train() inputs_test1 = None inputs_test2 = None inputs_test3 = None pred = mem_label[tar_idx] if type(inputs_test) is list: inputs_test1 = inputs_test[0].cuda() inputs_test2 = inputs_test[1].cuda() if len(inputs_test) == 3: inputs_test3 = inputs_test[2].cuda() else: inputs_test1 = inputs_test.cuda() if args.layer in ['add_margin', 'arc_margin', 'sphere' ] and args.use_margin_forward: labels_forward = pred else: labels_forward = None if inputs_test is not None: f1 = netF(inputs_test1) b1 = netB(f1) outputs_test = netC(b1, labels_forward) if use_second_pass: f2 = netF(inputs_test2) b2 = netB(f2) if use_third_pass: if args.sg3: with torch.no_grad(): f3 = netF(inputs_test3) b3 = netB(f3) c3 = netC(b3, labels_forward) conf = torch.max(F.softmax(c3, dim=1), dim=1)[0] else: f3 = netF(inputs_test3) b3 = netB(f3) c3 = netC(b3, labels_forward) conf = torch.max(F.softmax(c3, dim=1), dim=1)[0] iter_num += 1 lr_scheduler(args, optimizer, iter_num=iter_num, max_iter=max_iter, gamma=args.gamma, power=args.power) pred = compute_pl(args, b3, centroid, labelset) if args.cr_weight > 0: if args.cr_site == 'feat': f_hard = f1 f_weak = f3 elif args.cr_site == 'btn': f_hard = b1 f_weak = b3 elif args.cr_site == 'cls': f_hard = outputs_test f_weak = c3 if args.cr_metric != 'cos': f_hard = F.softmax(f_hard, dim=-1) f_weak = F.softmax(f_weak, dim=-1) else: raise NotImplementedError if args.cls_par > 0: # with torch.no_grad(): # conf, _ = torch.max(F.softmax(outputs_test, dim=-1), dim=-1) # conf = conf.cpu().numpy() conf_cls = mem_conf[tar_idx] #pred = mem_label[tar_idx] if args.cls_smooth > 0: classifier_loss = CrossEntropyLabelSmooth( num_classes=args.class_num, epsilon=args.cls_smooth)( outputs_test[conf_cls >= args.conf_threshold], pred[conf_cls >= args.conf_threshold]) else: classifier_loss = nn.CrossEntropyLoss()( outputs_test[conf_cls >= args.conf_threshold], pred[conf_cls >= args.conf_threshold]) if args.cls3: if args.cls_smooth > 0: classifier_loss = CrossEntropyLabelSmooth( num_classes=args.class_num, epsilon=args.cls_smooth)( c3[conf_cls >= args.conf_threshold], pred[conf_cls >= args.conf_threshold]) else: classifier_loss = nn.CrossEntropyLoss()( c3[conf_cls >= args.conf_threshold], pred[conf_cls >= args.conf_threshold]) classifier_loss *= args.cls_par if iter_num < interval_iter and args.dset == "visda-c": classifier_loss *= 0 else: classifier_loss = torch.tensor(0.0).cuda() if args.ent: softmax_out = nn.Softmax(dim=1)(outputs_test) entropy_loss = torch.mean(loss.Entropy(softmax_out)) if args.gent: msoftmax = softmax_out.mean(dim=0) gentropy_loss = torch.sum(-msoftmax * torch.log(msoftmax + args.epsilon)) entropy_loss -= gentropy_loss im_loss = entropy_loss * args.ent_par classifier_loss += im_loss if args.ssl_weight > 0: if args.ssl_before_btn: z1 = netH(f1, args.norm_feat) if use_second_pass: z2 = netH(f2, args.norm_feat) if use_third_pass: z3 = netH(f3, args.norm_feat) else: z1 = netH(b1, args.norm_feat) if use_second_pass: z2 = netH(b2, args.norm_feat) if use_third_pass: z3 = netH(b3, args.norm_feat) if args.ssl_task == 'simclr': ssl_loss = ssl_loss_fn(z1, z2) elif args.ssl_task == 'supcon': z = torch.cat([z1.unsqueeze(1), z2.unsqueeze(1)], dim=1) pl = mem_label[tar_idx] ssl_loss = ssl_loss_fn(z, pl) elif args.ssl_task == 'ls_supcon': pl = mem_label[tar_idx] ssl_loss = ssl_loss_fn(z1, z2, pl).squeeze() elif args.ssl_task == 'crsc': z = torch.cat([z1.unsqueeze(1), z3.unsqueeze(1)], dim=1) pl = mem_label[tar_idx] ssl_loss = ssl_loss_fn(z, pl) elif args.ssl_task == 'crs': ssl_loss = ssl_loss_fn(z1, z3) classifier_loss += args.ssl_weight * ssl_loss if args.cr_weight > 0: try: cr_loss = dist(f_hard[conf >= args.cr_threshold], f_weak[conf >= args.cr_threshold]).mean() if args.cr_metric == 'cos': cr_loss *= -1 except: print('Error computing CR loss') cr_loss = torch.tensor(0.0).cuda() classifier_loss += args.cr_weight * cr_loss optimizer.zero_grad() classifier_loss.backward() optimizer.step() centroid = update_centroid(args, b3, centroid, c3, labelset) if iter_num % interval_iter == 0 or iter_num == max_iter: netF.eval() netH.eval() netB.eval() if args.dset in ['visda-c', 'CIFAR-10-C', 'CIFAR-100-C']: acc_s_te, acc_list = cal_acc(dset_loaders['test'], netF, netH, netB, netC, True) log_str = 'Task: {}, Iter:{}/{}; Accuracy = {:.2f}%'.format( args.name, iter_num, max_iter, acc_s_te) + '\n' + acc_list else: acc_s_te, _ = cal_acc(dset_loaders['test'], netF, netH, netB, netC, False) log_str = 'Task: {}, Iter:{}/{}; Accuracy = {:.2f}%'.format( args.name, iter_num, max_iter, acc_s_te) args.out_file.write(log_str + '\n') args.out_file.flush() print(log_str + '\n') netF.train() netH.train() netB.train() if args.issave: if args.dataparallel: torch.save( netF.module.state_dict(), osp.join(args.output_dir, "target_F_" + args.savename + ".pt")) torch.save( netH.module.state_dict(), osp.join(args.output_dir, "target_H_" + args.savename + ".pt")) torch.save( netB.module.state_dict(), osp.join(args.output_dir, "target_B_" + args.savename + ".pt")) torch.save( netC.module.state_dict(), osp.join(args.output_dir, "target_C_" + args.savename + ".pt")) else: torch.save( netF.state_dict(), osp.join(args.output_dir, "target_F_" + args.savename + ".pt")) torch.save( netH.state_dict(), osp.join(args.output_dir, "target_H_" + args.savename + ".pt")) torch.save( netB.state_dict(), osp.join(args.output_dir, "target_B_" + args.savename + ".pt")) torch.save( netC.state_dict(), osp.join(args.output_dir, "target_C_" + args.savename + ".pt")) return netF, netH, netB, netC
def test_target(args): dset_loaders = data_load(args) ## set base network if args.norm_layer == 'batchnorm': norm_layer = nn.BatchNorm2d elif args.norm_layer == 'groupnorm': def gn_helper(planes): return nn.GroupNorm(8, planes) norm_layer = gn_helper if args.net[0:3] == 'res': if '26' in args.net: netF = network.ResCifarBase(26, norm_layer=norm_layer) else: netF = network.ResBase(res_name=args.net, args=args) elif args.net[0:3] == 'vgg': netF = network.VGGBase(vgg_name=args.net) if args.ssl_before_btn: netH = network.ssl_head(ssl_task=args.ssl_task, feature_dim=netF.in_features, embedding_dim=args.embedding_dim) else: netH = network.ssl_head(ssl_task=args.ssl_task, feature_dim=args.bottleneck, embedding_dim=args.embedding_dim) if args.bottleneck != 0: netB = network.feat_bootleneck(type=args.classifier, feature_dim=netF.in_features, bottleneck_dim=args.bottleneck, norm_btn=args.norm_btn) netC = network.feat_classifier(type=args.layer, class_num=args.class_num, bottleneck_dim=args.bottleneck, bias=args.classifier_bias, args=args) else: netB = nn.Identity() netC = network.feat_classifier(type=args.layer, class_num=args.class_num, bottleneck_dim=netF.in_features, bias=args.classifier_bias, args=args) args.modelpath = args.output_dir_src + '/source_F.pt' netF.load_state_dict(torch.load(args.modelpath)) args.modelpath = args.output_dir_src + '/source_H.pt' netH.load_state_dict(torch.load(args.modelpath)) try: args.modelpath = args.output_dir_src + '/source_B.pt' netB.load_state_dict(torch.load(args.modelpath)) except: print('Skipped loading btn for version compatibility') args.modelpath = args.output_dir_src + '/source_C.pt' netC.load_state_dict(torch.load(args.modelpath)) if args.dataparallel: netF = nn.DataParallel(netF).cuda() netH = nn.DataParallel(netH).cuda() netB = nn.DataParallel(netB).cuda() netC = nn.DataParallel(netC).cuda() else: netF.cuda() netH.cuda() netB.cuda() netC.cuda() netF.eval() netH.eval() netB.eval() netC.eval() if args.da == 'oda': acc_os1, acc_os2, acc_unknown = cal_acc_oda(dset_loaders['test'], netF, netH, netB, netC) log_str = '\nTraining: {}, Task: {}, Accuracy = {:.2f}% / {:.2f}% / {:.2f}%'.format( args.trte, args.name, acc_os2, acc_os1, acc_unknown) else: if args.dset in ['visda-c', 'CIFAR-10-C', 'CIFAR-100-C']: acc, acc_list = cal_acc(dset_loaders['test'], netF, netH, netB, netC, args, True) log_str = '\nTraining: {}, Task: {}, Accuracy = {:.2f}%'.format( args.trte, args.name, acc) + '\n' + acc_list else: acc, _ = cal_acc(dset_loaders['test'], netF, netH, netB, netC, args, False) log_str = '\nTraining: {}, Task: {}, Accuracy = {:.2f}%'.format( args.trte, args.name, acc) args.out_file.write(log_str) args.out_file.flush() print(log_str)
def train_source(args): dset_loaders = data_load(args) ## set base network if args.norm_layer == 'batchnorm': norm_layer = nn.BatchNorm2d elif args.norm_layer == 'groupnorm': def gn_helper(planes): return nn.GroupNorm(8, planes) norm_layer = gn_helper if args.net[0:3] == 'res': if '26' in args.net: netF = network.ResCifarBase(26, norm_layer=norm_layer) args.bottleneck = netF.in_features // 2 else: netF = network.ResBase(res_name=args.net, args=args) elif args.net[0:3] == 'vgg': netF = network.VGGBase(vgg_name=args.net) if args.ssl_before_btn: netH = network.ssl_head(ssl_task=args.ssl_task, feature_dim=netF.in_features, embedding_dim=args.embedding_dim) else: netH = network.ssl_head(ssl_task=args.ssl_task, feature_dim=args.bottleneck, embedding_dim=args.embedding_dim) if args.bottleneck != 0: netB = network.feat_bootleneck(type=args.classifier, feature_dim=netF.in_features, bottleneck_dim=args.bottleneck, norm_btn=args.norm_btn) netC = network.feat_classifier(type=args.layer, class_num=args.class_num, bottleneck_dim=args.bottleneck, bias=args.classifier_bias, temp=args.angular_temp, args=args) else: netB = nn.Identity() netC = network.feat_classifier(type=args.layer, class_num=args.class_num, bottleneck_dim=netF.in_features, bias=args.classifier_bias, temp=args.angular_temp, args=args) if args.dataparallel: netF = nn.DataParallel(netF).cuda() netH = nn.DataParallel(netH).cuda() netB = nn.DataParallel(netB).cuda() netC = nn.DataParallel(netC).cuda() else: netF.cuda() netH.cuda() netB.cuda() netC.cuda() param_group = [] learning_rate = args.lr for k, v in netF.named_parameters(): if args.separate_wd and ('bias' in k or 'norm' in k): param_group += [{ 'params': v, 'lr': learning_rate * 0.1, 'weight_decay': 0 }] else: param_group += [{ 'params': v, 'lr': learning_rate * 0.1, 'weight_decay': args.weight_decay }] for k, v in netH.named_parameters(): if args.separate_wd and ('bias' in k or 'norm' in k): param_group += [{ 'params': v, 'lr': learning_rate, 'weight_decay': 0 }] else: param_group += [{ 'params': v, 'lr': learning_rate, 'weight_decay': args.weight_decay }] for k, v in netB.named_parameters(): if args.separate_wd and ('bias' in k or 'norm' in k): param_group += [{ 'params': v, 'lr': learning_rate, 'weight_decay': 0 }] else: param_group += [{ 'params': v, 'lr': learning_rate, 'weight_decay': args.weight_decay }] for k, v in netC.named_parameters(): if args.separate_wd and ('bias' in k or 'norm' in k): param_group += [{ 'params': v, 'lr': learning_rate, 'weight_decay': 0 }] else: param_group += [{ 'params': v, 'lr': learning_rate, 'weight_decay': args.weight_decay }] optimizer = optim.SGD(param_group) optimizer = op_copy(optimizer) acc_init = 0 if args.class_stratified: max_iter = args.max_epoch * len( dset_loaders["source_tr"].batch_sampler) else: max_iter = args.max_epoch * len(dset_loaders["source_tr"]) interval_iter = max_iter // 10 iter_num = 0 epoch = 0 netF.train() netH.train() netB.train() netC.train() if args.use_focal_loss: cls_loss_fn = FocalLoss(alpha=args.focal_alpha, gamma=args.focal_gamma, reduction='mean') else: if args.ce_weighting: w = torch.Tensor(args.ce_weight).cuda() w.requires_grad = False if args.smooth == 0: cls_loss_fn = nn.CrossEntropyLoss(weight=w).cuda() else: cls_loss_fn = CrossEntropyLabelSmooth( num_classes=args.class_num, epsilon=args.smooth, weight=w).cuda() else: if args.smooth == 0: cls_loss_fn = nn.CrossEntropyLoss().cuda() else: cls_loss_fn = CrossEntropyLabelSmooth( num_classes=args.class_num, epsilon=args.smooth).cuda() if args.ssl_task in ['simclr', 'crs']: if args.use_new_ntxent: ssl_loss_fn = SupConLoss(temperature=args.temperature, base_temperature=args.temperature).cuda() else: ssl_loss_fn = NTXentLoss(args.batch_size, args.temperature, True).cuda() elif args.ssl_task in ['supcon', 'crsc']: ssl_loss_fn = SupConLoss(temperature=args.temperature, base_temperature=args.temperature).cuda() elif args.ssl_task == 'ls_supcon': ssl_loss_fn = LabelSmoothedSCLLoss(args.batch_size, args.temperature, args.class_num, args.ssl_smooth) if args.cr_weight > 0: if args.cr_metric == 'cos': dist = nn.CosineSimilarity(dim=1).cuda() elif args.cr_metric == 'l1': dist = nn.PairwiseDistance(p=1).cuda() elif args.cr_metric == 'l2': dist = nn.PairwiseDistance(p=2).cuda() elif args.cr_metric == 'bce': dist = nn.BCEWithLogitsLoss(reduction='sum').cuda() elif args.cr_metric == 'kl': dist = nn.KLDivLoss(reduction='sum').cuda() elif args.cr_metric == 'js': dist = JSDivLoss(reduction='sum').cuda() use_second_pass = (args.ssl_task in ['simclr', 'supcon', 'ls_supcon' ]) and (args.ssl_weight > 0) use_third_pass = (args.cr_weight > 0) or (args.ssl_task in ['crsc', 'crs'] and args.ssl_weight > 0) or (args.cls3) while iter_num < max_iter: try: inputs_source, labels_source = iter_source.next() except: iter_source = iter(dset_loaders["source_tr"]) if args.class_stratified: dset_loaders["source_tr"].batch_sampler.set_epoch(epoch) epoch += 1 inputs_source, labels_source = iter_source.next() try: if inputs_source.size(0) == 1: continue except: if inputs_source[0].size(0) == 1: continue iter_num += 1 lr_scheduler(args, optimizer, iter_num=iter_num, max_iter=max_iter) inputs_source1 = None inputs_source2 = None inputs_source3 = None labels_source = labels_source.cuda() if args.layer in ['add_margin', 'arc_margin', 'shpere']: labels_forward = labels_source else: labels_forward = None if type(inputs_source) is list: inputs_source1 = inputs_source[0].cuda() inputs_source2 = inputs_source[1].cuda() if len(inputs_source) == 3: inputs_source3 = inputs_source[2].cuda() else: inputs_source1 = inputs_source.cuda() if inputs_source1 is not None: f1 = netF(inputs_source1) b1 = netB(f1) outputs_source = netC(b1, labels_forward) if use_second_pass: f2 = netF(inputs_source2) b2 = netB(f2) if use_third_pass: if args.sg3: with torch.no_grad(): f3 = netF(inputs_source3) b3 = netB(f3) c3 = netC(b3, labels_forward) conf = torch.max(F.softmax(c3, dim=1), dim=1)[0] else: f3 = netF(inputs_source3) b3 = netB(f3) c3 = netC(b3, labels_forward) conf = torch.max(F.softmax(c3, dim=1), dim=1)[0] if args.cr_weight > 0: if args.cr_site == 'feat': f_hard = f1 f_weak = f3 elif args.cr_site == 'btn': f_hard = b1 f_weak = b3 elif args.cr_site == 'cls': f_hard = outputs_source f_weak = c3 if args.cr_metric in ['kl', 'js']: f_hard = F.softmax(f_hard, dim=-1) if args.cr_metric in ['bce', 'kl', 'js']: f_weak = F.softmax(f_weak, dim=-1) else: raise NotImplementedError classifier_loss = cls_loss_fn(outputs_source, labels_source) if args.cls3: classifier_loss += cls_loss_fn(c3, labels_source) if args.ssl_weight > 0: if args.ssl_before_btn: z1 = netH(f1, args.norm_feat) if use_second_pass: z2 = netH(f2, args.norm_feat) if use_third_pass: z3 = netH(f3, args.norm_feat) else: z1 = netH(b1, args.norm_feat) if use_second_pass: z2 = netH(b2, args.norm_feat) if use_third_pass: z3 = netH(b3, args.norm_feat) if args.ssl_task in 'simclr': if args.use_new_ntxent: z = torch.cat([z1.unsqueeze(1), z2.unsqueeze(1)], dim=1) ssl_loss = ssl_loss_fn(z) else: ssl_loss = ssl_loss_fn(z1, z2) elif args.ssl_task == 'supcon': z = torch.cat([z1.unsqueeze(1), z2.unsqueeze(1)], dim=1) ssl_loss = ssl_loss_fn(z, labels=labels_source) elif args.ssl_task == 'ls_supcon': ssl_loss = ssl_loss_fn(z1, z2, labels_source) elif args.ssl_task == 'crsc': z = torch.cat([z1.unsqueeze(1), z3.unsqueeze(1)], dim=1) ssl_loss = ssl_loss_fn(z, labels_source) elif args.ssl_task == 'crs': if args.use_new_ntxent: z = torch.cat([z1.unsqueeze(1), z3.unsqueeze(1)], dim=1) ssl_loss = ssl_loss_fn(z) else: ssl_loss = ssl_loss_fn(z1, z3) else: ssl_loss = torch.tensor(0.0).cuda() if args.cr_weight > 0: try: cr_loss = dist(f_hard[conf <= args.cr_threshold], f_weak[conf <= args.cr_threshold]).mean() if args.cr_metric == 'cos': cr_loss *= -1 except: print('Error computing CR loss') cr_loss = torch.tensor(0.0).cuda() else: cr_loss = torch.tensor(0.0).cuda() if args.ent_weight > 0: softmax_out = nn.Softmax(dim=1)(outputs_source) entropy_loss = torch.mean(Entropy(softmax_out)) classifier_loss += args.ent_weight * entropy_loss if args.gent_weight > 0: softmax_out = nn.Softmax(dim=1)(outputs_source) msoftmax = softmax_out.mean(dim=0) gentropy_loss = torch.sum(-msoftmax * torch.log(msoftmax + args.epsilon)) classifier_loss -= args.gent_weight * gentropy_loss loss = classifier_loss + args.ssl_weight * ssl_loss + args.cr_weight * cr_loss optimizer.zero_grad() loss.backward() optimizer.step() if iter_num % interval_iter == 0 or iter_num == max_iter: netF.eval() netH.eval() netB.eval() netC.eval() if args.dset == 'visda-c': acc_s_te, acc_list = cal_acc(dset_loaders['source_te'], netF, netH, netB, netC, args, True) log_str = 'Task: {}, Iter:{}/{}; Accuracy = {:.2f}%'.format( args.name_src, iter_num, max_iter, acc_s_te) + '\n' + acc_list else: acc_s_te, _ = cal_acc(dset_loaders['source_te'], netF, netH, netB, netC, args, False) log_str = 'Task: {}, Iter:{}/{}; Accuracy = {:.2f}%'.format( args.name_src, iter_num, max_iter, acc_s_te) args.out_file.write(log_str + '\n') args.out_file.flush() print(log_str + '\n') if acc_s_te >= acc_init: acc_init = acc_s_te if args.dataparallel: best_netF = netF.module.state_dict() best_netH = netH.module.state_dict() best_netB = netB.module.state_dict() best_netC = netC.module.state_dict() else: best_netF = netF.state_dict() best_netH = netH.state_dict() best_netB = netB.state_dict() best_netC = netC.state_dict() netF.train() netH.train() netB.train() netC.train() torch.save(best_netF, osp.join(args.output_dir_src, "source_F.pt")) torch.save(best_netH, osp.join(args.output_dir_src, "source_H.pt")) torch.save(best_netB, osp.join(args.output_dir_src, "source_B.pt")) torch.save(best_netC, osp.join(args.output_dir_src, "source_C.pt")) return netF, netH, netB, netC
def train_target_rot(args): dset_loaders = data_load(args) ## set base network if args.net[0:3] == 'res': netF = network.ResBase(res_name=args.net).cuda() elif args.net[0:3] == 'vgg': netF = network.VGGBase(vgg_name=args.net).cuda() netB = network.feat_bootleneck(type=args.classifier, feature_dim=netF.in_features, bottleneck_dim=args.bottleneck).cuda() netR = network.feat_classifier(type='linear', class_num=4, bottleneck_dim=2*args.bottleneck).cuda() modelpath = args.output_dir_src + '/source_F.pt' netF.load_state_dict(torch.load(modelpath)) netF.eval() for k, v in netF.named_parameters(): v.requires_grad = False modelpath = args.output_dir_src + '/source_B.pt' netB.load_state_dict(torch.load(modelpath)) netB.eval() for k, v in netB.named_parameters(): v.requires_grad = False param_group = [] for k, v in netR.named_parameters(): param_group += [{'params': v, 'lr': args.lr*1}] netR.train() optimizer = optim.SGD(param_group) optimizer = op_copy(optimizer) max_iter = args.max_epoch * len(dset_loaders["target"]) interval_iter = max_iter // 10 iter_num = 0 rot_acc = 0 while iter_num < max_iter: optimizer.zero_grad() try: inputs_test, _, tar_idx = iter_test.next() except: iter_test = iter(dset_loaders["target"]) inputs_test, _, tar_idx = iter_test.next() if inputs_test.size(0) == 1: continue inputs_test = inputs_test.cuda() iter_num += 1 lr_scheduler(optimizer, iter_num=iter_num, max_iter=max_iter) r_labels_target = np.random.randint(0, 4, len(inputs_test)) r_inputs_target = rotation.rotate_batch_with_labels(inputs_test, r_labels_target) r_labels_target = torch.from_numpy(r_labels_target).cuda() r_inputs_target = r_inputs_target.cuda() f_outputs = netB(netF(inputs_test)) f_r_outputs = netB(netF(r_inputs_target)) r_outputs_target = netR(torch.cat((f_outputs, f_r_outputs), 1)) rotation_loss = nn.CrossEntropyLoss()(r_outputs_target, r_labels_target) rotation_loss.backward() optimizer.step() if iter_num % interval_iter == 0 or iter_num == max_iter: netR.eval() acc_rot = cal_acc_rot(dset_loaders['target'], netF, netB, netR) log_str = 'Task: {}, Iter:{}/{}; Accuracy = {:.2f}%'.format(args.name, iter_num, max_iter, acc_rot) args.out_file.write(log_str + '\n') args.out_file.flush() print(log_str+'\n') netR.train() if rot_acc < acc_rot: rot_acc = acc_rot best_netR = netR.state_dict() log_str = 'Best Accuracy = {:.2f}%'.format(rot_acc) args.out_file.write(log_str + '\n') args.out_file.flush() print(log_str+'\n') return best_netR, rot_acc
def split_target(args): test_transform = torchvision.transforms.Compose([ torchvision.transforms.Resize((256, 256)), torchvision.transforms.CenterCrop(224), torchvision.transforms.ToTensor(), torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) test_set = ImageList(open(args.tu_dset_path).readlines(), transform=test_transform) dset_loaders = {} dset_loaders["test"] = torch.utils.data.DataLoader( test_set, batch_size=args.batch_size * 3, shuffle=False, num_workers=args.worker, drop_last=False) if args.net[0:3] == 'res': netF = network.ResBase(res_name=args.net).cuda() elif args.net[0:3] == 'vgg': netF = network.VGGBase(vgg_name=args.net).cuda() netB = network.feat_bootleneck(type=args.classifier, feature_dim=netF.in_features, bottleneck_dim=args.bottleneck).cuda() netC = network.feat_classifier(type=args.layer, class_num=args.class_num, bottleneck_dim=args.bottleneck).cuda() modelpath = args.output_dir + "/target_F_" + args.savename + ".pt" netF.load_state_dict(torch.load(modelpath)) modelpath = args.output_dir + "/target_B_" + args.savename + ".pt" netB.load_state_dict(torch.load(modelpath)) modelpath = args.output_dir + "/target_C_" + args.savename + ".pt" netC.load_state_dict(torch.load(modelpath)) netF.eval() netB.eval() netC.eval() start_test = True with torch.no_grad(): iter_test = iter(dset_loaders['test']) for i in range(len(dset_loaders['test'])): data = iter_test.next() inputs = data[0] labels = data[1] inputs = inputs.cuda() outputs = netC(netB(netF(inputs))) if start_test: all_output = outputs.float().cpu() all_label = labels.float() start_test = False else: all_output = torch.cat((all_output, outputs.float().cpu()), 0) all_label = torch.cat((all_label, labels.float()), 0) top_pred, predict = torch.max(all_output, 1) acc = torch.sum( torch.squeeze(predict).float() == all_label).item() / float( all_label.size()[0]) * 100 mean_ent = loss.Entropy(nn.Softmax(dim=1)(all_output)) log_str = 'Task: {}, Iter:{}/{}; Accuracy = {:.2f}%; Mean Ent = {:.4f}'.format( args.name, 0, 0, acc, mean_ent.mean()) args.out_file.write(log_str + '\n') args.out_file.flush() print(log_str + '\n') if args.ps == 0: est_p = (mean_ent < mean_ent.mean()).sum().item() / mean_ent.size(0) log_str = 'Task: {:.2f}'.format(est_p) print(log_str + '\n') args.out_file.write(log_str + '\n') args.out_file.flush() PS = est_p else: PS = args.ps if args.choice == "ent": value = mean_ent elif args.choice == "maxp": value = -top_pred elif args.choice == "marginp": pred, _ = torch.sort(all_output, 1) value = pred[:, 1] - pred[:, 0] else: value = torch.rand(len(mean_ent)) ori_target = open(args.tu_dset_path).readlines() new_tar = [] new_src = [] predict = predict.numpy() cls_k = args.class_num for c in range(cls_k): c_idx = np.where(predict == c) c_idx = c_idx[0] c_value = value[c_idx] _, idx_ = torch.sort(c_value) c_num = len(idx_) c_num_s = int(c_num * PS) for ei in range(0, c_num_s): ee = c_idx[idx_[ei]] reci = ori_target[ee].strip().split(' ') line = reci[0] + ' ' + str(c) + '\n' new_src.append(line) for ei in range(c_num_s, c_num): ee = c_idx[idx_[ei]] reci = ori_target[ee].strip().split(' ') line = reci[0] + ' ' + str(c) + '\n' new_tar.append(line) return new_src.copy(), new_tar.copy()
def train_target(args): dset_loaders = data_load(args) ## set base network if args.net[0:3] == 'res': netF_list = [network.ResBase(res_name=args.net).cuda() for i in range(len(args.src))] elif args.net[0:3] == 'vgg': netF_list = [network.VGGBase(vgg_name=args.net).cuda() for i in range(len(args.src))] w = 2*torch.rand((len(args.src),))-1 print(w) netB_list = [network.feat_bottleneck(type=args.classifier, feature_dim=netF_list[i].in_features, bottleneck_dim=args.bottleneck).cuda() for i in range(len(args.src))] netC_list = [network.feat_classifier(type=args.layer, class_num = args.class_num, bottleneck_dim=args.bottleneck).cuda() for i in range(len(args.src))] netG_list = [network.scalar(w[i]).cuda() for i in range(len(args.src))] param_group = [] for i in range(len(args.src)): modelpath = args.output_dir_src[i] + '/source_F.pt' print(modelpath) netF_list[i].load_state_dict(torch.load(modelpath)) netF_list[i].eval() for k, v in netF_list[i].named_parameters(): param_group += [{'params':v, 'lr':args.lr * args.lr_decay1}] modelpath = args.output_dir_src[i] + '/source_B.pt' print(modelpath) netB_list[i].load_state_dict(torch.load(modelpath)) netB_list[i].eval() for k, v in netB_list[i].named_parameters(): param_group += [{'params':v, 'lr':args.lr * args.lr_decay2}] modelpath = args.output_dir_src[i] + '/source_C.pt' print(modelpath) netC_list[i].load_state_dict(torch.load(modelpath)) netC_list[i].eval() for k, v in netC_list[i].named_parameters(): v.requires_grad = False for k, v in netG_list[i].named_parameters(): param_group += [{'params':v, 'lr':args.lr}] optimizer = optim.SGD(param_group) optimizer = op_copy(optimizer) max_iter = args.max_epoch * len(dset_loaders["target"]) interval_iter = max_iter // args.interval iter_num = 0 c = 0 while iter_num < max_iter: try: inputs_test, _, tar_idx = iter_test.next() except: iter_test = iter(dset_loaders["target"]) inputs_test, _, tar_idx = iter_test.next() if inputs_test.size(0) == 1: continue if iter_num % interval_iter == 0 and args.cls_par > 0: initc = [] all_feas = [] for i in range(len(args.src)): netF_list[i].eval() netB_list[i].eval() temp1, temp2 = obtain_label(dset_loaders['target_'], netF_list[i], netB_list[i], netC_list[i], args) temp1 = torch.from_numpy(temp1).cuda() temp2 = torch.from_numpy(temp2).cuda() initc.append(temp1) all_feas.append(temp2) netF_list[i].train() netB_list[i].train() inputs_test = inputs_test.cuda() iter_num += 1 lr_scheduler(optimizer, iter_num=iter_num, max_iter=max_iter) outputs_all = torch.zeros(len(args.src), inputs_test.shape[0], args.class_num) weights_all = torch.ones(inputs_test.shape[0], len(args.src)) outputs_all_w = torch.zeros(inputs_test.shape[0], args.class_num) init_ent = torch.zeros(1,len(args.src)) for i in range(len(args.src)): features_test = netB_list[i](netF_list[i](inputs_test)) outputs_test = netC_list[i](features_test) softmax_ = nn.Softmax(dim=1)(outputs_test) ent_loss = torch.mean(loss.Entropy(softmax_)) init_ent[:,i] = ent_loss weights_test = netG_list[i](features_test) outputs_all[i] = outputs_test weights_all[:, i] = weights_test.squeeze() z = torch.sum(weights_all, dim=1) z = z + 1e-16 weights_all = torch.transpose(torch.transpose(weights_all,0,1)/z,0,1) outputs_all = torch.transpose(outputs_all, 0, 1) z_ = torch.sum(weights_all, dim=0) z_2 = torch.sum(weights_all) z_ = z_/z_2 for i in range(inputs_test.shape[0]): outputs_all_w[i] = torch.matmul(torch.transpose(outputs_all[i],0,1), weights_all[i]) if args.cls_par > 0: initc_ = torch.zeros(initc[0].size()).cuda() temp = all_feas[0] all_feas_ = torch.zeros(temp[tar_idx, :].size()).cuda() for i in range(len(args.src)): initc_ = initc_ + z_[i] * initc[i].float() src_fea = all_feas[i] all_feas_ = all_feas_ + z_[i] * src_fea[tar_idx, :] dd = torch.cdist(all_feas_.float(), initc_.float(), p=2) pred_label = dd.argmin(dim=1) pred_label = pred_label.int() pred = pred_label.long() classifier_loss = args.cls_par * nn.CrossEntropyLoss()(outputs_all_w, pred.cpu()) else: classifier_loss = torch.tensor(0.0) if args.ent: softmax_out = nn.Softmax(dim=1)(outputs_all_w) entropy_loss = torch.mean(loss.Entropy(softmax_out)) if args.gent: msoftmax = softmax_out.mean(dim=0) entropy_loss -= torch.sum(-msoftmax * torch.log(msoftmax + 1e-5)) im_loss = entropy_loss * args.ent_par classifier_loss += im_loss optimizer.zero_grad() classifier_loss.backward() optimizer.step() if iter_num % interval_iter == 0 or iter_num == max_iter: for i in range(len(args.src)): netF_list[i].eval() netB_list[i].eval() acc, _ = cal_acc_multi(dset_loaders['test'], netF_list, netB_list, netC_list, netG_list, args) log_str = 'Iter:{}/{}; Accuracy = {:.2f}%'.format(iter_num, max_iter, acc) print(log_str+'\n') for i in range(len(args.src)): torch.save(netF_list[i].state_dict(), osp.join(args.output_dir, "target_F_" + str(i) + "_" + args.savename + ".pt")) torch.save(netB_list[i].state_dict(), osp.join(args.output_dir, "target_B_" + str(i) + "_" + args.savename + ".pt")) torch.save(netC_list[i].state_dict(), osp.join(args.output_dir, "target_C_" + str(i) + "_" + args.savename + ".pt")) torch.save(netG_list[i].state_dict(), osp.join(args.output_dir, "target_G_" + str(i) + "_" + args.savename + ".pt"))
def train_source(args): dset_loaders = data_load(args) ## set base network if args.net[0:3] == 'res' or args.net[0:3] == 'vgg': if args.net[0:3] == 'res': netF = network.ResBase(res_name=args.net).cuda() else: netF = network.VGGBase(vgg_name=args.net).cuda() netB = network.feat_bootleneck(type=args.classifier, feature_dim=netF.in_features, bottleneck_dim=args.bottleneck).cuda() # classifier: bn netC = network.feat_classifier(type=args.layer, class_num=args.class_num, bottleneck_dim=args.bottleneck).cuda() # layer: wn if args.resume: args.modelpath = args.output_dir_src + '/source_F.pt' netF.load_state_dict(torch.load(args.modelpath)) args.modelpath = args.output_dir_src + '/source_B.pt' netB.load_state_dict(torch.load(args.modelpath)) args.modelpath = args.output_dir_src + '/source_C.pt' netC.load_state_dict(torch.load(args.modelpath)) param_group = [] learning_rate = args.lr for k, v in netF.named_parameters(): param_group += [{'params': v, 'lr': learning_rate * 0.1}] for k, v in netB.named_parameters(): param_group += [{'params': v, 'lr': learning_rate}] for k, v in netC.named_parameters(): param_group += [{'params': v, 'lr': learning_rate}] optimizer = optim.SGD(param_group) optimizer = op_copy(optimizer) acc_init = 0. max_iter = args.max_epoch * len(dset_loaders["source_tr"]) print_loss_interval = 25 interval_iter = 100 iter_num = 0 if args.net[0:3] == 'res' or args.net[0:3] == 'vgg': netF.train() netB.train() netC.train() while iter_num < max_iter: try: inputs_source, labels_source = iter_source.next() except: iter_source = iter(dset_loaders["source_tr"]) inputs_source, labels_source = iter_source.next() if inputs_source.size(0) == 1: continue iter_num += 1 lr_scheduler(optimizer, iter_num=iter_num, max_iter=max_iter) inputs_source, labels_source = inputs_source.cuda(), labels_source.cuda() outputs_source = netC(netB(netF(inputs_source))) classifier_loss = CrossEntropyLabelSmooth(num_classes=args.class_num, epsilon=args.smooth)(outputs_source, labels_source) optimizer.zero_grad() classifier_loss.backward() optimizer.step() if iter_num % print_loss_interval == 0: print("Iter:{:>4d}/{} | Classification loss on Source: {:.2f}".format(iter_num, max_iter, classifier_loss.item())) if iter_num % interval_iter == 0 or iter_num == max_iter: netF.eval() netB.eval() netC.eval() if args.dset == 'VISDA-RSUT' or args.dset == 'VISDA-RSUT-50' or args.dset == 'VISDA-RSUT-10': # The small classes in VisDA-C (RSUT) still have relatively many samples. # Safe to use per-class average accuracy. acc_s_te, acc_list, acc_cls_avg_te= cal_acc(dset_loaders['source_te'], netF, netB, netC, per_class_flag=True, visda_flag=True) log_str = 'Task: {}, Iter:{}/{}; Accuracy = {:.2f}%, Cls Avg Acc = {:.2f}'.format(args.name_src, iter_num, max_iter, acc_s_te, acc_cls_avg_te) + '\n' + acc_list cur_acc = acc_cls_avg_te else: if args.trte == 'stratified': # Stratified cross validation ensures the existence of every class in the validation set. # Safe to use per-class average accuracy. acc_s_te, acc_cls_avg_te, _ = cal_acc(dset_loaders['source_te'], netF, netB, netC, per_class_flag=True, visda_flag=False) log_str = 'Task: {}, Iter:{}/{}; Accuracy = {:.2f}%, Cls Avg Acc = {:.2f}'.format(args.name_src, iter_num, max_iter, acc_s_te, acc_cls_avg_te) cur_acc = acc_cls_avg_te else: # Conventional cross validation may lead to the absence of certain classes in validation set, # esp. when the dataset includes some very small classes, e.g., Office-Home (RSUT), DomainNet. # Use overall accuracy to avoid 'nan' issue. acc_s_te, _ = cal_acc(dset_loaders['source_te'], netF, netB, netC, per_class_flag=False, visda_flag=False) log_str = 'Task: {}, Iter:{}/{}; Accuracy = {:.2f}%'.format(args.name_src, iter_num, max_iter, acc_s_te) cur_acc = acc_s_te args.out_file.write(log_str + '\n') args.out_file.flush() print(log_str + '\n') if cur_acc >= acc_init and iter_num >= 3 * len(dset_loaders["source_tr"]): # first 3 epochs: not stable yet acc_init = cur_acc best_netF = netF.state_dict() best_netB = netB.state_dict() best_netC = netC.state_dict() netF.train() netB.train() netC.train() torch.save(best_netF, osp.join(args.output_dir_src, "source_F.pt")) torch.save(best_netB, osp.join(args.output_dir_src, "source_B.pt")) torch.save(best_netC, osp.join(args.output_dir_src, "source_C.pt")) return netF, netB, netC
def train_target(args): dset_loaders = data_load(args) ## set base network if args.net[0:3] == 'res': netF = network.ResBase(res_name=args.net).cuda() elif args.net[0:3] == 'vgg': netF = network.VGGBase(vgg_name=args.net).cuda() netB = network.feat_bootleneck(type=args.classifier, feature_dim=netF.in_features, bottleneck_dim=args.bottleneck).cuda() netC = network.feat_classifier(type=args.layer, class_num=args.class_num, bottleneck_dim=args.bottleneck).cuda() wandb.watch(netB) wandb.watch(netC) wandb.watch(netF) modelpath = args.output_dir_src + '/source_F.pt' netF.load_state_dict(torch.load(modelpath)) modelpath = args.output_dir_src + '/source_B.pt' netB.load_state_dict(torch.load(modelpath)) modelpath = args.output_dir_src + '/source_C.pt' netC.load_state_dict(torch.load(modelpath)) netC.eval() for k, v in netC.named_parameters(): v.requires_grad = False param_group = [] for k, v in netF.named_parameters(): if args.lr_decay1 > 0: param_group += [{'params': v, 'lr': args.lr * args.lr_decay1}] else: v.requires_grad = False for k, v in netB.named_parameters(): if args.lr_decay2 > 0: param_group += [{'params': v, 'lr': args.lr * args.lr_decay2}] else: v.requires_grad = False optimizer = optim.SGD(param_group) optimizer = op_copy(optimizer) max_iter = args.max_epoch * len(dset_loaders["target"]) interval_iter = max_iter // args.interval iter_num = 0 while iter_num < max_iter: try: inputs_test, _, tar_idx = iter_test.next() except: iter_test = iter(dset_loaders["target"]) inputs_test, _, tar_idx = iter_test.next() if inputs_test.size(0) == 1: continue if iter_num % interval_iter == 0 and args.cls_par > 0: netF.eval() netB.eval() # pseudolabels mem_label = obtain_label(dset_loaders['test'], netF, netB, netC, args) mem_label = torch.from_numpy(mem_label).cuda() netF.train() netB.train() inputs_test = inputs_test.cuda() iter_num += 1 lr_scheduler(optimizer, iter_num=iter_num, max_iter=max_iter) features_test = netB(netF(inputs_test)) outputs_test = netC(features_test) if args.cls_par > 0: pred = mem_label[tar_idx] # cross-entropy loss between classifier_loss = nn.CrossEntropyLoss()(outputs_test, pred) classifier_loss *= args.cls_par if iter_num < interval_iter and (args.dset == "VISDA18" or args.dset == 'VISDA-C'): classifier_loss *= 0 else: classifier_loss = torch.tensor(0.0).cuda() if args.ent: softmax_out = nn.Softmax(dim=1)(outputs_test) entropy_loss = torch.mean( loss.Entropy(softmax_out)) # expectation over samples drawn if args.gent: msoftmax = softmax_out.mean(dim=0) gentropy_loss = torch.sum(-msoftmax * torch.log(msoftmax + args.epsilon)) entropy_loss -= gentropy_loss im_loss = entropy_loss * args.ent_par classifier_loss += im_loss wandb.log({ "target_classifier_loss": classifier_loss, "info_max_loss": im_loss }) optimizer.zero_grad() classifier_loss.backward() optimizer.step() if iter_num % interval_iter == 0 or iter_num == max_iter: netF.eval() netB.eval() if args.dset == 'VISDA18' or args.dset == 'VISCA-C': # if args.dset=='VISDA-C': acc_s_te, acc_list = cal_acc(dset_loaders['test'], netF, netB, netC, True) log_str = 'Task: {}, Iter:{}/{}; Accuracy = {:.2f}%'.format( args.name, iter_num, max_iter, acc_s_te) + '\n' + acc_list wandb.log({"accuracy": acc_s_te}) # true test accuracy else: acc_s_te, _ = cal_acc(dset_loaders['test'], netF, netB, netC, False) log_str = 'Task: {}, Iter:{}/{}; Accuracy = {:.2f}%'.format( args.name, iter_num, max_iter, acc_s_te) args.out_file.write(log_str + '\n') args.out_file.flush() print(log_str + '\n') netF.train() netB.train() if args.issave: torch.save( netF.state_dict(), osp.join(args.output_dir, "target_F_" + args.savename + ".pt")) torch.save( netB.state_dict(), osp.join(args.output_dir, "target_B_" + args.savename + ".pt")) torch.save( netC.state_dict(), osp.join(args.output_dir, "target_C_" + args.savename + ".pt")) return netF, netB, netC
def train_distill(args): dset_loaders = data_load(args) # load sources if args.net[0:3] == 'res': netF_list = [network.ResBase(res_name=args.net).cuda() for i in range(len(args.src))] netF = network.ResBase(res_name=args.net).cuda() elif args.net[0:3] == 'vgg': netF_list = [network.VGGBase(vgg_name=args.net).cuda() for i in range(len(args.src))] netF = network.VGGBase(res_name=args.net).cuda() netB_list = [network.feat_bottleneck(type=args.classifier, feature_dim=netF_list[i].in_features, bottleneck_dim=args.bottleneck).cuda() for i in range(len(args.src))] netC_list = [network.feat_classifier(type=args.layer, class_num = args.class_num, bottleneck_dim=args.bottleneck).cuda() for i in range(len(args.src))] netG_list = [network.scalar(1).cuda() for i in range(len(args.src))] for i in range(len(args.src)): modelpath = args.output_dir_src + '/target_F_'+str(i)+'_par_0.3.pt' netF_list[i].load_state_dict(torch.load(modelpath)) netF_list[i].eval() netF_list[i].cuda() for k, v in netF_list[i].named_parameters(): v.requires_grad = False modelpath = args.output_dir_src + '/target_B_'+str(i)+'_par_0.3.pt' netB_list[i].load_state_dict(torch.load(modelpath)) netB_list[i].eval() netB_list[i].cuda() for k, v in netB_list[i].named_parameters(): v.requires_grad = False modelpath = args.output_dir_src + '/target_C_'+str(i)+'_par_0.3.pt' netC_list[i].load_state_dict(torch.load(modelpath)) netC_list[i].eval() netC_list[i].cuda() for k, v in netC_list[i].named_parameters(): v.requires_grad = False modelpath = args.output_dir_src + '/target_G_'+str(i)+'_par_0.3.pt' netG_list[i].load_state_dict(torch.load(modelpath)) netG_list[i].eval() netG_list[i].cuda() for k, v in netG_list[i].named_parameters(): v.requires_grad = False # create student netB = network.feat_bottleneck(type=args.classifier, feature_dim=netF.in_features, bottleneck_dim=args.bottleneck).cuda() netC = network.feat_classifier(type=args.layer, class_num = args.class_num, bottleneck_dim=args.bottleneck).cuda() param_group = [] learning_rate = args.lr for k, v in netF.named_parameters(): param_group += [{'params': v, 'lr': learning_rate}] for k, v in netB.named_parameters(): param_group += [{'params': v, 'lr': learning_rate}] for k, v in netC.named_parameters(): param_group += [{'params': v, 'lr': learning_rate}] optimizer = optim.SGD(param_group) optimizer = op_copy(optimizer) acc_init = 0 max_iter = args.max_epoch * len(dset_loaders["target"]) interval_iter = max_iter // 10 iter_num = 0 netF.train() netB.train() netC.train() while iter_num < max_iter: try: inputs = iter_source.next() except: iter_source = iter(dset_loaders["target"]) inputs = iter_source.next() inputs = inputs[0] if inputs.size(0) == 1: continue iter_num += 1 lr_scheduler(optimizer, iter_num=iter_num, max_iter=max_iter) labels, logits = get_labels(inputs, netF_list, netB_list, netC_list, netG_list) inputs, labels, logits = inputs.cuda(), labels.cuda(), logits.cuda() labels, logits = labels.detach(), logits.detach() outputs = netC(netB(netF(inputs))) classifier_loss = nn.CrossEntropyLoss()(outputs, labels) optimizer.zero_grad() classifier_loss.backward() optimizer.step() if iter_num % interval_iter == 0 or iter_num == max_iter: netF.eval() netB.eval() netC.eval() acc_s_te, _ = cal_acc(dset_loaders['test'], netF, netB, netC, False) log_str = 'Task: {}, Iter:{}/{}; Accuracy = {:.2f}%'.format(args.tgt, iter_num, max_iter, acc_s_te) # args.out_file.write(log_str + '\n') # args.out_file.flush() print(log_str+'\n') if acc_s_te >= acc_init: acc_init = acc_s_te best_netF = netF.state_dict() best_netB = netB.state_dict() best_netC = netC.state_dict() netF.train() netB.train() netC.train() torch.save(best_netF, osp.join(args.output_dir_src, "source_F.pt")) torch.save(best_netB, osp.join(args.output_dir_src, "source_B.pt")) torch.save(best_netC, osp.join(args.output_dir_src, "source_C.pt"))
def train_target(args): dset_loaders = data_load(args) if args.net[0:3] == 'res': netF = network.ResBase(res_name=args.net).cuda() elif args.net[0:3] == 'vgg': netF = network.VGGBase(vgg_name=args.net).cuda() netB = network.feat_bootleneck(type=args.classifier, feature_dim=netF.in_features, bottleneck_dim=args.bottleneck).cuda() netC = network.feat_classifier(type=args.layer, class_num=args.class_num, bottleneck_dim=args.bottleneck).cuda() modelpath = args.output_dir_src + '/source_F.pt' netF.load_state_dict(torch.load(modelpath)) modelpath = args.output_dir_src + '/source_B.pt' netB.load_state_dict(torch.load(modelpath)) modelpath = args.output_dir_src + '/source_C.pt' netC.load_state_dict(torch.load(modelpath)) netC.eval() for k, v in netC.named_parameters(): v.requires_grad = False param_group = [] for k, v in netF.named_parameters(): if args.lr_decay1 > 0: param_group += [{'params': v, 'lr': args.lr * args.lr_decay1}] else: v.requires_grad = False for k, v in netB.named_parameters(): if args.lr_decay2 > 0: param_group += [{'params': v, 'lr': args.lr * args.lr_decay2}] else: v.requires_grad = False optimizer = optim.SGD(param_group) optimizer = op_copy(optimizer) max_iter = args.max_epoch * len(dset_loaders["target"]) interval_iter = max_iter // args.interval iter_num = 0 iter_sw = int(max_iter / 2.0) while iter_num < max_iter: try: inputs_test, _, tar_idx = iter_test.next() except: iter_test = iter(dset_loaders["target"]) inputs_test, _, tar_idx = iter_test.next() if inputs_test.size(0) == 1: continue if iter_num % interval_iter == 0 and args.cls_par > 0: netF.eval() netB.eval() mem_label_soft, mtx_infor_nh, feas_FC = obtain_label( dset_loaders['test'], netF, netB, netC, args, iter_num, iter_sw) mem_label_soft = torch.from_numpy(mem_label_soft).cuda() feas_all = feas_FC[0] ops_all = feas_FC[1] netF.train() netB.train() inputs_test = inputs_test.cuda() iter_num += 1 lr_scheduler(optimizer, iter_num=iter_num, max_iter=max_iter) features_F_self = netF(inputs_test) features_F_nh = get_mtx_sam_wgt_nh(feas_all, mtx_infor_nh, tar_idx) features_F_nh = features_F_nh.cuda() features_F_mix = 0.8 * features_F_self + 0.2 * features_F_nh outputs_test_mix = netC(netB(features_F_mix)) ops_test_self = netC(netB(features_F_self)) outputs_test_nh = netC(netB(features_F_nh)) if args.cls_par > 0: log_probs = nn.LogSoftmax(dim=1)(outputs_test_mix) targets = mem_label_soft[tar_idx] loss_soft = (-targets * log_probs).sum(dim=1) classifier_loss = loss_soft.mean() classifier_loss *= args.cls_par if iter_num < interval_iter and args.dset == "VISDA-C": classifier_loss *= 0 else: classifier_loss = torch.tensor(0.0).cuda() if args.ent: softmax_out = nn.Softmax(dim=1)( outputs_test_mix) # outputs_test_mix entropy_loss = torch.mean(loss.Entropy(softmax_out)) if args.gent: msoftmax = softmax_out.mean(dim=0) gentropy_loss = torch.sum(-msoftmax * torch.log(msoftmax + args.epsilon)) entropy_loss -= gentropy_loss im_loss = entropy_loss * args.ent_par classifier_loss += im_loss optimizer.zero_grad() classifier_loss.backward() optimizer.step() if iter_num % interval_iter == 0 or iter_num == max_iter: netF.eval() netB.eval() if args.dset == 'VISDA-C': acc_s_te, acc_list = cal_acc(dset_loaders['test'], netF, netB, netC, True) log_str = 'Task: {}, Iter:{}/{}; Accuracy = {:.2f}%'.format( args.name, iter_num, max_iter, acc_s_te) + '\n' + acc_list else: acc_s_te, _ = cal_acc(dset_loaders['test'], netF, netB, netC, False) log_str = 'Task: {}, Iter:{}/{}; Accuracy = {:.2f}%'.format( args.name, iter_num, max_iter, acc_s_te) args.out_file.write(log_str + '\n') args.out_file.flush() print(log_str + '\n') netF.train() netB.train() if args.issave: torch.save( netF.state_dict(), osp.join(args.output_dir, "target_F_" + args.savename + ".pt")) torch.save( netB.state_dict(), osp.join(args.output_dir, "target_B_" + args.savename + ".pt")) torch.save( netC.state_dict(), osp.join(args.output_dir, "target_C_" + args.savename + ".pt")) return netF, netB, netC
def train_target(args): dset_loaders = data_load(args) ## set base network if args.net[0:3] == 'res': netF = network.ResBase(res_name=args.net).cuda() elif args.net[0:3] == 'vgg': netF = network.VGGBase(vgg_name=args.net).cuda() netB = network.feat_bootleneck(type=args.classifier, feature_dim=netF.in_features, bottleneck_dim=args.bottleneck).cuda() netC = network.feat_classifier(type=args.layer, class_num=args.class_num, bottleneck_dim=args.bottleneck).cuda() args.modelpath = args.output_dir_src + '/source_F.pt' netF.load_state_dict(torch.load(args.modelpath)) args.modelpath = args.output_dir_src + '/source_B.pt' netB.load_state_dict(torch.load(args.modelpath)) args.modelpath = args.output_dir_src + '/source_C.pt' netC.load_state_dict(torch.load(args.modelpath)) netC.eval() for k, v in netC.named_parameters(): v.requires_grad = False param_group = [] for k, v in netF.named_parameters(): if args.lr_decay1 > 0: param_group += [{'params': v, 'lr': args.lr * args.lr_decay1}] else: v.requires_grad = False for k, v in netB.named_parameters(): if args.lr_decay2 > 0: param_group += [{'params': v, 'lr': args.lr * args.lr_decay2}] else: v.requires_grad = False optimizer = optim.SGD(param_group) optimizer = op_copy(optimizer) tt = 0 iter_num = 0 max_iter = args.max_epoch * len(dset_loaders["target"]) interval_iter = max_iter // args.interval while iter_num < max_iter: try: inputs_test, _, tar_idx = iter_test.next() except: iter_test = iter(dset_loaders["target"]) inputs_test, _, tar_idx = iter_test.next() if inputs_test.size(0) == 1: continue if iter_num % interval_iter == 0: netF.eval() netB.eval() mem_label, ENT_THRESHOLD = obtain_label(dset_loaders['test'], netF, netB, netC, args) mem_label = torch.from_numpy(mem_label).cuda() netF.train() netB.train() inputs_test = inputs_test.cuda() iter_num += 1 lr_scheduler(optimizer, iter_num=iter_num, max_iter=max_iter) pred = mem_label[tar_idx] features_test = netB(netF(inputs_test)) outputs_test = netC(features_test) softmax_out = nn.Softmax(dim=1)(outputs_test) outputs_test_known = outputs_test[pred < args.class_num, :] pred = pred[pred < args.class_num] if len(pred) == 0: print(tt) del features_test del outputs_test tt += 1 continue if args.cls_par > 0: classifier_loss = nn.CrossEntropyLoss()(outputs_test_known, pred) classifier_loss *= args.cls_par else: classifier_loss = torch.tensor(0.0).cuda() if args.ent: softmax_out_known = nn.Softmax(dim=1)(outputs_test_known) entropy_loss = torch.mean(loss.Entropy(softmax_out_known)) if args.gent: msoftmax = softmax_out.mean(dim=0) gentropy_loss = torch.sum(-msoftmax * torch.log(msoftmax + args.epsilon)) entropy_loss -= gentropy_loss classifier_loss += entropy_loss * args.ent_par optimizer.zero_grad() classifier_loss.backward() optimizer.step() if iter_num % interval_iter == 0 or iter_num == max_iter: netF.eval() netB.eval() acc_os1, acc_os2, acc_unknown = cal_acc(dset_loaders['test'], netF, netB, netC, True, ENT_THRESHOLD) log_str = 'Task: {}, Iter:{}/{}; Accuracy = {:.2f}% / {:.2f}% / {:.2f}%'.format( args.name, iter_num, max_iter, acc_os2, acc_os1, acc_unknown) args.out_file.write(log_str + '\n') args.out_file.flush() print(log_str + '\n') netF.train() netB.train() if args.issave: torch.save( netF.state_dict(), osp.join(args.output_dir, "target_F_" + args.savename + ".pt")) torch.save( netB.state_dict(), osp.join(args.output_dir, "target_B_" + args.savename + ".pt")) torch.save( netC.state_dict(), osp.join(args.output_dir, "target_C_" + args.savename + ".pt")) return netF, netB, netC
def train(args, txt_src, txt_tgt): ## set pre-process dset_loaders = data_load(args, txt_src, txt_tgt) ## set base network if args.net[0:3] == 'res': netG = network.ResBase(res_name=args.net).cuda() elif args.net[0:3] == 'vgg': netG = network.VGGBase(vgg_name=args.net).cuda() max_len = max(len(dset_loaders["source"]), len(dset_loaders["target"])) max_iter = args.max_epoch*max_len interval_iter = max_iter // 10 netB = network.feat_bootleneck(type=args.classifier, feature_dim=netG.in_features, bottleneck_dim=args.bottleneck).cuda() netC = network.feat_classifier(type=args.layer, class_num = args.class_num, bottleneck_dim=args.bottleneck).cuda() if args.model == "source": modelpath = args.output_dir + "/source_F.pt" netG.load_state_dict(torch.load(modelpath)) modelpath = args.output_dir + "/source_B.pt" netB.load_state_dict(torch.load(modelpath)) else: modelpath = args.output_dir + "/target_F_" + args.savename + ".pt" netG.load_state_dict(torch.load(modelpath)) modelpath = args.output_dir + "/target_B_" + args.savename + ".pt" netB.load_state_dict(torch.load(modelpath)) if len(args.gpu_id.split(',')) > 1: netG = nn.DataParallel(netG) netF = nn.Sequential(netB, netC) optimizer_g = optim.SGD(netG.parameters(), lr = args.lr * 0.1) optimizer_f = optim.SGD(netF.parameters(), lr = args.lr) base_network = nn.Sequential(netG, netF) source_loader_iter = iter(dset_loaders["source"]) target_loader_iter = iter(dset_loaders["target"]) list_acc = [] best_ent = 100 for iter_num in range(1, max_iter + 1): base_network.train() lr_scheduler(optimizer_g, init_lr=args.lr * 0.1, iter_num=iter_num, max_iter=max_iter) lr_scheduler(optimizer_f, init_lr=args.lr, iter_num=iter_num, max_iter=max_iter) try: inputs_source, labels_source = source_loader_iter.next() except: source_loader_iter = iter(dset_loaders["source"]) inputs_source, labels_source = source_loader_iter.next() try: inputs_target, _, target_idx = target_loader_iter.next() except: target_loader_iter = iter(dset_loaders["target"]) inputs_target, _, target_idx = target_loader_iter.next() targets_s = torch.zeros(args.batch_size, args.class_num).scatter_(1, labels_source.view(-1,1), 1) inputs_s = inputs_source.cuda() targets_s = targets_s.cuda() inputs_t = inputs_target[0].cuda() inputs_t2 = inputs_target[1].cuda() with torch.no_grad(): # compute guessed labels of unlabel samples outputs_u = base_network(inputs_t) outputs_u2 = base_network(inputs_t2) p = (torch.softmax(outputs_u, dim=1) + torch.softmax(outputs_u2, dim=1)) / 2 pt = p**(1/args.T) targets_u = pt / pt.sum(dim=1, keepdim=True) targets_u = targets_u.detach() #################################################################### all_inputs = torch.cat([inputs_s, inputs_t, inputs_t2], dim=0) all_targets = torch.cat([targets_s, targets_u, targets_u], dim=0) if args.alpha > 0: l = np.random.beta(args.alpha, args.alpha) l = max(l, 1-l) else: l = 1 idx = torch.randperm(all_inputs.size(0)) input_a, input_b = all_inputs, all_inputs[idx] target_a, target_b = all_targets, all_targets[idx] mixed_input = l * input_a + (1 - l) * input_b mixed_target = l * target_a + (1 - l) * target_b # interleave labeled and unlabed samples between batches to get correct batchnorm calculation mixed_input = list(torch.split(mixed_input, args.batch_size)) mixed_input = utils.interleave(mixed_input, args.batch_size) # s = [sa, sb, sc] # t1 = [t1a, t1b, t1c] # t2 = [t2a, t2b, t2c] # => s' = [sa, t1b, t2c] t1' = [t1a, sb, t1c] t2' = [t2a, t2b, sc] logits = base_network(mixed_input[0]) logits = [logits] for input in mixed_input[1:]: temp = base_network(input) logits.append(temp) # put interleaved samples back # [i[:,0] for i in aa] logits = utils.interleave(logits, args.batch_size) logits_x = logits[0] logits_u = torch.cat(logits[1:], dim=0) train_criterion = utils.SemiLoss() Lx, Lu, w = train_criterion(logits_x, mixed_target[:args.batch_size], logits_u, mixed_target[args.batch_size:], iter_num, max_iter, args.lambda_u) loss = Lx + w * Lu optimizer_g.zero_grad() optimizer_f.zero_grad() loss.backward() optimizer_g.step() optimizer_f.step() if iter_num % interval_iter == 0 or iter_num == max_iter: base_network.eval() if args.dset == 'VISDA-C': acc, py, score, y, tacc = cal_acc(dset_loaders["test"], base_network, flag=True) print(tacc) args.out_file.write(tacc + '\n') args.out_file.flush() _ent = Entropy(score) mean_ent = 0 for ci in range(args.class_num): if _ent[py==ci].size(0) > 0: mean_ent += _ent[py==ci].mean() mean_ent /= args.class_num else: acc, py, score, y = cal_acc(dset_loaders["test"], base_network, flag=False) mean_ent = torch.mean(Entropy(score)) list_acc.append(acc) if best_ent > mean_ent: val_acc = acc best_ent = mean_ent best_y = y best_py = py best_score = score log_str = 'Task: {}, Iter:{}/{}; Accuracy = {:.2f}%; Mean Ent = {:.4f}'.format(args.name, iter_num, max_iter, acc, mean_ent) args.out_file.write(log_str + '\n') args.out_file.flush() print(log_str+'\n') idx = np.argmax(np.array(list_acc)) max_acc = list_acc[idx] final_acc = list_acc[-1] log_str = '\n==========================================\n' log_str += '\nVal Acc = {:.2f}\nMax Acc = {:.2f}\nFin Acc = {:.2f}\n'.format(val_acc, max_acc, final_acc) args.out_file.write(log_str + '\n') args.out_file.flush() return base_network, py
def train_target(args): dset_loaders = data_load(args) ## set base network if args.net[0:3] == 'res': netF = network.ResBase(res_name=args.net).cuda() elif args.net[0:3] == 'vgg': netF = network.VGGBase(vgg_name=args.net).cuda() netB = network.feat_bootleneck(type=args.classifier, feature_dim=netF.in_features, bottleneck_dim=args.bottleneck).cuda() netC = network.feat_classifier(type=args.layer, class_num = args.class_num, bottleneck_dim=args.bottleneck).cuda() if not args.ssl == 0: netR = network.feat_classifier(type='linear', class_num=4, bottleneck_dim=2*args.bottleneck).cuda() netR_dict, acc_rot = train_target_rot(args) netR.load_state_dict(netR_dict) modelpath = args.output_dir_src + '/source_F.pt' netF.load_state_dict(torch.load(modelpath)) modelpath = args.output_dir_src + '/source_B.pt' netB.load_state_dict(torch.load(modelpath)) modelpath = args.output_dir_src + '/source_C.pt' netC.load_state_dict(torch.load(modelpath)) netC.eval() for k, v in netC.named_parameters(): v.requires_grad = False param_group = [] for k, v in netF.named_parameters(): if args.lr_decay1 > 0: param_group += [{'params': v, 'lr': args.lr * args.lr_decay1}] else: v.requires_grad = False for k, v in netB.named_parameters(): if args.lr_decay2 > 0: param_group += [{'params': v, 'lr': args.lr * args.lr_decay2}] else: v.requires_grad = False if not args.ssl == 0: for k, v in netR.named_parameters(): param_group += [{'params': v, 'lr': args.lr * args.lr_decay2}] netR.train() optimizer = optim.SGD(param_group) optimizer = op_copy(optimizer) max_iter = args.max_epoch * len(dset_loaders["target"]) interval_iter = max_iter // args.interval iter_num = 0 while iter_num < max_iter: optimizer.zero_grad() try: inputs_test, _, tar_idx = iter_test.next() except: iter_test = iter(dset_loaders["target"]) inputs_test, _, tar_idx = iter_test.next() if inputs_test.size(0) == 1: continue if iter_num % interval_iter == 0 and args.cls_par > 0: netF.eval() netB.eval() mem_label = obtain_label(dset_loaders['target_te'], netF, netB, netC, args) mem_label = torch.from_numpy(mem_label).cuda() netF.train() netB.train() inputs_test = inputs_test.cuda() iter_num += 1 lr_scheduler(optimizer, iter_num=iter_num, max_iter=max_iter) if args.cls_par > 0: pred = mem_label[tar_idx] features_test = netB(netF(inputs_test)) outputs_test = netC(features_test) if args.cls_par > 0: classifier_loss = nn.CrossEntropyLoss()(outputs_test, pred) classifier_loss *= args.cls_par if iter_num < interval_iter and args.dset == "VISDA-C": classifier_loss *= 0 else: classifier_loss = torch.tensor(0.0).cuda() if args.ent: softmax_out = nn.Softmax(dim=1)(outputs_test) entropy_loss = torch.mean(loss.Entropy(softmax_out)) if args.gent: msoftmax = softmax_out.mean(dim=0) gentropy_loss = torch.sum(-msoftmax * torch.log(msoftmax + args.epsilon)) entropy_loss -= gentropy_loss im_loss = entropy_loss * args.ent_par classifier_loss += im_loss classifier_loss.backward() if not args.ssl == 0: r_labels_target = np.random.randint(0, 4, len(inputs_test)) r_inputs_target = rotation.rotate_batch_with_labels(inputs_test, r_labels_target) r_labels_target = torch.from_numpy(r_labels_target).cuda() r_inputs_target = r_inputs_target.cuda() f_outputs = netB(netF(inputs_test)) f_outputs = f_outputs.detach() f_r_outputs = netB(netF(r_inputs_target)) r_outputs_target = netR(torch.cat((f_outputs, f_r_outputs), 1)) rotation_loss = args.ssl * nn.CrossEntropyLoss()(r_outputs_target, r_labels_target) rotation_loss.backward() optimizer.step() if iter_num % interval_iter == 0 or iter_num == max_iter: netF.eval() netB.eval() if args.dset=='VISDA-C': acc_s_te, acc_list = cal_acc(dset_loaders['test'], netF, netB, netC, True) log_str = 'Task: {}, Iter:{}/{}; Accuracy = {:.2f}%'.format(args.name, iter_num, max_iter, acc_s_te) + '\n' + acc_list else: acc_s_te, _ = cal_acc(dset_loaders['test'], netF, netB, netC, False) log_str = 'Task: {}, Iter:{}/{}; Accuracy = {:.2f}%'.format(args.name, iter_num, max_iter, acc_s_te) args.out_file.write(log_str + '\n') args.out_file.flush() print(log_str+'\n') netF.train() netB.train() if args.issave: torch.save(netF.state_dict(), osp.join(args.output_dir, "target_F_" + args.savename + ".pt")) torch.save(netB.state_dict(), osp.join(args.output_dir, "target_B_" + args.savename + ".pt")) torch.save(netC.state_dict(), osp.join(args.output_dir, "target_C_" + args.savename + ".pt")) return netF, netB, netC
def train_source(args): dset_loaders = data_load(args) ## set base network if args.net[0:3] == 'res': netF = network.ResBase(res_name=args.net).cuda() elif args.net[0:3] == 'vgg': netF = network.VGGBase(vgg_name=args.net).cuda() netB = network.feat_bootleneck(type=args.classifier, feature_dim=netF.in_features, bottleneck_dim=args.bottleneck).cuda() netC = network.feat_classifier(type=args.layer, class_num=args.class_num, bottleneck_dim=args.bottleneck).cuda() param_group = [] learning_rate = args.lr for k, v in netF.named_parameters(): param_group += [{'params': v, 'lr': learning_rate * 0.1}] for k, v in netB.named_parameters(): param_group += [{'params': v, 'lr': learning_rate}] for k, v in netC.named_parameters(): param_group += [{'params': v, 'lr': learning_rate}] optimizer = optim.SGD(param_group) optimizer = op_copy(optimizer) acc_init = 0 max_iter = args.max_epoch * len(dset_loaders["source_tr"]) interval_iter = max_iter // 10 iter_num = 0 netF.train() netB.train() netC.train() while iter_num < max_iter: try: inputs_source, labels_source = iter_source.next() except: iter_source = iter(dset_loaders["source_tr"]) inputs_source, labels_source = iter_source.next() if inputs_source.size(0) == 1: continue iter_num += 1 lr_scheduler(optimizer, iter_num=iter_num, max_iter=max_iter) inputs_source, labels_source = inputs_source.cuda( ), labels_source.cuda() outputs_source = netC(netB(netF(inputs_source))) classifier_loss = CrossEntropyLabelSmooth( num_classes=args.class_num, epsilon=args.smooth)(outputs_source, labels_source) optimizer.zero_grad() classifier_loss.backward() optimizer.step() if iter_num % interval_iter == 0 or iter_num == max_iter: netF.eval() netB.eval() netC.eval() acc_s_te, _ = cal_acc(dset_loaders['source_te'], netF, netB, netC) log_str = 'Task: {}, Iter:{}/{}; Accuracy = {:.2f}%'.format( args.name_src, iter_num, max_iter, acc_s_te) args.out_file.write(log_str + '\n') args.out_file.flush() print(log_str + '\n') if acc_s_te >= acc_init: acc_init = acc_s_te best_netF = netF.state_dict() best_netB = netB.state_dict() best_netC = netC.state_dict() netF.train() netB.train() netC.train() torch.save(best_netF, osp.join(args.output_dir_src, "source_F.pt")) torch.save(best_netB, osp.join(args.output_dir_src, "source_B.pt")) torch.save(best_netC, osp.join(args.output_dir_src, "source_C.pt")) return netF, netB, netC