예제 #1
0
def update_sampler(sampler, model, loader, query, gallery, sub_set, vlad=True, gpu=None, sync_gather=False):
    if (dist.get_rank()==0):
        print ("===> Start extracting features for sorting gallery")
    features = extract_features(model, loader, sorted(list(set(query) | set(gallery))),
                                vlad=vlad, gpu=gpu, sync_gather=sync_gather)
    distmat, _, _ = pairwise_distance(features, query, gallery)
    del features
    if (dist.get_rank()==0):
        print ("===> Start sorting gallery")
    sampler.sort_gallery(distmat, sub_set)
    del distmat
예제 #2
0
def update_sampler(sampler, model, loader, query, gallery, sub_set, rerank=False,
                        vlad=True, gpu=None, sync_gather=False, lambda_value=0.1):
    if (dist.get_rank()==0):
        print ("===> Start extracting features for sorting gallery")
    features = extract_features(model, loader, sorted(list(set(query) | set(gallery))),
                                vlad=vlad, gpu=gpu, sync_gather=sync_gather)
    distmat, _, _ = pairwise_distance(features, query, gallery)
    if rerank:
        distmat_qq, _, _ = pairwise_distance(features, query, query)
        distmat_gg, _, _ = pairwise_distance(features, gallery, gallery)
        distmat_jac = re_ranking(distmat.numpy(), distmat_qq.numpy(), distmat_gg.numpy(),
                                                    k1=20, k2=1, lambda_value=lambda_value)
        distmat_jac = torch.from_numpy(distmat_jac)
        del distmat_qq, distmat_gg
    else:
        distmat_jac = distmat
    del features
    if (dist.get_rank()==0):
        print ("===> Start sorting gallery")
    sampler.sort_gallery(distmat, distmat_jac, sub_set)
    del distmat, distmat_jac
예제 #3
0
def main_worker(args):
    global start_epoch, best_recall5
    init_dist(args.launcher, args)
    synchronize()

    if args.seed is not None:
        random.seed(args.seed)
        np.random.seed(args.seed)
        torch.manual_seed(args.seed)
        torch.cuda.manual_seed(args.seed)

    if args.deterministic:
        cudnn.deterministic = True
        cudnn.benchmark = False

    print("Use GPU: {} for training, rank no.{} of world_size {}"
          .format(args.gpu, args.rank, args.world_size))

    if (args.rank==0):
        sys.stdout = Logger(osp.join(args.logs_dir, 'log.txt'))
        print("==========\nArgs:{}\n==========".format(args))

    # Create data loaders
    iters = args.iters if (args.iters>0) else None
    dataset, train_loader, val_loader, test_loader, sampler, train_extract_loader = get_data(args, iters)

    # Create model
    model = get_model(args)

    # Load from checkpoint
    if args.resume:
        checkpoint = load_checkpoint(args.resume)
        copy_state_dict(checkpoint['state_dict'], model)
        start_epoch = checkpoint['epoch']+1
        best_recall5 = checkpoint['best_recall5']
        if (args.rank==0):
            print("=> Start epoch {}  best recall5 {:.1%}"
                  .format(start_epoch, best_recall5))

    # Evaluator
    evaluator = Evaluator(model)
    if (args.rank==0):
        print("Test the initial model:")
    recalls = evaluator.evaluate(val_loader, sorted(list(set(dataset.q_val) | set(dataset.db_val))),
                        dataset.q_val, dataset.db_val, dataset.val_pos,
                        vlad=args.vlad, gpu=args.gpu, sync_gather=args.sync_gather)

    # Optimizer
    optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, model.parameters()),
                                lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
    lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.step_size, gamma=0.5)

    # Trainer
    trainer = Trainer(model, margin=args.margin**0.5, gpu=args.gpu)
    if ((args.cache_size<args.tuple_size) or (args.cache_size>len(dataset.q_train))):
        args.cache_size = len(dataset.q_train)

    # Start training
    for epoch in range(start_epoch, args.epochs):
        sampler.set_epoch(args.seed+epoch)
        args.cache_size = args.cache_size * (2 ** (epoch // args.step_size))

        g = torch.Generator()
        g.manual_seed(args.seed+epoch)
        subset_indices = torch.randperm(len(dataset.q_train), generator=g).long().split(args.cache_size)

        for subid, subset in enumerate(subset_indices):
            update_sampler(sampler, model, train_extract_loader, dataset.q_train, dataset.db_train, subset.tolist(),
                            vlad=args.vlad, gpu=args.gpu, sync_gather=args.sync_gather)
            synchronize()
            trainer.train(epoch, subid, train_loader, optimizer,
                            train_iters=len(train_loader), print_freq=args.print_freq,
                            vlad=args.vlad, loss_type=args.loss_type)
            synchronize()


        if ((epoch+1)%args.eval_step==0 or (epoch==args.epochs-1)):
            recalls = evaluator.evaluate(val_loader, sorted(list(set(dataset.q_val) | set(dataset.db_val))),
                                    dataset.q_val, dataset.db_val, dataset.val_pos,
                                    vlad=args.vlad, gpu=args.gpu, sync_gather=args.sync_gather)

            is_best = recalls[1] > best_recall5
            best_recall5 = max(recalls[1], best_recall5)

            if (args.rank==0):
                save_checkpoint({
                    'state_dict': model.state_dict(),
                    'epoch': epoch,
                    'best_recall5': best_recall5,
                }, is_best, fpath=osp.join(args.logs_dir, 'checkpoint'+str(epoch)+'.pth.tar'))
                print('\n * Finished epoch {:3d} recall@1: {:5.1%}  recall@5: {:5.1%}  recall@10: {:5.1%}  best@5: {:5.1%}{}\n'.
                      format(epoch, recalls[0], recalls[1], recalls[2], best_recall5, ' *' if is_best else ''))

        lr_scheduler.step()
        synchronize()

    # final inference
    if (args.rank==0):
        print("Performing PCA reduction on the best model:")
    model.load_state_dict(load_checkpoint(osp.join(args.logs_dir, 'model_best.pth.tar'))['state_dict'])
    pca_parameters_path = osp.join(args.logs_dir, 'pca_params_model_best.h5')
    pca = PCA(args.features, (not args.nowhiten), pca_parameters_path)
    dict_f = extract_features(model, train_extract_loader, sorted(list(set(dataset.q_train) | set(dataset.db_train))),
                                vlad=args.vlad, gpu=args.gpu, sync_gather=args.sync_gather)
    features = list(dict_f.values())
    if (len(features)>10000):
        features = random.sample(features, 10000)
    features = torch.stack(features)
    if (args.rank==0):
        pca.train(features)
    synchronize()
    del features
    if (args.rank==0):
        print("Testing on Pitts30k-test:")
    evaluator.evaluate(test_loader, sorted(list(set(dataset.q_test) | set(dataset.db_test))),
                dataset.q_test, dataset.db_test, dataset.test_pos,
                vlad=args.vlad, pca=pca, gpu=args.gpu, sync_gather=args.sync_gather)
    synchronize()
    return
예제 #4
0
def main_worker(args):
    init_dist(args.launcher, args)
    synchronize()
    cudnn.benchmark = True
    print("Use GPU: {} for testing, rank no.{} of world_size {}".format(
        args.gpu, args.rank, args.world_size))

    assert (args.resume)
    if (args.rank == 0):
        log_dir = osp.dirname(args.resume)
        sys.stdout = Logger(
            osp.join(log_dir, 'log_test_' + args.dataset + '.txt'))
        print("==========\nArgs:{}\n==========".format(args))

    # Create data loaders
    dataset, pitts_train, train_extract_loader, test_loader_q, test_loader_db = get_data(
        args)

    # Create model
    model = get_model(args)

    # Load from checkpoint
    if args.resume:
        checkpoint = load_checkpoint(args.resume)
        copy_state_dict(checkpoint['state_dict'], model)
        start_epoch = checkpoint['epoch']
        best_recall5 = checkpoint['best_recall5']
        if (args.rank == 0):
            print("=> Start epoch {}  best recall5 {:.1%}".format(
                start_epoch, best_recall5))

    # Evaluator
    evaluator = Evaluator(model)
    if (args.reduction):
        pca_parameters_path = osp.join(
            osp.dirname(args.resume),
            'pca_params_' + osp.basename(args.resume).split('.')[0] + '.h5')
        pca = PCA(args.features, (not args.nowhiten), pca_parameters_path)
        if (not osp.isfile(pca_parameters_path)):
            dict_f = extract_features(model,
                                      train_extract_loader,
                                      pitts_train,
                                      vlad=args.vlad,
                                      gpu=args.gpu,
                                      sync_gather=args.sync_gather)
            features = list(dict_f.values())
            if (len(features) > 10000):
                features = random.sample(features, 10000)
            features = torch.stack(features)
            if (args.rank == 0):
                pca.train(features)
            synchronize()
            del features
    else:
        pca = None

    if (args.rank == 0):
        print("Evaluate on the test set:")
    evaluator.evaluate(test_loader_q,
                       sorted(list(set(dataset.q_test)
                                   | set(dataset.db_test))),
                       dataset.q_test,
                       dataset.db_test,
                       dataset.test_pos,
                       gallery_loader=test_loader_db,
                       vlad=args.vlad,
                       pca=pca,
                       rerank=args.rerank,
                       gpu=args.gpu,
                       sync_gather=args.sync_gather,
                       nms=(True if args.dataset == 'tokyo' else False),
                       rr_topk=args.rr_topk,
                       lambda_value=args.lambda_value)
    synchronize()
    return