def digit_load(args): train_bs = args.batch_size if args.dset == 's': train_source = svhn.SVHN('./data/svhn/', split='train', download=True, transform=transforms.Compose([ transforms.Resize(32), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ])) test_source = svhn.SVHN('./data/svhn/', split='test', download=True, transform=transforms.Compose([ transforms.Resize(32), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ])) elif args.dset == 'u': train_source = usps.USPS('./data/usps/', train=True, download=True, transform=transforms.Compose([ transforms.RandomCrop(28, padding=4), transforms.RandomRotation(10), transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,)) ])) test_source = usps.USPS('./data/usps/', train=False, download=True, transform=transforms.Compose([ transforms.RandomCrop(28, padding=4), transforms.RandomRotation(10), transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,)) ])) elif args.dset == 'm': train_source = mnist.MNIST('./data/mnist/', train=True, download=True, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,)) ])) test_source = mnist.MNIST('./data/mnist/', train=False, download=True, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,)) ])) dset_loaders = {} dset_loaders["train"] = DataLoader(train_source, batch_size=train_bs, shuffle=True, num_workers=args.worker, drop_last=False) dset_loaders["test"] = DataLoader(test_source, batch_size=train_bs*2, shuffle=False, num_workers=args.worker, drop_last=False) return dset_loaders
def split_target(args): train_bs = args.batch_size if args.dset == 's2m': train_target = mnist.MNIST( './data/mnist/', train=True, download=True, transform=transforms.Compose([ transforms.Resize(32), transforms.Lambda(lambda x: x.convert("RGB")), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ])) train_target2 = mnist.MNIST_twice( './data/mnist/', train=True, download=True, transform=transforms.Compose([ transforms.Resize(32), transforms.Lambda(lambda x: x.convert("RGB")), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ])) test_target = mnist.MNIST( './data/mnist/', train=False, download=True, transform=transforms.Compose([ transforms.Resize(32), transforms.Lambda(lambda x: x.convert("RGB")), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ])) elif args.dset == 'u2m': train_target = mnist.MNIST('./data/mnist/', train=True, download=True, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.5, ), (0.5, )) ])) train_target2 = mnist.MNIST_twice('./data/mnist/', train=True, download=True, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.5, ), (0.5, )) ])) test_target = mnist.MNIST('./data/mnist/', train=False, download=True, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.5, ), (0.5, )) ])) elif args.dset == 'm2u': train_target = usps.USPS( './data/usps/', train=True, download=True, transform=transforms.Compose([ transforms.ToTensor(), # transforms.Lambda(lambda x: _gaussian_blur(x, sigma=0.1)), transforms.Normalize((0.5, ), (0.5, )) ])) train_target2 = usps.USPS_twice( './data/usps/', train=True, download=True, transform=transforms.Compose([ transforms.ToTensor(), # transforms.Lambda(lambda x: _gaussian_blur(x, sigma=0.1)), transforms.Normalize((0.5, ), (0.5, )) ])) test_target = usps.USPS( './data/usps/', train=False, download=True, transform=transforms.Compose([ transforms.ToTensor(), # transforms.Lambda(lambda x: _gaussian_blur(x, sigma=0.1)), transforms.Normalize((0.5, ), (0.5, )) ])) dset_loaders = {} dset_loaders["target_te"] = DataLoader(test_target, batch_size=train_bs, shuffle=False, num_workers=args.worker, drop_last=False) dset_loaders["target"] = DataLoader(train_target, batch_size=train_bs, shuffle=False, num_workers=args.worker, drop_last=False) dset_loaders["target2"] = DataLoader(train_target2, batch_size=train_bs, shuffle=False, num_workers=args.worker, drop_last=False) if args.dset == 'u2m': netF = network.LeNetBase().cuda() elif args.dset == 'm2u': netF = network.LeNetBase().cuda() elif args.dset == 's2m': netF = network.DTNBase().cuda() netB = network.feat_bootleneck(type=args.classifier, feature_dim=netF.in_features, bottleneck_dim=args.bottleneck).cuda() netC = network.feat_classifier(type=args.layer, class_num=args.class_num, bottleneck_dim=args.bottleneck).cuda() if args.model == 'source': modelpath = args.output_dir + "/source_F.pt" netF.load_state_dict(torch.load(modelpath)) modelpath = args.output_dir + "/source_B.pt" netB.load_state_dict(torch.load(modelpath)) modelpath = args.output_dir + "/source_C.pt" netC.load_state_dict(torch.load(modelpath)) pass else: modelpath = args.output_dir + "/target_F_" + args.savename + ".pt" netF.load_state_dict(torch.load(modelpath)) modelpath = args.output_dir + "/target_B_" + args.savename + ".pt" netB.load_state_dict(torch.load(modelpath)) modelpath = args.output_dir + "/target_C_" + args.savename + ".pt" netC.load_state_dict(torch.load(modelpath)) netF.eval() netB.eval() netC.eval() start_test = True with torch.no_grad(): iter_test = iter(dset_loaders['target_te']) for i in range(len(dset_loaders['target_te'])): data = iter_test.next() # pdb.set_trace() inputs = data[0] labels = data[1] inputs = inputs.cuda() outputs = netC(netB(netF(inputs))) if start_test: all_output = outputs.float().cpu() all_label = labels.float() start_test = False else: all_output = torch.cat((all_output, outputs.float().cpu()), 0) all_label = torch.cat((all_label, labels.float()), 0) top_pred, predict = torch.max(all_output, 1) acc = torch.sum( torch.squeeze(predict).float() == all_label).item() / float( all_label.size()[0]) * 100 mean_ent = loss.Entropy(nn.Softmax(dim=1)(all_output)) log_str = 'Task: {}, Iter:{}/{}; Accuracy = {:.2f}%; Mean Ent = {:.4f}'.format( args.dset + '_test', 0, 0, acc, mean_ent.mean()) args.out_file.write(log_str + '\n') args.out_file.flush() print(log_str + '\n') start_test = True with torch.no_grad(): iter_test = iter(dset_loaders['target']) for i in range(len(dset_loaders['target'])): data = iter_test.next() # pdb.set_trace() inputs = data[0] labels = data[1] inputs = inputs.cuda() outputs = netC(netB(netF(inputs))) if start_test: all_output = outputs.float().cpu() all_label = labels.float() start_test = False else: all_output = torch.cat((all_output, outputs.float().cpu()), 0) all_label = torch.cat((all_label, labels.float()), 0) top_pred, predict = torch.max(all_output, 1) acc = torch.sum( torch.squeeze(predict).float() == all_label).item() / float( all_label.size()[0]) * 100 mean_ent = loss.Entropy(nn.Softmax(dim=1)(all_output)) log_str = 'Task: {}, Iter:{}/{}; Accuracy = {:.2f}%; Mean Ent = {:.4f}'.format( args.dset + '_train', 0, 0, acc, mean_ent.mean()) args.out_file.write(log_str + '\n') args.out_file.flush() print(log_str + '\n') if args.ps == 0: est_p = (mean_ent < mean_ent.mean()).sum().item() / mean_ent.size(0) log_str = 'Task: {:.2f}'.format(est_p) print(log_str + '\n') args.out_file.write(log_str + '\n') args.out_file.flush() PS = est_p else: PS = args.ps if args.choice == "ent": value = mean_ent elif args.choice == "maxp": value = -top_pred elif args.choice == "marginp": pred, _ = torch.sort(all_output, 1) value = pred[:, 1] - pred[:, 0] else: value = torch.rand(len(mean_ent)) predict = predict.numpy() train_idx = np.zeros(predict.shape) cls_k = args.class_num for c in range(cls_k): c_idx = np.where(predict == c) c_idx = c_idx[0] c_value = value[c_idx] _, idx_ = torch.sort(c_value) c_num = len(idx_) c_num_s = int(c_num * PS) # print(c, c_num, c_num_s) for ei in range(0, c_num_s): ee = c_idx[idx_[ei]] train_idx[ee] = 1 train_target.targets = predict new_src = copy.deepcopy(train_target) new_tar = copy.deepcopy(train_target2) # pdb.set_trace() if args.dset == 'm2u': new_src.train_data = np.delete(new_src.train_data, np.where(train_idx == 0)[0], axis=0) new_src.train_labels = np.delete(new_src.train_labels, np.where(train_idx == 0)[0], axis=0) new_tar.train_data = np.delete(new_tar.train_data, np.where(train_idx == 1)[0], axis=0) new_tar.train_labels = np.delete(new_tar.train_labels, np.where(train_idx == 1)[0], axis=0) else: new_src.data = np.delete(new_src.data, np.where(train_idx == 0)[0], axis=0) new_src.targets = np.delete(new_src.targets, np.where(train_idx == 0)[0], axis=0) new_tar.data = np.delete(new_tar.data, np.where(train_idx == 1)[0], axis=0) new_tar.targets = np.delete(new_tar.targets, np.where(train_idx == 1)[0], axis=0) # pdb.set_trace() return new_src, new_tar
def data_load(args, txt_src, txt_tgt): if args.dset == 's2m': train_target = mnist.MNIST( './data/mnist/', train=True, download=True, transform=transforms.Compose([ transforms.Resize(32), transforms.Lambda(lambda x: x.convert("RGB")), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ])) test_target = mnist.MNIST( './data/mnist/', train=False, download=True, transform=transforms.Compose([ transforms.Resize(32), transforms.Lambda(lambda x: x.convert("RGB")), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ])) elif args.dset == 'u2m': train_target = mnist.MNIST('./data/mnist/', train=True, download=True, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.5, ), (0.5, )) ])) test_target = mnist.MNIST('./data/mnist/', train=False, download=True, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.5, ), (0.5, )) ])) elif args.dset == 'm2u': train_target = usps.USPS('./data/usps/', train=True, download=True, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.5, ), (0.5, )) ])) test_target = usps.USPS('./data/usps/', train=False, download=True, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.5, ), (0.5, )) ])) dset_loaders = {} dset_loaders["train"] = DataLoader(train_target, batch_size=args.batch_size * 2, shuffle=False, num_workers=args.worker, drop_last=False) dset_loaders["test"] = DataLoader(test_target, batch_size=args.batch_size * 2, shuffle=False, num_workers=args.worker, drop_last=False) dset_loaders["source"] = DataLoader(txt_src, batch_size=args.batch_size, shuffle=True, num_workers=args.worker, drop_last=True) dset_loaders["target"] = DataLoader(txt_tgt, batch_size=args.batch_size, shuffle=True, num_workers=args.worker, drop_last=True) return dset_loaders