def main_worker(args): init_dist(args.launcher, args) synchronize() cudnn.benchmark = True print("Use GPU: {} for testing, rank no.{} of world_size {}".format( args.gpu, args.rank, args.world_size)) assert (args.resume) if (args.rank == 0): log_dir = osp.dirname(args.resume) sys.stdout = Logger( osp.join(log_dir, 'log_test_' + args.dataset + '.txt')) print("==========\nArgs:{}\n==========".format(args)) # Create data loaders dataset, pitts_train, train_extract_loader, test_loader_q, test_loader_db = get_data( args) # Create model model = get_model(args) # Load from checkpoint if args.resume: checkpoint = load_checkpoint(args.resume) copy_state_dict(checkpoint['state_dict'], model) start_epoch = checkpoint['epoch'] best_recall5 = checkpoint['best_recall5'] if (args.rank == 0): print("=> Start epoch {} best recall5 {:.1%}".format( start_epoch, best_recall5)) # Evaluator evaluator = Evaluator(model) if (args.reduction): pca_parameters_path = osp.join( osp.dirname(args.resume), 'pca_params_' + osp.basename(args.resume).split('.')[0] + '.h5') pca = PCA(args.features, (not args.nowhiten), pca_parameters_path) if (not osp.isfile(pca_parameters_path)): dict_f = extract_features(model, train_extract_loader, pitts_train, vlad=args.vlad, gpu=args.gpu, sync_gather=args.sync_gather) features = list(dict_f.values()) if (len(features) > 10000): features = random.sample(features, 10000) features = torch.stack(features) if (args.rank == 0): pca.train(features) synchronize() del features else: pca = None if (args.rank == 0): print("Evaluate on the test set:") evaluator.evaluate(test_loader_q, sorted(list(set(dataset.q_test) | set(dataset.db_test))), dataset.q_test, dataset.db_test, dataset.test_pos, gallery_loader=test_loader_db, vlad=args.vlad, pca=pca, rerank=args.rerank, gpu=args.gpu, sync_gather=args.sync_gather, nms=(True if args.dataset == 'tokyo' else False), rr_topk=args.rr_topk, lambda_value=args.lambda_value) synchronize() return
def main_worker(args): global start_epoch, best_recall5 init_dist(args.launcher, args) synchronize() if args.seed is not None: random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) torch.cuda.manual_seed(args.seed) if args.deterministic: cudnn.deterministic = True cudnn.benchmark = False print("Use GPU: {} for training, rank no.{} of world_size {}" .format(args.gpu, args.rank, args.world_size)) if (args.rank==0): sys.stdout = Logger(osp.join(args.logs_dir, 'log.txt')) print("==========\nArgs:{}\n==========".format(args)) # Create data loaders iters = args.iters if (args.iters>0) else None dataset, train_loader, val_loader, test_loader, sampler, train_extract_loader = get_data(args, iters) # Create model model = get_model(args) # Load from checkpoint if args.resume: checkpoint = load_checkpoint(args.resume) copy_state_dict(checkpoint['state_dict'], model) start_epoch = checkpoint['epoch']+1 best_recall5 = checkpoint['best_recall5'] if (args.rank==0): print("=> Start epoch {} best recall5 {:.1%}" .format(start_epoch, best_recall5)) # Evaluator evaluator = Evaluator(model) if (args.rank==0): print("Test the initial model:") recalls = evaluator.evaluate(val_loader, sorted(list(set(dataset.q_val) | set(dataset.db_val))), dataset.q_val, dataset.db_val, dataset.val_pos, vlad=args.vlad, gpu=args.gpu, sync_gather=args.sync_gather) # Optimizer optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.step_size, gamma=0.5) # Trainer trainer = Trainer(model, margin=args.margin**0.5, gpu=args.gpu) if ((args.cache_size<args.tuple_size) or (args.cache_size>len(dataset.q_train))): args.cache_size = len(dataset.q_train) # Start training for epoch in range(start_epoch, args.epochs): sampler.set_epoch(args.seed+epoch) args.cache_size = args.cache_size * (2 ** (epoch // args.step_size)) g = torch.Generator() g.manual_seed(args.seed+epoch) subset_indices = torch.randperm(len(dataset.q_train), generator=g).long().split(args.cache_size) for subid, subset in enumerate(subset_indices): update_sampler(sampler, model, train_extract_loader, dataset.q_train, dataset.db_train, subset.tolist(), vlad=args.vlad, gpu=args.gpu, sync_gather=args.sync_gather) synchronize() trainer.train(epoch, subid, train_loader, optimizer, train_iters=len(train_loader), print_freq=args.print_freq, vlad=args.vlad, loss_type=args.loss_type) synchronize() if ((epoch+1)%args.eval_step==0 or (epoch==args.epochs-1)): recalls = evaluator.evaluate(val_loader, sorted(list(set(dataset.q_val) | set(dataset.db_val))), dataset.q_val, dataset.db_val, dataset.val_pos, vlad=args.vlad, gpu=args.gpu, sync_gather=args.sync_gather) is_best = recalls[1] > best_recall5 best_recall5 = max(recalls[1], best_recall5) if (args.rank==0): save_checkpoint({ 'state_dict': model.state_dict(), 'epoch': epoch, 'best_recall5': best_recall5, }, is_best, fpath=osp.join(args.logs_dir, 'checkpoint'+str(epoch)+'.pth.tar')) print('\n * Finished epoch {:3d} recall@1: {:5.1%} recall@5: {:5.1%} recall@10: {:5.1%} best@5: {:5.1%}{}\n'. format(epoch, recalls[0], recalls[1], recalls[2], best_recall5, ' *' if is_best else '')) lr_scheduler.step() synchronize() # final inference if (args.rank==0): print("Performing PCA reduction on the best model:") model.load_state_dict(load_checkpoint(osp.join(args.logs_dir, 'model_best.pth.tar'))['state_dict']) pca_parameters_path = osp.join(args.logs_dir, 'pca_params_model_best.h5') pca = PCA(args.features, (not args.nowhiten), pca_parameters_path) dict_f = extract_features(model, train_extract_loader, sorted(list(set(dataset.q_train) | set(dataset.db_train))), vlad=args.vlad, gpu=args.gpu, sync_gather=args.sync_gather) features = list(dict_f.values()) if (len(features)>10000): features = random.sample(features, 10000) features = torch.stack(features) if (args.rank==0): pca.train(features) synchronize() del features if (args.rank==0): print("Testing on Pitts30k-test:") evaluator.evaluate(test_loader, sorted(list(set(dataset.q_test) | set(dataset.db_test))), dataset.q_test, dataset.db_test, dataset.test_pos, vlad=args.vlad, pca=pca, gpu=args.gpu, sync_gather=args.sync_gather) synchronize() return
def main_worker(args): cudnn.benchmark = True print("==========\nArgs:{}\n==========".format(args)) nDescriptors = 50000 nPerImage = 100 nIm = math.ceil(nDescriptors / nPerImage) # Create data loaders dataset, data_loader = get_data(args, nIm) # Create model model = get_model(args) encoder_dim = model.module.feature_dim # Load from resume if args.resume: print('Loading weights from {}'.format(args.resume)) checkpoint = load_checkpoint(args.resume) copy_state_dict(checkpoint['state_dict'], model) if not osp.exists(osp.join(args.logs_dir)): os.makedirs(osp.join(args.logs_dir)) initcache = osp.join( args.logs_dir, args.arch + '_' + args.dataset + '_' + str(args.num_clusters) + '_desc_cen.hdf5') with h5py.File(initcache, mode='w') as h5: with torch.no_grad(): model.eval() print('====> Extracting Descriptors') dbFeat = h5.create_dataset("descriptors", [nDescriptors, encoder_dim], dtype=np.float32) for iteration, (input, _, _, _, _) in enumerate(data_loader, 1): input = input.cuda() image_descriptors = model(input) # normalization is IMPORTANT! image_descriptors = F.normalize(image_descriptors, p=2, dim=1).view( input.size(0), encoder_dim, -1).permute(0, 2, 1) batchix = (iteration - 1) * args.batch_size * nPerImage for ix in range(image_descriptors.size(0)): # sample different location for each image in batch sample = np.random.choice(image_descriptors.size(1), nPerImage, replace=False) startix = batchix + ix * nPerImage dbFeat[startix:startix + nPerImage, :] = image_descriptors[ ix, sample, :].detach().cpu().numpy() if (iteration % args.print_freq == 0) or (len(data_loader) <= args.print_freq): print("==> Batch ({}/{})".format( iteration, math.ceil(nIm / args.batch_size)), flush=True) del input, image_descriptors print('====> Clustering') niter = 100 kmeans = KMeans(n_clusters=args.num_clusters, max_iter=niter, random_state=args.seed).fit(dbFeat[...]) print('====> Storing centroids', kmeans.cluster_centers_.shape) h5.create_dataset('centroids', data=kmeans.cluster_centers_) print('====> Done!')