def cal_acc_rot(loader, netF, netB, netR): start_test = True with torch.no_grad(): iter_test = iter(loader) for i in range(len(loader)): data = iter_test.next() inputs = data[0].cuda() r_labels = np.random.randint(0, 4, len(inputs)) r_inputs = rotation.rotate_batch_with_labels(inputs, r_labels) r_labels = torch.from_numpy(r_labels) r_inputs = r_inputs.cuda() f_outputs = netB(netF(inputs)) f_r_outputs = netB(netF(r_inputs)) r_outputs = netR(torch.cat((f_outputs, f_r_outputs), 1)) if start_test: all_output = r_outputs.float().cpu() all_label = r_labels.float() start_test = False else: all_output = torch.cat((all_output, r_outputs.float().cpu()), 0) all_label = torch.cat((all_label, r_labels.float()), 0) _, predict = torch.max(all_output, 1) accuracy = torch.sum(torch.squeeze(predict).float() == all_label).item() / float(all_label.size()[0]) return accuracy*100
def train_target(args): dset_loaders = data_load(args) ## set base network if args.net[0:3] == 'res': netF = network.ResBase(res_name=args.net).cuda() elif args.net[0:3] == 'vgg': netF = network.VGGBase(vgg_name=args.net).cuda() netB = network.feat_bootleneck(type=args.classifier, feature_dim=netF.in_features, bottleneck_dim=args.bottleneck).cuda() netC = network.feat_classifier(type=args.layer, class_num = args.class_num, bottleneck_dim=args.bottleneck).cuda() if not args.ssl == 0: netR = network.feat_classifier(type='linear', class_num=4, bottleneck_dim=2*args.bottleneck).cuda() netR_dict, acc_rot = train_target_rot(args) netR.load_state_dict(netR_dict) modelpath = args.output_dir_src + '/source_F.pt' netF.load_state_dict(torch.load(modelpath)) modelpath = args.output_dir_src + '/source_B.pt' netB.load_state_dict(torch.load(modelpath)) modelpath = args.output_dir_src + '/source_C.pt' netC.load_state_dict(torch.load(modelpath)) netC.eval() for k, v in netC.named_parameters(): v.requires_grad = False param_group = [] for k, v in netF.named_parameters(): if args.lr_decay1 > 0: param_group += [{'params': v, 'lr': args.lr * args.lr_decay1}] else: v.requires_grad = False for k, v in netB.named_parameters(): if args.lr_decay2 > 0: param_group += [{'params': v, 'lr': args.lr * args.lr_decay2}] else: v.requires_grad = False if not args.ssl == 0: for k, v in netR.named_parameters(): param_group += [{'params': v, 'lr': args.lr * args.lr_decay2}] netR.train() optimizer = optim.SGD(param_group) optimizer = op_copy(optimizer) max_iter = args.max_epoch * len(dset_loaders["target"]) interval_iter = max_iter // args.interval iter_num = 0 while iter_num < max_iter: optimizer.zero_grad() try: inputs_test, _, tar_idx = iter_test.next() except: iter_test = iter(dset_loaders["target"]) inputs_test, _, tar_idx = iter_test.next() if inputs_test.size(0) == 1: continue if iter_num % interval_iter == 0 and args.cls_par > 0: netF.eval() netB.eval() mem_label = obtain_label(dset_loaders['target_te'], netF, netB, netC, args) mem_label = torch.from_numpy(mem_label).cuda() netF.train() netB.train() inputs_test = inputs_test.cuda() iter_num += 1 lr_scheduler(optimizer, iter_num=iter_num, max_iter=max_iter) if args.cls_par > 0: pred = mem_label[tar_idx] features_test = netB(netF(inputs_test)) outputs_test = netC(features_test) if args.cls_par > 0: classifier_loss = nn.CrossEntropyLoss()(outputs_test, pred) classifier_loss *= args.cls_par if iter_num < interval_iter and args.dset == "VISDA-C": classifier_loss *= 0 else: classifier_loss = torch.tensor(0.0).cuda() if args.ent: softmax_out = nn.Softmax(dim=1)(outputs_test) entropy_loss = torch.mean(loss.Entropy(softmax_out)) if args.gent: msoftmax = softmax_out.mean(dim=0) gentropy_loss = torch.sum(-msoftmax * torch.log(msoftmax + args.epsilon)) entropy_loss -= gentropy_loss im_loss = entropy_loss * args.ent_par classifier_loss += im_loss classifier_loss.backward() if not args.ssl == 0: r_labels_target = np.random.randint(0, 4, len(inputs_test)) r_inputs_target = rotation.rotate_batch_with_labels(inputs_test, r_labels_target) r_labels_target = torch.from_numpy(r_labels_target).cuda() r_inputs_target = r_inputs_target.cuda() f_outputs = netB(netF(inputs_test)) f_outputs = f_outputs.detach() f_r_outputs = netB(netF(r_inputs_target)) r_outputs_target = netR(torch.cat((f_outputs, f_r_outputs), 1)) rotation_loss = args.ssl * nn.CrossEntropyLoss()(r_outputs_target, r_labels_target) rotation_loss.backward() optimizer.step() if iter_num % interval_iter == 0 or iter_num == max_iter: netF.eval() netB.eval() if args.dset=='VISDA-C': acc_s_te, acc_list = cal_acc(dset_loaders['test'], netF, netB, netC, True) log_str = 'Task: {}, Iter:{}/{}; Accuracy = {:.2f}%'.format(args.name, iter_num, max_iter, acc_s_te) + '\n' + acc_list else: acc_s_te, _ = cal_acc(dset_loaders['test'], netF, netB, netC, False) log_str = 'Task: {}, Iter:{}/{}; Accuracy = {:.2f}%'.format(args.name, iter_num, max_iter, acc_s_te) args.out_file.write(log_str + '\n') args.out_file.flush() print(log_str+'\n') netF.train() netB.train() if args.issave: torch.save(netF.state_dict(), osp.join(args.output_dir, "target_F_" + args.savename + ".pt")) torch.save(netB.state_dict(), osp.join(args.output_dir, "target_B_" + args.savename + ".pt")) torch.save(netC.state_dict(), osp.join(args.output_dir, "target_C_" + args.savename + ".pt")) return netF, netB, netC
def train_target_rot(args): dset_loaders = data_load(args) ## set base network if args.net[0:3] == 'res': netF = network.ResBase(res_name=args.net).cuda() elif args.net[0:3] == 'vgg': netF = network.VGGBase(vgg_name=args.net).cuda() netB = network.feat_bootleneck(type=args.classifier, feature_dim=netF.in_features, bottleneck_dim=args.bottleneck).cuda() netR = network.feat_classifier(type='linear', class_num=4, bottleneck_dim=2*args.bottleneck).cuda() modelpath = args.output_dir_src + '/source_F.pt' netF.load_state_dict(torch.load(modelpath)) netF.eval() for k, v in netF.named_parameters(): v.requires_grad = False modelpath = args.output_dir_src + '/source_B.pt' netB.load_state_dict(torch.load(modelpath)) netB.eval() for k, v in netB.named_parameters(): v.requires_grad = False param_group = [] for k, v in netR.named_parameters(): param_group += [{'params': v, 'lr': args.lr*1}] netR.train() optimizer = optim.SGD(param_group) optimizer = op_copy(optimizer) max_iter = args.max_epoch * len(dset_loaders["target"]) interval_iter = max_iter // 10 iter_num = 0 rot_acc = 0 while iter_num < max_iter: optimizer.zero_grad() try: inputs_test, _, tar_idx = iter_test.next() except: iter_test = iter(dset_loaders["target"]) inputs_test, _, tar_idx = iter_test.next() if inputs_test.size(0) == 1: continue inputs_test = inputs_test.cuda() iter_num += 1 lr_scheduler(optimizer, iter_num=iter_num, max_iter=max_iter) r_labels_target = np.random.randint(0, 4, len(inputs_test)) r_inputs_target = rotation.rotate_batch_with_labels(inputs_test, r_labels_target) r_labels_target = torch.from_numpy(r_labels_target).cuda() r_inputs_target = r_inputs_target.cuda() f_outputs = netB(netF(inputs_test)) f_r_outputs = netB(netF(r_inputs_target)) r_outputs_target = netR(torch.cat((f_outputs, f_r_outputs), 1)) rotation_loss = nn.CrossEntropyLoss()(r_outputs_target, r_labels_target) rotation_loss.backward() optimizer.step() if iter_num % interval_iter == 0 or iter_num == max_iter: netR.eval() acc_rot = cal_acc_rot(dset_loaders['target'], netF, netB, netR) log_str = 'Task: {}, Iter:{}/{}; Accuracy = {:.2f}%'.format(args.name, iter_num, max_iter, acc_rot) args.out_file.write(log_str + '\n') args.out_file.flush() print(log_str+'\n') netR.train() if rot_acc < acc_rot: rot_acc = acc_rot best_netR = netR.state_dict() log_str = 'Best Accuracy = {:.2f}%'.format(rot_acc) args.out_file.write(log_str + '\n') args.out_file.flush() print(log_str+'\n') return best_netR, rot_acc
def train_target(args): dset_loaders = data_load(args) netF = network.Res50().cuda() if args.ssl > 0: netR = network.feat_classifier(type='linear', class_num=4, bottleneck_dim=2 * 2048).cuda() netR_dict, acc_rot = train_target_rot(args) netR.load_state_dict(netR_dict) param_group = [] for k, v in netF.named_parameters(): if k.__contains__("fc"): v.requires_grad = False else: param_group += [{'params': v, 'lr': args.lr * args.lr_decay1}] if args.ssl > 0: for k, v in netR.named_parameters(): param_group += [{'params': v, 'lr': args.lr * args.lr_decay2}] netR.train() optimizer = optim.SGD(param_group) optimizer = op_copy(optimizer) max_iter = args.max_epoch * len(dset_loaders["target"]) interval_iter = max_iter // args.interval iter_num = 0 netF.train() while iter_num < max_iter: try: inputs_test, _, tar_idx = iter_test.next() except: iter_test = iter(dset_loaders["target"]) inputs_test, _, tar_idx = iter_test.next() if inputs_test.size(0) == 1: continue if iter_num % interval_iter == 0 and args.cls_par > 0: netF.eval() mem_label = obtain_label(dset_loaders['test'], netF, args) mem_label = torch.from_numpy(mem_label).cuda() netF.train() inputs_test = inputs_test.cuda() iter_num += 1 lr_scheduler(optimizer, iter_num=iter_num, max_iter=max_iter) features_test, outputs_test = netF(inputs_test) if args.cls_par > 0: pred = mem_label[tar_idx] classifier_loss = nn.CrossEntropyLoss()(outputs_test, pred) classifier_loss *= args.cls_par else: classifier_loss = torch.tensor(0.0).cuda() if args.ent: softmax_out = nn.Softmax(dim=1)(outputs_test) entropy_loss = torch.mean(loss.Entropy(softmax_out)) if args.gent: msoftmax = softmax_out.mean(dim=0) gentropy_loss = torch.sum(-msoftmax * torch.log(msoftmax + args.epsilon)) entropy_loss -= gentropy_loss classifier_loss += entropy_loss * args.ent_par optimizer.zero_grad() classifier_loss.backward() del features_test, outputs_test if args.ssl > 0: r_labels_target = np.random.randint(0, 4, len(inputs_test)) r_inputs_target = rotation.rotate_batch_with_labels( inputs_test, r_labels_target) r_labels_target = torch.from_numpy(r_labels_target).cuda() r_inputs_target = r_inputs_target.cuda() f_outputs, _ = netF(inputs_test) f_outputs = f_outputs.detach() f_r_outputs, _ = netF(r_inputs_target) r_outputs_target = netR(torch.cat((f_outputs, f_r_outputs), 1)) rotation_loss = args.ssl * nn.CrossEntropyLoss()(r_outputs_target, r_labels_target) rotation_loss.backward() optimizer.step() if iter_num % interval_iter == 0 or iter_num == max_iter: netF.eval() acc, ment = cal_acc(dset_loaders['test'], netF) log_str = 'Task: {}, Iter:{}/{}; Accuracy = {:.2f}%'.format( args.dset, iter_num, max_iter, acc * 100) args.out_file.write(log_str + '\n') args.out_file.flush() print(log_str + '\n') netF.train() if args.issave: torch.save( netF.state_dict(), osp.join(args.output_dir, "target_" + args.savename + ".pt")) return netF