Exemple #1
0
        print(log)
        logger.write(log + '\n')
        sys.stdout.flush()
        #logger.flush()

    if round == cfg.max_round:
        break

    epoch = 0
    lr = cfg.lr_base
    while lr > 0 and epoch < cfg.max_epoch:
        lr = lr_handler.update(epoch)
        loss = train(epoch, net, trainloader, optimizer, npc, criterion, rlb,
                     lr)
        memory, _ = compute_memory(net, trainloader,
                                   testloader.dataset.transform, cfg.device)
        acc = kNN(net, memory, trainloader, trainlabels, testloader, 200,
                  cfg.npc_temperature, cfg.device)
        if acc > best_acc:
            best_acc = acc
            best_net_wts = copy.deepcopy(net.state_dict())
        epoch += 1
        log = '[%04d-%04d]\tloss:%2.12f acc:%2.12f best:%2.12f' % (
            round + 1, epoch, loss, acc, best_acc)
        print(log)
        logger.write(log + '\n')
        sys.stdout.flush()
        #logger.flush()
    round += 1
torch.save(best_net_wts, os.path.join(cfg.sess_dir, 'checkpoint.dict'))
Exemple #2
0
def main():
    global args, best_prec1
    args = parser.parse_args()
    print(args)

    # create model
    if args.pretrained:
        print("=> using pre-trained model '{}'".format(args.arch))
        model = models.__dict__[args.arch](pretrained=True)
    else:
        print("=> creating model '{}'".format(args.arch))
        model = models.__dict__[args.arch](low_dim=args.low_dim)

    if args.arch.startswith('alexnet') or args.arch.startswith('vgg'):
        model.features = torch.nn.DataParallel(model.features)
        model.cuda()
    else:
        model = torch.nn.DataParallel(model).cuda()

    # Data loading code
    traindir = os.path.join(args.data, 'train')
    valdir = os.path.join(args.data, 'val')
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    train_dataset = datasets.ImageFolderInstance(
        traindir,
        transforms.Compose([
            transforms.RandomResizedCrop(224, scale=(0.2, 1.)),
            transforms.RandomGrayscale(p=0.2),
            transforms.ColorJitter(0.4, 0.4, 0.4, 0.4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ]))
    train_labels = torch.tensor(train_dataset.targets).long().cuda()
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True,
                                               sampler=None)

    val_loader = torch.utils.data.DataLoader(datasets.ImageFolderInstance(
        valdir,
        transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            normalize,
        ])),
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    # define lemniscate and loss function (criterion)
    ndata = train_dataset.__len__()
    lemniscate = LinearAverage(args.low_dim, ndata, args.nce_t,
                               args.nce_m).cuda()
    rlb = ReliableSearch(ndata, args.low_dim, args.threshold_1,
                         args.threshold_2, args.batch_size).cuda()
    criterion = ReliableCrossEntropyLoss().cuda()

    optimizer = torch.optim.SGD(model.parameters(),
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = 0
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            lemniscate = checkpoint['lemniscate']
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    cudnn.benchmark = True

    if args.evaluate:
        kNN(0, model, lemniscate, train_loader, val_loader, 200, args.nce_t)
        return

    for rnd in range(args.start_round, args.rounds):

        if rnd > 0:
            memory = recompute_memory(model, lemniscate, train_loader,
                                      val_loader, args.batch_size,
                                      args.workers)
            num_reliable_1, consistency_1, num_reliable_2, consistency_2 = rlb.update(
                memory, train_labels)
            print(
                'Round [%02d/%02d]\tReliable1: %.12f\tReliable2: %.12f\tConsistency1: %.12f\tConsistency2: %.12f'
                % (rnd, args.rounds, num_reliable_1, num_reliable_2,
                   consistency_1, consistency_2))

        for epoch in range(args.start_epoch, args.epochs):
            adjust_learning_rate(optimizer, epoch)

            # train for one epoch
            train(train_loader, model, lemniscate, rlb, criterion, optimizer,
                  epoch)

            # evaluate on validation set
            prec1 = NN(epoch, model, lemniscate, train_loader, val_loader)

            # remember best prec@1 and save checkpoint
            is_best = prec1 > best_prec1
            best_prec1 = max(prec1, best_prec1)
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'arch': args.arch,
                    'state_dict': model.state_dict(),
                    'lemniscate': lemniscate,
                    'best_prec1': best_prec1,
                    'optimizer': optimizer.state_dict(),
                    #}, is_best, filename='ckpts/%02d-%04d-checkpoint.pth.tar'%(rnd+1, epoch + 1))
                },
                is_best)

        save_checkpoint(
            {
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': model.state_dict(),
                'lemniscate': lemniscate,
                'best_prec1': best_prec1,
                'optimizer': optimizer.state_dict(),
            },
            is_best=False,
            filename='ckpts/%02d-checkpoint.pth.tar' % (rnd + 1))

        # evaluate KNN after last epoch
        top1, top5 = kNN(0, model, lemniscate, train_loader, val_loader, 200,
                         args.nce_t)
        print('Round [%02d/%02d]\tTop1: %.2f\tTop5: %.2f' %
              (rnd + 1, args.rounds, top1, top5))
Exemple #3
0
def main():
    args = config()

    # fix random seeds
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    np.random.seed(args.seed)

    args.device = 'cuda' if torch.cuda.is_available() else 'cpu'
    trainset, trainloader, testset, testloader = preprocess(args)
    ntrain = len(trainset)
    cheat_labels = torch.tensor(trainset.targets).long().to(args.device)
    net = models.__dict__['ResNet18withSobel'](low_dim=args.low_dim)
    npc = NonParametricClassifier(args.low_dim, ntrain, args.npc_t, args.npc_m)
    ANs_discovery = ANsDiscovery(ntrain, args.ANs_select_rate, args.ANs_size,
                                 args.device)
    criterion = Criterion_SAND(args.batch_m, args.batch_t, args.batch_size,
                               args.device)
    criterion2 = UELoss()
    optimizer = torch.optim.SGD(net.parameters(),
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay,
                                nesterov=True)

    if args.device == 'cuda':
        net = torch.nn.DataParallel(net,
                                    device_ids=range(
                                        torch.cuda.device_count()))
        cudnn.benchmark = True

    net.to(args.device)
    npc.to(args.device)
    ANs_discovery.to(args.device)
    criterion.to(args.device)
    criterion2.to(args.device)

    if args.test_only or len(args.resume) > 0:
        model_path = args.model_dir + args.resume
        print('==> Resuming from checkpoint..')
        assert os.path.isdir(
            args.model_dir), 'Error: no checkpoint directory found!'
        checkpoint = torch.load(model_path)
        net.load_state_dict(checkpoint['net'])
        npc.load_state_dict(checkpoint['npc'])
        ANs_discovery = checkpoint['ANs_discovery']
        best_acc = checkpoint['acc']
        start_round = checkpoint['round']
        start_epoch = checkpoint['epoch']

    if args.test_only:
        acc = kNN(net,
                  npc,
                  trainloader,
                  testloader,
                  K=200,
                  sigma=0.1,
                  recompute_memory=False,
                  device=args.device)
        print("accuracy: %.2f\n" % (acc * 100))
        sys.exit(0)

    best_acc = 0
    for r in range(args.rounds):
        if r > 0:
            ANs_discovery.update(r, npc, cheat_labels)

        for epoch in range(args.epochs):
            train(r, epoch, net, trainloader, optimizer, npc, criterion,
                  criterion2, ANs_discovery, args.device)
            acc = kNN(net,
                      npc,
                      trainloader,
                      testloader,
                      K=200,
                      sigma=0.1,
                      recompute_memory=False,
                      device=args.device)
            print("accuracy: %.2f\n" % (acc * 100))

            if acc > best_acc:
                best_acc = acc
            print("best accuracy: %.2f\n" % (best_acc * 100))

    state = {
        'net': net.state_dict(),
        'npc': npc.state_dict(),
        'ANs_discovery': ANs_discovery.state_dict(),
        'acc': acc,
        'round': r,
        'epoch': epoch
    }
    if not os.path.isdir('checkpoint'):
        os.mkdir('checkpoint')
    torch.save(state, './checkpoint/ckpt_embed.t7')
Exemple #4
0
def main():
    args = config()

    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    np.random.seed(args.seed)

    args.device = 'cuda' if torch.cuda.is_available() else 'cpu'
    trainset, trainloader, testset, testloader = preprocess(args)
    ntrain = len(trainset)

    net = models.__dict__['ResNet18withSobel'](low_dim=args.low_dim)
    npc = NonParametricClassifier(args.structure, args.low_dim, ntrain,
                                  args.npc_t, args.npc_m, args.device)
    # structure = GraphStructure(args.structure, ntrain, args.low_dim, args.batch_size, args.neighbor_size, args.device)
    criterion = Criterion()
    optimizer = torch.optim.SGD(net.parameters(),
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay,
                                nesterov=True)

    if args.device == 'cuda':
        net = torch.nn.DataParallel(net,
                                    device_ids=range(
                                        torch.cuda.device_count()))
        cudnn.benchmark = True

    net.to(args.device)
    npc.to(args.device)
    # structure.to(args.device)
    criterion.to(args.device)

    print('==> init by self loop..')
    checkpoint = torch.load('checkpoint/init.t7')
    net.load_state_dict(checkpoint['net'])
    npc.load_state_dict(checkpoint['npc'])

    # images = trainset.data
    # neighbor_indexes_sim = checkpoint['structure']['neighbor_indexes_sim']
    # neighbor_indexes_disim = checkpoint['structure']['neighbor_indexes_disim']
    # query = torch.randperm(neighbor_indexes_sim.size(0))[:10]
    # for q in query:
    #     os.mkdir('BFS_bi/%d' % q)
    #     for top, i in enumerate(neighbor_indexes_sim[q]):
    #         img = images[i]
    #         plt.imshow(img)
    #         plt.savefig('BFS_bi/%d/%d' % (q,top))
    #     for top, i in enumerate(neighbor_indexes_disim[q]):
    #         img = images[i]
    #         plt.imshow(img)
    #         plt.savefig('BFS_bi/%d/%d_neg' % (q,top))
    # sys.exit(0)

    if len(args.resume) > 0:
        model_path = args.model_dir + args.resume
        print('==> Resuming from checkpoint..')
        assert os.path.isdir(
            args.model_dir), 'Error: no checkpoint directory found!'
        checkpoint = torch.load(model_path)
        net.load_state_dict(checkpoint['net'])
        npc.load_state_dict(checkpoint['npc'])

    if args.test_only:
        acc = kNN(net,
                  npc,
                  trainloader,
                  testloader,
                  K=200,
                  sigma=0.1,
                  recompute_memory=False,
                  device=args.device)
        print("accuracy: %.2f\n" % (acc * 100))
        sys.exit(0)

    best_acc = 0
    cur_acc = []
    # for r in range(args.rounds):
    for r in range(1, args.rounds):
        # for r in range(1, 2):
        if r > 0:
            structure = GraphStructure(args.structure, ntrain, args.low_dim,
                                       args.batch_size, args.neighbor_size,
                                       args.device)
            structure.to(args.device)
            structure.update(npc)

        # for epoch in range(args.epochs):
        for epoch in range(1):
            train(r, epoch, trainloader, net, npc, structure, criterion,
                  optimizer, args.device)
            acc = kNN(net,
                      npc,
                      trainloader,
                      testloader,
                      K=200,
                      sigma=0.1,
                      recompute_memory=False,
                      device=args.device)
            print("accuracy: %.2f" % (acc * 100))

            if acc > best_acc:
                print("state saving...")
                state = {
                    'net': net.state_dict(),
                    'npc': npc.state_dict(),
                    'structure': structure.state_dict(),
                    'acc': acc,
                    'round': r,
                    'epoch': epoch
                }
                if not os.path.isdir('checkpoint'):
                    os.mkdir('checkpoint')
                torch.save(
                    state,
                    './checkpoint/{}_cur_reset.t7'.format(args.structure))
                best_acc = acc
            print("best accuracy: %.2f\n" % (best_acc * 100))
        cur_acc.append(acc)
        print(cur_acc)
    sys.exit(0)