Exemple #1
0
def digit_load(args): 
    train_bs = args.batch_size
    if args.dset == 's':
        train_source = svhn.SVHN('./data/svhn/', split='train', download=True,
                transform=transforms.Compose([
                    transforms.Resize(32),
                    transforms.ToTensor(),
                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
                ]))
        test_source = svhn.SVHN('./data/svhn/', split='test', download=True,
                transform=transforms.Compose([
                    transforms.Resize(32),
                    transforms.ToTensor(),
                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
                ]))  
    elif args.dset == 'u':
        train_source = usps.USPS('./data/usps/', train=True, download=True,
                transform=transforms.Compose([
                    transforms.RandomCrop(28, padding=4),
                    transforms.RandomRotation(10),
                    transforms.ToTensor(),
                    transforms.Normalize((0.5,), (0.5,))
                ]))
        test_source = usps.USPS('./data/usps/', train=False, download=True,
                transform=transforms.Compose([
                    transforms.RandomCrop(28, padding=4),
                    transforms.RandomRotation(10),
                    transforms.ToTensor(),
                    transforms.Normalize((0.5,), (0.5,))
                ]))    
    elif args.dset == 'm':
        train_source = mnist.MNIST('./data/mnist/', train=True, download=True,
                transform=transforms.Compose([
                    transforms.ToTensor(),
                    transforms.Normalize((0.5,), (0.5,))
                ]))
        test_source = mnist.MNIST('./data/mnist/', train=False, download=True,
                transform=transforms.Compose([
                    transforms.ToTensor(),
                    transforms.Normalize((0.5,), (0.5,))
                ]))

    dset_loaders = {}
    dset_loaders["train"] = DataLoader(train_source, batch_size=train_bs, shuffle=True, 
        num_workers=args.worker, drop_last=False)
    dset_loaders["test"] = DataLoader(test_source, batch_size=train_bs*2, shuffle=False, 
        num_workers=args.worker, drop_last=False)
    return dset_loaders
            transforms.ToTensor(),
            transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
        ])

        test_list = os.path.join(target_image_root, 'mnist_m_test_labels.txt')

        dataset_target_test = GetLoader(
            data_root=os.path.join(target_image_root, 'mnist_m_test'),
            data_list=test_list,
            transform=img_transform_target
        )
    elif target_dataset_name=='usps':
        args.base_classifier = "outputs/garbage1_s2u/best_model.pt"
        dataset_target_test = usps.USPS('./data/usps/', train=False, download=True,
                                        transform=transforms.Compose([
                                            transforms.ToTensor(),
                                            transforms.Normalize((0.5,), (0.5,))
                                        ]))


    target_test_loader = torch.utils.data.DataLoader(
        dataset=dataset_target_test,
        batch_size=batch_size,
        shuffle=False,
        num_workers=8
    )


    checkpoint = torch.load(args.base_classifier)
    base_classifier = CNN(in_channels=3, target=True).to(device)
    base_classifier.load_state_dict(checkpoint['model'])
def test(dataset_name,model_root):
    cuda = True
    cudnn.benchmark = True
    batch_size = 128
    image_size = 28
    alpha = 0
    if model_root == 'model_mm':
        assert dataset_name in ['MNIST', 'mnist_m']
        image_root = os.path.join('dataset', dataset_name)

        """load data"""

        img_transform_source = transforms.Compose([
            transforms.Resize(image_size),
            transforms.ToTensor(),
            transforms.Normalize(mean=(0.1307,), std=(0.3081,))
        ])

        img_transform_target = transforms.Compose([
            transforms.Resize(image_size),
            transforms.ToTensor(),
            transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
        ])

        if dataset_name == 'mnist_m':
            test_list = os.path.join(image_root, 'mnist_m_test_labels.txt')

            dataset = GetLoader(
                data_root=os.path.join(image_root, 'mnist_m_test'),
                data_list=test_list,
                transform=img_transform_target
            )
        else:
            dataset = datasets.MNIST(
                root='dataset',
                train=False,
                transform=img_transform_source,
            )
    elif model_root == 'model_su':
        if dataset_name == 'svhn':
            dataset = svhn.SVHN('./data/svhn/', split='test', download=True,
                                transform=transforms.Compose([
                                    transforms.Resize(28),
                                    transforms.ToTensor(),
                                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
                                ]))
        elif dataset_name == 'usps':
            dataset = usps.USPS('./data/usps/', train=False, download=True,
                                transform=transforms.Compose([
                                    transforms.ToTensor(),
                                    transforms.Normalize((0.5,), (0.5,))
                                ]))
    dataloader = torch.utils.data.DataLoader(
        dataset=dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=8
    )

    """ test """

    my_net = torch.load(os.path.join(
        model_root, 'mnist_mnistm_model_epoch_current.pth'
    ))
    my_net = my_net.eval()

    if cuda:
        my_net = my_net.cuda()

    len_dataloader = len(dataloader)
    data_target_iter = iter(dataloader)

    i = 0
    n_total = 0
    n_correct = 0

    while i < len_dataloader:

        # test model using target data
        data_target = data_target_iter.next()
        t_img, t_label = data_target

        batch_size = len(t_label)

        if cuda:
            t_img = t_img.cuda()
            t_label = t_label.cuda()

        class_output = my_net(input_data=t_img, alpha=alpha)
        pred = class_output.data.max(1, keepdim=True)[1]
        n_correct += pred.eq(t_label.data.view_as(pred)).cpu().sum()
        n_total += batch_size

        i += 1

    accu = n_correct.data.numpy() * 1.0 / n_total

    return accu
def split_target(args):
    train_bs = args.batch_size
    if args.dset == 's2m':
        train_target = mnist.MNIST(
            './data/mnist/',
            train=True,
            download=True,
            transform=transforms.Compose([
                transforms.Resize(32),
                transforms.Lambda(lambda x: x.convert("RGB")),
                transforms.ToTensor(),
                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
            ]))
        train_target2 = mnist.MNIST_twice(
            './data/mnist/',
            train=True,
            download=True,
            transform=transforms.Compose([
                transforms.Resize(32),
                transforms.Lambda(lambda x: x.convert("RGB")),
                transforms.ToTensor(),
                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
            ]))
        test_target = mnist.MNIST(
            './data/mnist/',
            train=False,
            download=True,
            transform=transforms.Compose([
                transforms.Resize(32),
                transforms.Lambda(lambda x: x.convert("RGB")),
                transforms.ToTensor(),
                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
            ]))
    elif args.dset == 'u2m':
        train_target = mnist.MNIST('./data/mnist/',
                                   train=True,
                                   download=True,
                                   transform=transforms.Compose([
                                       transforms.ToTensor(),
                                       transforms.Normalize((0.5, ), (0.5, ))
                                   ]))
        train_target2 = mnist.MNIST_twice('./data/mnist/',
                                          train=True,
                                          download=True,
                                          transform=transforms.Compose([
                                              transforms.ToTensor(),
                                              transforms.Normalize((0.5, ),
                                                                   (0.5, ))
                                          ]))
        test_target = mnist.MNIST('./data/mnist/',
                                  train=False,
                                  download=True,
                                  transform=transforms.Compose([
                                      transforms.ToTensor(),
                                      transforms.Normalize((0.5, ), (0.5, ))
                                  ]))
    elif args.dset == 'm2u':
        train_target = usps.USPS(
            './data/usps/',
            train=True,
            download=True,
            transform=transforms.Compose([
                transforms.ToTensor(),
                # transforms.Lambda(lambda x: _gaussian_blur(x, sigma=0.1)),
                transforms.Normalize((0.5, ), (0.5, ))
            ]))
        train_target2 = usps.USPS_twice(
            './data/usps/',
            train=True,
            download=True,
            transform=transforms.Compose([
                transforms.ToTensor(),
                # transforms.Lambda(lambda x: _gaussian_blur(x, sigma=0.1)),
                transforms.Normalize((0.5, ), (0.5, ))
            ]))
        test_target = usps.USPS(
            './data/usps/',
            train=False,
            download=True,
            transform=transforms.Compose([
                transforms.ToTensor(),
                # transforms.Lambda(lambda x: _gaussian_blur(x, sigma=0.1)),
                transforms.Normalize((0.5, ), (0.5, ))
            ]))
    dset_loaders = {}
    dset_loaders["target_te"] = DataLoader(test_target,
                                           batch_size=train_bs,
                                           shuffle=False,
                                           num_workers=args.worker,
                                           drop_last=False)
    dset_loaders["target"] = DataLoader(train_target,
                                        batch_size=train_bs,
                                        shuffle=False,
                                        num_workers=args.worker,
                                        drop_last=False)
    dset_loaders["target2"] = DataLoader(train_target2,
                                         batch_size=train_bs,
                                         shuffle=False,
                                         num_workers=args.worker,
                                         drop_last=False)

    if args.dset == 'u2m':
        netF = network.LeNetBase().cuda()
    elif args.dset == 'm2u':
        netF = network.LeNetBase().cuda()
    elif args.dset == 's2m':
        netF = network.DTNBase().cuda()

    netB = network.feat_bootleneck(type=args.classifier,
                                   feature_dim=netF.in_features,
                                   bottleneck_dim=args.bottleneck).cuda()
    netC = network.feat_classifier(type=args.layer,
                                   class_num=args.class_num,
                                   bottleneck_dim=args.bottleneck).cuda()

    if args.model == 'source':
        modelpath = args.output_dir + "/source_F.pt"
        netF.load_state_dict(torch.load(modelpath))
        modelpath = args.output_dir + "/source_B.pt"
        netB.load_state_dict(torch.load(modelpath))
        modelpath = args.output_dir + "/source_C.pt"
        netC.load_state_dict(torch.load(modelpath))
        pass
    else:
        modelpath = args.output_dir + "/target_F_" + args.savename + ".pt"
        netF.load_state_dict(torch.load(modelpath))
        modelpath = args.output_dir + "/target_B_" + args.savename + ".pt"
        netB.load_state_dict(torch.load(modelpath))
        modelpath = args.output_dir + "/target_C_" + args.savename + ".pt"
        netC.load_state_dict(torch.load(modelpath))

    netF.eval()
    netB.eval()
    netC.eval()

    start_test = True
    with torch.no_grad():
        iter_test = iter(dset_loaders['target_te'])
        for i in range(len(dset_loaders['target_te'])):
            data = iter_test.next()
            # pdb.set_trace()
            inputs = data[0]
            labels = data[1]
            inputs = inputs.cuda()
            outputs = netC(netB(netF(inputs)))
            if start_test:
                all_output = outputs.float().cpu()
                all_label = labels.float()
                start_test = False
            else:
                all_output = torch.cat((all_output, outputs.float().cpu()), 0)
                all_label = torch.cat((all_label, labels.float()), 0)
    top_pred, predict = torch.max(all_output, 1)
    acc = torch.sum(
        torch.squeeze(predict).float() == all_label).item() / float(
            all_label.size()[0]) * 100
    mean_ent = loss.Entropy(nn.Softmax(dim=1)(all_output))
    log_str = 'Task: {}, Iter:{}/{}; Accuracy = {:.2f}%; Mean Ent = {:.4f}'.format(
        args.dset + '_test', 0, 0, acc, mean_ent.mean())
    args.out_file.write(log_str + '\n')
    args.out_file.flush()
    print(log_str + '\n')

    start_test = True
    with torch.no_grad():
        iter_test = iter(dset_loaders['target'])
        for i in range(len(dset_loaders['target'])):
            data = iter_test.next()
            # pdb.set_trace()
            inputs = data[0]
            labels = data[1]
            inputs = inputs.cuda()
            outputs = netC(netB(netF(inputs)))
            if start_test:
                all_output = outputs.float().cpu()
                all_label = labels.float()
                start_test = False
            else:
                all_output = torch.cat((all_output, outputs.float().cpu()), 0)
                all_label = torch.cat((all_label, labels.float()), 0)
    top_pred, predict = torch.max(all_output, 1)
    acc = torch.sum(
        torch.squeeze(predict).float() == all_label).item() / float(
            all_label.size()[0]) * 100
    mean_ent = loss.Entropy(nn.Softmax(dim=1)(all_output))

    log_str = 'Task: {}, Iter:{}/{}; Accuracy = {:.2f}%; Mean Ent = {:.4f}'.format(
        args.dset + '_train', 0, 0, acc, mean_ent.mean())
    args.out_file.write(log_str + '\n')
    args.out_file.flush()
    print(log_str + '\n')

    if args.ps == 0:
        est_p = (mean_ent < mean_ent.mean()).sum().item() / mean_ent.size(0)
        log_str = 'Task: {:.2f}'.format(est_p)
        print(log_str + '\n')
        args.out_file.write(log_str + '\n')
        args.out_file.flush()
        PS = est_p
    else:
        PS = args.ps

    if args.choice == "ent":
        value = mean_ent
    elif args.choice == "maxp":
        value = -top_pred
    elif args.choice == "marginp":
        pred, _ = torch.sort(all_output, 1)
        value = pred[:, 1] - pred[:, 0]
    else:
        value = torch.rand(len(mean_ent))

    predict = predict.numpy()
    train_idx = np.zeros(predict.shape)

    cls_k = args.class_num
    for c in range(cls_k):
        c_idx = np.where(predict == c)
        c_idx = c_idx[0]
        c_value = value[c_idx]

        _, idx_ = torch.sort(c_value)
        c_num = len(idx_)
        c_num_s = int(c_num * PS)
        # print(c, c_num, c_num_s)

        for ei in range(0, c_num_s):
            ee = c_idx[idx_[ei]]
            train_idx[ee] = 1

    train_target.targets = predict
    new_src = copy.deepcopy(train_target)
    new_tar = copy.deepcopy(train_target2)

    # pdb.set_trace()

    if args.dset == 'm2u':

        new_src.train_data = np.delete(new_src.train_data,
                                       np.where(train_idx == 0)[0],
                                       axis=0)
        new_src.train_labels = np.delete(new_src.train_labels,
                                         np.where(train_idx == 0)[0],
                                         axis=0)

        new_tar.train_data = np.delete(new_tar.train_data,
                                       np.where(train_idx == 1)[0],
                                       axis=0)
        new_tar.train_labels = np.delete(new_tar.train_labels,
                                         np.where(train_idx == 1)[0],
                                         axis=0)

    else:

        new_src.data = np.delete(new_src.data,
                                 np.where(train_idx == 0)[0],
                                 axis=0)
        new_src.targets = np.delete(new_src.targets,
                                    np.where(train_idx == 0)[0],
                                    axis=0)

        new_tar.data = np.delete(new_tar.data,
                                 np.where(train_idx == 1)[0],
                                 axis=0)
        new_tar.targets = np.delete(new_tar.targets,
                                    np.where(train_idx == 1)[0],
                                    axis=0)

    # pdb.set_trace()

    return new_src, new_tar
def data_load(args, txt_src, txt_tgt):

    if args.dset == 's2m':
        train_target = mnist.MNIST(
            './data/mnist/',
            train=True,
            download=True,
            transform=transforms.Compose([
                transforms.Resize(32),
                transforms.Lambda(lambda x: x.convert("RGB")),
                transforms.ToTensor(),
                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
            ]))
        test_target = mnist.MNIST(
            './data/mnist/',
            train=False,
            download=True,
            transform=transforms.Compose([
                transforms.Resize(32),
                transforms.Lambda(lambda x: x.convert("RGB")),
                transforms.ToTensor(),
                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
            ]))
    elif args.dset == 'u2m':
        train_target = mnist.MNIST('./data/mnist/',
                                   train=True,
                                   download=True,
                                   transform=transforms.Compose([
                                       transforms.ToTensor(),
                                       transforms.Normalize((0.5, ), (0.5, ))
                                   ]))
        test_target = mnist.MNIST('./data/mnist/',
                                  train=False,
                                  download=True,
                                  transform=transforms.Compose([
                                      transforms.ToTensor(),
                                      transforms.Normalize((0.5, ), (0.5, ))
                                  ]))
    elif args.dset == 'm2u':
        train_target = usps.USPS('./data/usps/',
                                 train=True,
                                 download=True,
                                 transform=transforms.Compose([
                                     transforms.ToTensor(),
                                     transforms.Normalize((0.5, ), (0.5, ))
                                 ]))
        test_target = usps.USPS('./data/usps/',
                                train=False,
                                download=True,
                                transform=transforms.Compose([
                                    transforms.ToTensor(),
                                    transforms.Normalize((0.5, ), (0.5, ))
                                ]))

    dset_loaders = {}
    dset_loaders["train"] = DataLoader(train_target,
                                       batch_size=args.batch_size * 2,
                                       shuffle=False,
                                       num_workers=args.worker,
                                       drop_last=False)
    dset_loaders["test"] = DataLoader(test_target,
                                      batch_size=args.batch_size * 2,
                                      shuffle=False,
                                      num_workers=args.worker,
                                      drop_last=False)
    dset_loaders["source"] = DataLoader(txt_src,
                                        batch_size=args.batch_size,
                                        shuffle=True,
                                        num_workers=args.worker,
                                        drop_last=True)
    dset_loaders["target"] = DataLoader(txt_tgt,
                                        batch_size=args.batch_size,
                                        shuffle=True,
                                        num_workers=args.worker,
                                        drop_last=True)

    return dset_loaders
def run(args):
    args.logdir = args.logdir + args.mode
    args.trained = args.trained + args.mode + '/best_model.pt'
    if not os.path.exists(args.logdir):
        os.makedirs(args.logdir)
    logger = get_logger(os.path.join(args.logdir, 'main.log'))
    logger.info(args)

    # data
    # source_transform = transforms.Compose([
    #     # transforms.Grayscale(),
    #     transforms.ToTensor()]
    # )
    # target_transform = transforms.Compose([
    #     transforms.Resize(32),
    #     transforms.ToTensor(),
    #     transforms.Lambda(lambda x: x.repeat(3, 1, 1))
    # ])
    # source_dataset_train = SVHN(
    #     './input', 'train', transform=source_transform, download=True)
    # target_dataset_train = MNIST(
    #     './input', train=True, transform=target_transform, download=True)
    # target_dataset_test = MNIST(
    #     './input', train=False, transform=target_transform, download=True)
    # source_train_loader = DataLoader(
    #     source_dataset_train, args.batch_size, shuffle=True,
    #     drop_last=True,
    #     num_workers=args.n_workers)
    # target_train_loader = DataLoader(
    #     target_dataset_train, args.batch_size, shuffle=True,
    #     drop_last=True,
    #     num_workers=args.n_workers)
    # target_test_loader = DataLoader(
    #     target_dataset_test, args.batch_size, shuffle=False,
    #     num_workers=args.n_workers)
    batch_size = 128
    if args.mode == 'm2mm':
        source_dataset_name = 'MNIST'
        target_dataset_name = 'mnist_m'
        source_image_root = os.path.join('dataset', source_dataset_name)
        target_image_root = os.path.join('dataset', target_dataset_name)
        image_size = 28
        img_transform_source = transforms.Compose([
            transforms.Resize(image_size),
            transforms.ToTensor(),
            transforms.Normalize(mean=(0.1307, ), std=(0.3081, ))
        ])

        img_transform_target = transforms.Compose([
            transforms.Resize(image_size),
            transforms.ToTensor(),
            transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
        ])

        dataset_source = datasets.MNIST(root='dataset',
                                        train=True,
                                        transform=img_transform_source,
                                        download=True)

        train_list = os.path.join(target_image_root,
                                  'mnist_m_train_labels.txt')

        dataset_target_train = GetLoader(data_root=os.path.join(
            target_image_root, 'mnist_m_train'),
                                         data_list=train_list,
                                         transform=img_transform_target)

        test_list = os.path.join(target_image_root, 'mnist_m_test_labels.txt')

        dataset_target_test = GetLoader(data_root=os.path.join(
            target_image_root, 'mnist_m_test'),
                                        data_list=test_list,
                                        transform=img_transform_target)
    elif args.mode == 's2u':
        dataset_source = svhn.SVHN('./data/svhn/',
                                   split='train',
                                   download=True,
                                   transform=transforms.Compose([
                                       transforms.Resize(28),
                                       transforms.ToTensor(),
                                       transforms.Normalize((0.5, 0.5, 0.5),
                                                            (0.5, 0.5, 0.5))
                                   ]))

        dataset_target_train = usps.USPS('./data/usps/',
                                         train=True,
                                         download=True,
                                         transform=transforms.Compose([
                                             transforms.ToTensor(),
                                             transforms.Normalize((0.5, ),
                                                                  (0.5, ))
                                         ]))
        dataset_target_test = usps.USPS('./data/usps/',
                                        train=False,
                                        download=True,
                                        transform=transforms.Compose([
                                            transforms.ToTensor(),
                                            transforms.Normalize((0.5, ),
                                                                 (0.5, ))
                                        ]))
        source_dataset_name = 'svhn'
        target_dataset_name = 'usps'

    source_train_loader = torch.utils.data.DataLoader(dataset=dataset_source,
                                                      batch_size=batch_size,
                                                      shuffle=True,
                                                      num_workers=8)

    target_train_loader = torch.utils.data.DataLoader(
        dataset=dataset_target_train,
        batch_size=batch_size,
        shuffle=True,
        num_workers=8)

    target_test_loader = torch.utils.data.DataLoader(
        dataset=dataset_target_test,
        batch_size=batch_size,
        shuffle=False,
        num_workers=8)
    # train source CNN
    source_cnn = CNN(in_channels=args.in_channels).to(args.device)
    if os.path.isfile(args.trained):
        print("load model")
        c = torch.load(args.trained)
        source_cnn.load_state_dict(c['model'])
        logger.info('Loaded `{}`'.format(args.trained))
    else:
        print("not load model")

    # train target CNN
    target_cnn = CNN(in_channels=args.in_channels, target=True).to(args.device)
    target_cnn.load_state_dict(source_cnn.state_dict())
    discriminator = Discriminator(args=args).to(args.device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(target_cnn.encoder.parameters(), lr=args.lr)
    # optimizer = optim.Adam(
    #     target_cnn.encoder.parameters(),
    #     lr=args.lr, betas=args.betas, weight_decay=args.weight_decay)
    d_optimizer = optim.Adam(discriminator.parameters(), lr=args.lr)
    # d_optimizer = optim.Adam(
    #     discriminator.parameters(),
    #     lr=args.lr, betas=args.betas, weight_decay=args.weight_decay)
    train_target_cnn(source_cnn,
                     target_cnn,
                     discriminator,
                     criterion,
                     optimizer,
                     d_optimizer,
                     source_train_loader,
                     target_train_loader,
                     target_test_loader,
                     args=args)