Пример #1
0
transform_train = transforms.Compose([
    transforms.RandomResizedCrop(size=32, scale=(0.2, 1.)),
    transforms.ColorJitter(0.4, 0.4, 0.4, 0.4),
    transforms.RandomGrayscale(p=0.2),
    #transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

trainset = datasets.CIFAR10Instance(root='./data',
                                    train=True,
                                    download=True,
                                    transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset,
                                          batch_size=128,
                                          shuffle=True,
                                          num_workers=2)

testset = datasets.CIFAR10Instance(root='./data',
                                   train=False,
                                   download=True,
                                   transform=transform_test)
testloader = torch.utils.data.DataLoader(testset,
                                         batch_size=100,
                                         shuffle=False,
                                         num_workers=2)
Пример #2
0
def get_dataloader(args, add_erasing):
    if 'cifar' in args.dataset or 'kitchen' in args.dataset:
        transform_train_list = [
            transforms.RandomResizedCrop(size=32, scale=(0.2, 1.)),
            transforms.ColorJitter(0.4, 0.4, 0.4, 0.4),
            transforms.RandomGrayscale(p=0.2),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465),
                                 (0.2023, 0.1994, 0.2010)),
        ]
        if add_erasing:
            transform_train_list.append(transforms.RandomErasing(p=1.0))
        transform_train = transforms.Compose(transform_train_list)

        if 'kitchen' in args.dataset:
            transform_test = transforms.Compose([
                transforms.Resize((32, 32), interpolation=2),
                transforms.ToTensor(),
                transforms.Normalize((0.4914, 0.4822, 0.4465),
                                     (0.2023, 0.1994, 0.2010)),
            ])
        else:
            transform_test = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.4914, 0.4822, 0.4465),
                                     (0.2023, 0.1994, 0.2010)),
            ])

    elif 'stl' in args.dataset:
        transform_train = transforms.Compose([
            transforms.RandomResizedCrop(size=96, scale=(0.2, 1.)),
            transforms.ColorJitter(0.4, 0.4, 0.4, 0.4),
            transforms.RandomGrayscale(p=0.2),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ])

        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ])

    if args.dataset == 'cifar10':
        trainset = datasets.CIFAR10Instance(root='./data/CIFAR-10',
                                            train=True,
                                            download=True,
                                            transform=transform_train,
                                            two_imgs=args.two_imgs,
                                            three_imgs=args.three_imgs)
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            trainset)
        trainloader = torch.utils.data.DataLoader(trainset,
                                                  batch_size=args.batch_size,
                                                  shuffle=False,
                                                  num_workers=args.num_workers,
                                                  pin_memory=False,
                                                  sampler=train_sampler)

        testset = datasets.CIFAR10Instance(root='./data/CIFAR-10',
                                           train=False,
                                           download=True,
                                           transform=transform_test)
        testloader = torch.utils.data.DataLoader(testset,
                                                 batch_size=100,
                                                 shuffle=False,
                                                 num_workers=2,
                                                 pin_memory=False)
        args.pool_len = 4
        ndata = trainset.__len__()

    elif args.dataset == 'cifar100':
        trainset = datasets.CIFAR100Instance(root='./data/CIFAR-100',
                                             train=True,
                                             download=True,
                                             transform=transform_train,
                                             two_imgs=args.two_imgs,
                                             three_imgs=args.three_imgs)
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            trainset)
        trainloader = torch.utils.data.DataLoader(trainset,
                                                  batch_size=args.batch_size,
                                                  shuffle=False,
                                                  num_workers=args.num_workers,
                                                  pin_memory=False,
                                                  sampler=train_sampler)

        testset = datasets.CIFAR100Instance(root='./data/CIFAR-100',
                                            train=False,
                                            download=True,
                                            transform=transform_test)
        testloader = torch.utils.data.DataLoader(testset,
                                                 batch_size=100,
                                                 shuffle=False,
                                                 num_workers=2,
                                                 pin_memory=False)
        args.pool_len = 4
        ndata = trainset.__len__()

    elif args.dataset == 'stl10':
        trainset = datasets.STL10(root='./data/STL10',
                                  split='train',
                                  download=True,
                                  transform=transform_train,
                                  two_imgs=args.two_imgs,
                                  three_imgs=args.three_imgs)
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            trainset)
        trainloader = torch.utils.data.DataLoader(trainset,
                                                  batch_size=args.batch_size,
                                                  shuffle=False,
                                                  num_workers=args.num_workers,
                                                  pin_memory=False,
                                                  sampler=train_sampler)

        testset = datasets.STL10(root='./data/STL10',
                                 split='test',
                                 download=True,
                                 transform=transform_test)
        testloader = torch.utils.data.DataLoader(testset,
                                                 batch_size=100,
                                                 shuffle=False,
                                                 num_workers=2,
                                                 pin_memory=False)
        args.pool_len = 7
        ndata = trainset.__len__()

    elif args.dataset == 'stl10-full':
        trainset = datasets.STL10(root='./data/STL10',
                                  split='train+unlabeled',
                                  download=True,
                                  transform=transform_train,
                                  two_imgs=args.two_imgs,
                                  three_imgs=args.three_imgs)
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            trainset)
        trainloader = torch.utils.data.DataLoader(trainset,
                                                  batch_size=args.batch_size,
                                                  shuffle=False,
                                                  num_workers=args.num_workers,
                                                  pin_memory=False,
                                                  sampler=train_sampler)

        labeledTrainset = datasets.STL10(root='./data/STL10',
                                         split='train',
                                         download=True,
                                         transform=transform_train,
                                         two_imgs=args.two_imgs)
        labeledTrain_sampler = torch.utils.data.distributed.DistributedSampler(
            labeledTrainset)
        labeledTrainloader = torch.utils.data.DataLoader(
            labeledTrainset,
            batch_size=args.batch_size,
            shuffle=False,
            num_workers=2,
            pin_memory=False,
            sampler=labeledTrain_sampler)
        testset = datasets.STL10(root='./data/STL10',
                                 split='test',
                                 download=True,
                                 transform=transform_test)
        testloader = torch.utils.data.DataLoader(testset,
                                                 batch_size=100,
                                                 shuffle=False,
                                                 num_workers=2,
                                                 pin_memory=False)
        args.pool_len = 7
        ndata = labeledTrainset.__len__()

    elif args.dataset == 'kitchen':
        trainset = datasets.CIFARImageFolder(root='./data/Kitchen-HC/train',
                                             train=True,
                                             transform=transform_train,
                                             two_imgs=args.two_imgs,
                                             three_imgs=args.three_imgs)
        trainloader = torch.utils.data.DataLoader(trainset,
                                                  batch_size=args.batch_size,
                                                  shuffle=True,
                                                  num_workers=args.num_workers,
                                                  pin_memory=False)
        testset = datasets.CIFARImageFolder(root='./data/Kitchen-HC/test',
                                            train=False,
                                            transform=transform_test)
        testloader = torch.utils.data.DataLoader(testset,
                                                 batch_size=100,
                                                 shuffle=False,
                                                 num_workers=2,
                                                 pin_memory=False)
        args.pool_len = 4
        ndata = trainset.__len__()

    return trainloader, testloader, ndata
Пример #3
0
def main():
    global args, best_prec1
    args = parser.parse_args()

    args.distributed = args.world_size > 1

    if args.distributed:
        dist.init_process_group(backend=args.dist_backend,
                                init_method=args.dist_url,
                                world_size=args.world_size)

    # create model
    if args.pretrained:
        print("=> using pre-trained model '{}'".format(args.arch))
        model = models.__dict__[args.arch](pretrained=True)
    else:
        print("=> creating model '{}'".format(args.arch))
        model = models.__dict__[args.arch](low_dim=args.low_dim)

    if not args.distributed:
        if args.arch.startswith('alexnet') or args.arch.startswith('vgg'):
            model.features = torch.nn.DataParallel(model.features)
            model.cuda()
        else:
            model = torch.nn.DataParallel(model).cuda()
    else:
        model.cuda()
        model = torch.nn.parallel.DistributedDataParallel(model)
    '''
    # Data loading code
    traindir = os.path.join(args.data, 'train')
    valdir = os.path.join(args.data, 'val')
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    train_dataset = datasets.ImageFolderInstance(
        traindir,
        transforms.Compose([
            transforms.RandomResizedCrop(224, scale=(0.2,1.)),
            transforms.RandomGrayscale(p=0.2),
            transforms.ColorJitter(0.4, 0.4, 0.4, 0.4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ]))
    '''

    ### FOR MNIST
    '''
    train_trans = transforms.Compose([transforms.RandomResizedCrop(size=224, scale=(0.2,1)),
                                      transforms.ColorJitter(0.4, 0.4, 0.4, 0.4),
                                      transforms.RandomGrayscale(p=0.2),
                                      transforms.ToTensor()])

    val_trans = transforms.Compose([transforms.Resize(224),
                                    transforms.ToTensor()])#,
                                    #transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])

    '''

    ### FOR CIFAR
    train_trans = transforms.Compose([
        transforms.RandomResizedCrop(size=224, scale=(0.2, 1.)),
        transforms.ColorJitter(0.4, 0.4, 0.4, 0.4),
        transforms.RandomGrayscale(p=0.2),
        #transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465),
                             (0.2023, 0.1994, 0.2010))
    ])

    val_trans = transforms.Compose([
        transforms.Resize(224),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465),
                             (0.2023, 0.1994, 0.2010))
    ])

    train_dataset = datasets.CIFAR10Instance(
        subset_ratio=args.subset_ratio,
        classes_ratio=args.classes_ratio,
        batch_size=args.batch_size,
        root=args.data,
        train=True,
        transform=train_trans,
        download=True,
    )

    val_dataset = datasets.CIFAR10Instance(classes_ratio=args.classes_ratio,
                                           batch_size=args.batch_size,
                                           root=args.data,
                                           train=False,
                                           transform=val_trans,
                                           download=True)

    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            train_dataset)
    else:
        train_sampler = None

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=(train_sampler is None),
                                               num_workers=args.workers,
                                               pin_memory=True,
                                               sampler=train_sampler)

    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=args.batch_size,
                                             shuffle=(train_sampler is None),
                                             num_workers=args.workers,
                                             pin_memory=True,
                                             sampler=train_sampler)
    '''
    val_loader = torch.utils.data.DataLoader(
        datasets.ImageFolderInstance(valdir, transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            normalize,
        ])),
        batch_size=args.batch_size, shuffle=False,
        num_workers=args.workers, pin_memory=True)
    '''

    # define lemniscate and loss function (criterion)
    ndata = train_dataset.__len__()
    if args.nce_k > 0:
        lemniscate = NCEAverage(args.low_dim, ndata, args.nce_k, args.nce_t,
                                args.nce_m).cuda()
        criterion = NCECriterion(ndata).cuda()
    else:
        lemniscate = LinearAverage(args.low_dim, ndata, args.nce_t,
                                   args.nce_m).cuda()
        criterion = nn.CrossEntropyLoss().cuda()

    optimizer = torch.optim.SGD(model.parameters(),
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            lemniscate = checkpoint['lemniscate']
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    cudnn.benchmark = True

    if args.evaluate:
        kNN(0, model, lemniscate, train_loader, val_loader, 200, args.nce_t,
            args.subset_ratio, args.classes_ratio)
        return

    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed:
            train_sampler.set_epoch(epoch)
        adjust_learning_rate(optimizer, epoch)

        # train for one epoch
        train(train_loader, model, lemniscate, criterion, optimizer, epoch,
              args.subset_ratio, args.classes_ratio)

        # evaluate on validation set
        prec1 = NN(epoch,
                   model,
                   lemniscate,
                   train_loader,
                   val_loader,
                   classes_ratio=args.classes_ratio)

        # remember best prec@1 and save checkpoint
        is_best = prec1 > best_prec1
        best_prec1 = max(prec1, best_prec1)
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': model.state_dict(),
                'lemniscate': lemniscate,
                'best_prec1': best_prec1,
                'optimizer': optimizer.state_dict(),
            }, is_best, epoch)
    # evaluate KNN after last epoch
    kNN(0,
        model,
        lemniscate,
        train_loader,
        val_loader,
        200,
        args.nce_t,
        classes_ratio=args.classes_ratio)
Пример #4
0
def main(args):

    # Data
    print('==> Preparing data..')
    _size = 32
    transform_train = transforms.Compose([
        transforms.Resize(size=_size),
        transforms.RandomResizedCrop(size=_size, scale=(0.2, 1.)),
        transforms.ColorJitter(0.4, 0.4, 0.4, 0.4),
        transforms.RandomGrayscale(p=0.2),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465),
                             (0.2023, 0.1994, 0.2010)),
    ])

    transform_test = transforms.Compose([
        transforms.Resize(size=_size),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465),
                             (0.2023, 0.1994, 0.2010)),
    ])

    trainset = datasets.CIFAR10Instance(root='./data',
                                        train=True,
                                        download=True,
                                        transform=transform_train)
    trainloader = torch.utils.data.DataLoader(trainset,
                                              batch_size=args.batch_size,
                                              shuffle=True,
                                              num_workers=4)

    testset = datasets.CIFAR10Instance(root='./data',
                                       train=False,
                                       download=True,
                                       transform=transform_test)
    testloader = torch.utils.data.DataLoader(testset,
                                             batch_size=100,
                                             shuffle=False,
                                             num_workers=4)

    classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse',
               'ship', 'truck')
    ndata = trainset.__len__()

    print('==> Building model..')
    net = models.__dict__['ResNet18'](low_dim=args.low_dim)

    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    if device == 'cuda':
        net = torch.nn.DataParallel(net,
                                    device_ids=range(
                                        torch.cuda.device_count()))
        cudnn.benchmark = True

    criterion = ICRcriterion()
    # define loss function: inner product loss within each mini-batch
    uel_criterion = BatchCriterion(args.batch_m, args.batch_t, args.batch_size,
                                   ndata)

    net.to(device)
    criterion.to(device)
    uel_criterion.to(device)
    best_acc = 0  # best test accuracy
    start_epoch = 0  # start from epoch 0 or last checkpoint epoch

    if args.test_only or len(args.resume) > 0:
        # Load checkpoint.
        model_path = 'checkpoint/' + args.resume
        print('==> Resuming from checkpoint..')
        assert os.path.isdir(
            args.model_dir), 'Error: no checkpoint directory found!'
        checkpoint = torch.load(model_path)
        net.load_state_dict(checkpoint['net'])
        best_acc = checkpoint['acc']
        start_epoch = checkpoint['epoch']

    # define leminiscate
    if args.test_only and len(args.resume) > 0:

        trainFeatures, feature_index = compute_feature(trainloader, net,
                                                       len(trainset), args)
        lemniscate = LinearAverage(torch.tensor(trainFeatures), args.low_dim,
                                   ndata, args.nce_t, args.nce_m)

    else:

        lemniscate = LinearAverage(torch.tensor([]), args.low_dim, ndata,
                                   args.nce_t, args.nce_m)
    lemniscate.to(device)

    # define optimizer
    optimizer = torch.optim.SGD(net.parameters(),
                                lr=args.lr,
                                momentum=0.9,
                                weight_decay=5e-4)
    # optimizer2 = torch.optim.SGD(net2.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4)

    # test acc
    if args.test_only:
        acc = kNN(0,
                  net,
                  trainloader,
                  testloader,
                  200,
                  args.batch_t,
                  ndata,
                  low_dim=args.low_dim)
        exit(0)

    if len(args.resume) > 0:
        best_acc = best_acc
        start_epoch = start_epoch + 1
    else:
        best_acc = 0  # best test accuracy
        start_epoch = 0  # start from epoch 0 or last checkpoint epoch

    icr2 = ICRDiscovery(ndata)

    # init_cluster_num = 20000
    for round in range(5):
        for epoch in range(start_epoch, 200):
            #### get Features

            # trainFeatures are trainloader features and shuffle=True, so feature_index is match data
            trainFeatures, feature_index = compute_feature(
                trainloader, net, len(trainset), args)

            if round == 0:
                y = -1 * math.log10(ndata) / 200 * epoch + math.log10(ndata)
                cluster_num = int(math.pow(10, y))
                if cluster_num <= args.nmb_cluster:
                    cluster_num = args.nmb_cluster

                print('cluster number: ' + str(cluster_num))

                ###clustering algorithm to use
                # faiss cluster
                deepcluster = clustering.__dict__[args.clustering](
                    int(cluster_num))

                #### Features to clustering
                clustering_loss = deepcluster.cluster(trainFeatures,
                                                      feature_index,
                                                      verbose=args.verbose)

                L = np.array(deepcluster.images_lists)
                image_dict = deepcluster.images_dict

                print('create ICR ...')
                # icr = ICRDiscovery(ndata)

                # if args.test_only and len(args.resume) > 0:
                # icr = cluster_assign(icr, L, trainFeatures, feature_index, trainset,
                # cluster_ratio + epoch*((1-cluster_ratio)/250))
                icrtime = time.time()

                # icr = cluster_assign(epoch, L, trainFeatures, feature_index, 1, 1)
                if epoch < args.warm_epoch:
                    icr = cluster_assign(epoch, L, trainFeatures,
                                         feature_index, args.cluster_ratio, 1)
                else:
                    icr = PreScore(epoch, L, image_dict, trainFeatures,
                                   feature_index, trainset, args.high_ratio,
                                   args.cluster_ratio, args.alpha, args.beta)

                print('calculate ICR time is: {}'.format(time.time() -
                                                         icrtime))
                writer.add_scalar('icr_time', (time.time() - icrtime),
                                  epoch + round * 200)

            else:
                cluster_num = args.nmb_cluster
                print('cluster number: ' + str(cluster_num))

                ###clustering algorithm to use
                # faiss cluster
                deepcluster = clustering.__dict__[args.clustering](
                    int(cluster_num))

                #### Features to clustering
                clustering_loss = deepcluster.cluster(trainFeatures,
                                                      feature_index,
                                                      verbose=args.verbose)

                L = np.array(deepcluster.images_lists)
                image_dict = deepcluster.images_dict

                print('create ICR ...')
                # icr = ICRDiscovery(ndata)

                # if args.test_only and len(args.resume) > 0:
                # icr = cluster_assign(icr, L, trainFeatures, feature_index, trainset,
                # cluster_ratio + epoch*((1-cluster_ratio)/250))
                icrtime = time.time()

                # icr = cluster_assign(epoch, L, trainFeatures, feature_index, 1, 1)
                icr = PreScore(epoch, L, image_dict, trainFeatures,
                               feature_index, trainset, args.high_ratio,
                               args.cluster_ratio, args.alpha, args.beta)

                print('calculate ICR time is: {}'.format(time.time() -
                                                         icrtime))
                writer.add_scalar('icr_time', (time.time() - icrtime),
                                  epoch + round * 200)

            # else:
            #     icr = cluster_assign(icr, L, trainFeatures, feature_index, trainset, 0.2 + epoch*0.004)

            # print(icr.neighbours)

            icr2 = train(epoch, net, optimizer, lemniscate, criterion,
                         uel_criterion, trainloader, icr, icr2,
                         args.stage_update, args.lr, device, round)

            print('----------Evaluation---------')
            start = time.time()
            acc = kNN(0,
                      net,
                      trainloader,
                      testloader,
                      200,
                      args.batch_t,
                      ndata,
                      low_dim=args.low_dim)
            print("Evaluation Time: '{}'s".format(time.time() - start))

            writer.add_scalar('nn_acc', acc, epoch + round * 200)

            if acc > best_acc:
                print('Saving..')
                state = {
                    'net': net.state_dict(),
                    'acc': acc,
                    'epoch': epoch,
                }
                if not os.path.isdir(args.model_dir):
                    os.mkdir(args.model_dir)
                torch.save(state,
                           './checkpoint/ckpt_best_round_{}.t7'.format(round))

                best_acc = acc

            state = {
                'net': net.state_dict(),
                'acc': acc,
                'epoch': epoch,
            }
            torch.save(state,
                       './checkpoint/ckpt_last_round_{}.t7'.format(round))

            print(
                '[Round]: {} [Epoch]: {} \t accuracy: {}% \t (best acc: {}%)'.
                format(round, epoch, acc, best_acc))