def test_target(args): dset_loaders = digit_load(args) ## set base network if args.dset == 'u2m': netF = network.LeNetBase().cuda() elif args.dset == 'm2u': netF = network.LeNetBase().cuda() elif args.dset == 's2m': netF = network.DTNBase().cuda() elif args.dset == 'm2mm': netF = network.DTNBase().cuda() netB = network.feat_bootleneck(type=args.classifier, feature_dim=netF.in_features, bottleneck_dim=args.bottleneck).cuda() netC = network.feat_classifier(type=args.layer, class_num = args.class_num, bottleneck_dim=args.bottleneck).cuda() args.modelpath = args.output_dir + '/source_F.pt' netF.load_state_dict(torch.load(args.modelpath)) args.modelpath = args.output_dir + '/source_B.pt' netB.load_state_dict(torch.load(args.modelpath)) args.modelpath = args.output_dir + '/source_C.pt' netC.load_state_dict(torch.load(args.modelpath)) netF.eval() netB.eval() netC.eval() acc, _ = cal_acc(dset_loaders['test'], netF, netB, netC) log_str = 'Task: {}, Accuracy = {:.2f}%'.format(args.dset, acc) args.out_file.write(log_str + '\n') args.out_file.flush() print(log_str+'\n')
def train_target(args): dset_loaders = digit_load(args) ## set base network if args.dset == 'u2m': netF = network.LeNetBase().cuda() elif args.dset == 'm2u': netF = network.LeNetBase().cuda() elif args.dset == 's2m': netF = network.DTNBase().cuda() elif args.dset == 'm2mm': netF = network.DTNBase_c().cuda() elif args.dset == 's2u': netF = network.DTNBase_c().cuda() netB = network.feat_bootleneck_c().cuda() netC = network.feat_classifier_c().cuda() args.modelpath = args.output_dir + '/source_F.pt' netF.load_state_dict(torch.load(args.modelpath)) args.modelpath = args.output_dir + '/source_B.pt' netB.load_state_dict(torch.load(args.modelpath)) args.modelpath = args.output_dir + '/source_C.pt' netC.load_state_dict(torch.load(args.modelpath)) netC.eval() for k, v in netC.named_parameters(): v.requires_grad = False param_group = [] for k, v in netF.named_parameters(): param_group += [{'params': v, 'lr': args.lr}] for k, v in netB.named_parameters(): param_group += [{'params': v, 'lr': args.lr}] # optimizer = optim.SGD(param_group) optimizer = optim.Adam(param_group) optimizer = op_copy(optimizer) max_iter = args.max_epoch * len(dset_loaders["target"]) interval_iter = len(dset_loaders["target"]) # interval_iter = max_iter // args.interval iter_num = 0 while iter_num < max_iter: optimizer.zero_grad() try: inputs_test, _, tar_idx = iter_test.next() except: iter_test = iter(dset_loaders["target"]) inputs_test, _, tar_idx = iter_test.next() if inputs_test.size(0) == 1: continue if iter_num % interval_iter == 0 and args.cls_par > 0: netF.eval() netF.eval() mem_label = obtain_label(dset_loaders['target_te'], netF, netB, netC, args) mem_label = torch.from_numpy(mem_label).cuda() netF.train() netB.train() iter_num += 1 # lr_scheduler(optimizer, iter_num=iter_num, max_iter=max_iter) inputs_test = inputs_test.cuda() features_test = netB(netF(inputs_test)) outputs_test = netC(features_test) if args.cls_par > 0: pred = mem_label[tar_idx] classifier_loss = args.cls_par * nn.CrossEntropyLoss()( outputs_test, pred) else: classifier_loss = torch.tensor(0.0).cuda() if args.ent: softmax_out = nn.Softmax(dim=1)(outputs_test) entropy_loss = torch.mean(loss.Entropy(softmax_out)) if args.gent: msoftmax = softmax_out.mean(dim=0) entropy_loss -= torch.sum(-msoftmax * torch.log(msoftmax + 1e-5)) im_loss = entropy_loss * args.ent_par classifier_loss += im_loss optimizer.zero_grad() classifier_loss.backward() optimizer.step() if iter_num % interval_iter == 0 or iter_num == max_iter: netF.eval() netB.eval() acc, _ = cal_acc(dset_loaders['test'], netF, netB, netC) log_str = 'Task: {}, Iter:{}/{}; Accuracy = {:.2f}%'.format( args.dset, iter_num, max_iter, acc) args.out_file.write(log_str + '\n') args.out_file.flush() print(log_str + '\n') if args.issave: torch.save( netF.state_dict(), osp.join(args.output_dir, "target_F_" + args.savename + ".pt")) torch.save( netB.state_dict(), osp.join(args.output_dir, "target_B_" + args.savename + ".pt")) torch.save( netC.state_dict(), osp.join(args.output_dir, "target_C_" + args.savename + ".pt")) return netF, netB, netC
def train_source(args): dset_loaders = digit_load(args) ## set base network if args.dset == 'u2m': netF = network.LeNetBase().cuda() elif args.dset == 'm2u': netF = network.LeNetBase().cuda() elif args.dset == 's2m': netF = network.DTNBase().cuda() elif args.dset == 'm2mm': netF = network.DTNBase_c().cuda() elif args.dset == 's2u': netF = network.DTNBase_c().cuda() netB = network.feat_bootleneck_c().cuda() netC = network.feat_classifier_c().cuda() param_group = [] learning_rate = args.lr for k, v in netF.named_parameters(): param_group += [{'params': v, 'lr': learning_rate}] for k, v in netB.named_parameters(): param_group += [{'params': v, 'lr': learning_rate}] for k, v in netC.named_parameters(): param_group += [{'params': v, 'lr': learning_rate}] # optimizer = optim.SGD(param_group) optimizer = optim.Adam(param_group) optimizer = op_copy(optimizer) acc_init = 0 max_iter = args.max_epoch * len(dset_loaders["source_tr"]) interval_iter = max_iter // 10 iter_num = 0 netF.train() netB.train() netC.train() while iter_num < max_iter: try: inputs_source, labels_source = iter_source.next() except: iter_source = iter(dset_loaders["source_tr"]) inputs_source, labels_source = iter_source.next() if inputs_source.size(0) == 1: continue iter_num += 1 # lr_scheduler(optimizer, iter_num=iter_num, max_iter=max_iter) inputs_source, labels_source = inputs_source.cuda( ), labels_source.cuda() outputs_source = netC(netB(netF(inputs_source))) classifier_loss = loss.CrossEntropyLabelSmooth( num_classes=args.class_num, epsilon=args.smooth)(outputs_source, labels_source) optimizer.zero_grad() classifier_loss.backward() optimizer.step() if iter_num % interval_iter == 0 or iter_num == max_iter: netF.eval() netB.eval() netC.eval() acc_s_tr, _ = cal_acc(dset_loaders['source_tr'], netF, netB, netC) acc_s_te, _ = cal_acc(dset_loaders['source_te'], netF, netB, netC) log_str = 'Task: {}, Iter:{}/{}; Accuracy = {:.2f}%/ {:.2f}%'.format( args.dset, iter_num, max_iter, acc_s_tr, acc_s_te) args.out_file.write(log_str + '\n') args.out_file.flush() print(log_str + '\n') if acc_s_te >= acc_init: acc_init = acc_s_te best_netF = netF.state_dict() best_netB = netB.state_dict() best_netC = netC.state_dict() torch.save(best_netF, osp.join(args.output_dir, "source_F.pt")) torch.save(best_netB, osp.join(args.output_dir, "source_B.pt")) torch.save(best_netC, osp.join(args.output_dir, "source_C.pt")) return netF, netB, netC
#sigma 0.25 0.932 0.842 0 0 0 0 0 # 0.5 0.808 0.538 0.245 0 0 0 0 # 1 0.262 0.122 0.049 0.020 0.002 0.000 0.000 if __name__ == "__main__": # load the base classifier device = 'cuda:0' args.dset = 'm2mm' # args.dset = 's2u' print(args.dset, args.sigma) if args.dset == 'u2m': netF = network.LeNetBase().cuda() elif args.dset == 'm2u': netF = network.LeNetBase().cuda() elif args.dset == 's2m': netF = network.DTNBase().cuda() elif args.dset == 'm2mm': netF = network.DTNBase_c().cuda() elif args.dset == 's2u': netF = network.DTNBase_c().cuda() print(args.dset) netB = network.feat_bootleneck_c().cuda() netC = network.feat_classifier_c().cuda() args.output_dir = osp.join(args.output, 'seed' + str(args.seed), args.dset) args.modelpath = args.output_dir + '/target_F_par_0.3.pt' netF.load_state_dict(torch.load(args.modelpath)) args.modelpath = args.output_dir + '/target_B_par_0.3.pt' netB.load_state_dict(torch.load(args.modelpath)) args.modelpath = args.output_dir + '/target_C_par_0.3.pt' netC.load_state_dict(torch.load(args.modelpath))