def test_cross_dataset(config_file,test_dataset, **kwargs): cfg.merge_from_file(config_file) if kwargs: opts = [] for k,v in kwargs.items(): opts.append(k) opts.append(v) cfg.merge_from_list(opts) cfg.freeze() PersonReID_Dataset_Downloader('./datasets',cfg.DATASETS.NAMES) _, _, _, num_classes = data_loader(cfg,cfg.DATASETS.NAMES) PersonReID_Dataset_Downloader('./datasets',test_dataset) _, val_loader, num_query, _ = data_loader(cfg,test_dataset) re_ranking=cfg.RE_RANKING if not re_ranking: logger = make_logger("Reid_Baseline", cfg.OUTPUT_DIR, cfg.DATASETS.NAMES+'->'+test_dataset) logger.info("Test Results:") else: logger = make_logger("Reid_Baseline", cfg.OUTPUT_DIR, cfg.DATASETS.NAMES+'->'+test_dataset+'_re-ranking') logger.info("Re-Ranking Test Results:") device = torch.device(cfg.DEVICE) model = getattr(models, cfg.MODEL.NAME)(num_classes) model.load(cfg.OUTPUT_DIR,cfg.TEST.LOAD_EPOCH) model = model.eval() all_feats = [] all_pids = [] all_camids = [] since = time.time() for data in tqdm(val_loader, desc='Feature Extraction', leave=False): model.eval() with torch.no_grad(): images, pids, camids = data if device: model.to(device) images = images.to(device) feats = model(images) all_feats.append(feats) all_pids.extend(np.asarray(pids)) all_camids.extend(np.asarray(camids)) cmc, mAP = evaluation(all_feats,all_pids,all_camids,num_query,re_ranking) logger.info("mAP: {:.1%}".format(mAP)) for r in [1, 5, 10]: logger.info("CMC curve, Rank-{:<3}:{:.1%}".format(r, cmc[r - 1])) test_time = time.time() - since logger.info('Testing complete in {:.0f}m {:.0f}s'.format(test_time // 60, test_time % 60))
def train(config_file, **kwargs): cfg.merge_from_file(config_file) if kwargs: opts = [] for k, v in kwargs.items(): opts.append(k) opts.append(v) cfg.merge_from_list(opts) cfg.freeze() PersonReID_Dataset_Downloader(cfg.DATASETS.STORE_DIR, cfg.DATASETS.NAMES) output_dir = cfg.OUTPUT_DIR if output_dir and not os.path.exists(output_dir): os.makedirs(output_dir) logger = make_logger("Reid_Baseline", output_dir, 'log') logger.info("Using {} GPUS".format(1)) logger.info("Loaded configuration file {}".format(config_file)) logger.info("Running with config:\n{}".format(cfg)) checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD eval_period = cfg.SOLVER.EVAL_PERIOD output_dir = cfg.OUTPUT_DIR device = torch.device(cfg.DEVICE) epochs = cfg.SOLVER.MAX_EPOCHS train_loader, val_loader, num_query, num_classes = data_loader( cfg, cfg.DATASETS.NAMES) model = getattr(models, cfg.MODEL.NAME)(num_classes, cfg.MODEL.LAST_STRIDE, cfg.MODEL.POOL) optimizer = make_optimizer(cfg, model) scheduler = make_scheduler(cfg, optimizer) loss_fn = make_loss(cfg) logger.info("Start training") since = time.time() for epoch in range(epochs): count = 0 running_loss = 0.0 running_acc = 0 for data in tqdm(train_loader, desc='Iteration', leave=False): model.train() images, labels = data if device: model.to(device) images, labels = images.to(device), labels.to(device) optimizer.zero_grad() scores, feats = model(images) loss = loss_fn(scores, feats, labels) loss.backward() optimizer.step() count = count + 1 running_loss += loss.item() running_acc += ( scores[0].max(1)[1] == labels).float().mean().item() logger.info( "Epoch[{}] Iteration[{}/{}] Loss: {:.3f}, Acc: {:.3f}, Base Lr: {:.2e}" .format(epoch + 1, count, len(train_loader), running_loss / count, running_acc / count, scheduler.get_lr()[0])) scheduler.step() if (epoch + 1) % checkpoint_period == 0: model.cpu() model.save(output_dir, epoch + 1) # Validation if (epoch + 1) % eval_period == 0: all_feats = [] all_pids = [] all_camids = [] for data in tqdm(val_loader, desc='Feature Extraction', leave=False): model.eval() with torch.no_grad(): images, pids, camids = data if device: model.to(device) images = images.to(device) feats = model(images) all_feats.append(feats) all_pids.extend(np.asarray(pids)) all_camids.extend(np.asarray(camids)) cmc, mAP = evaluation(all_feats, all_pids, all_camids, num_query) logger.info("Validation Results - Epoch: {}".format(epoch + 1)) logger.info("mAP: {:.1%}".format(mAP)) for r in [1, 5, 10]: logger.info("CMC curve, Rank-{:<3}:{:.1%}".format( r, cmc[r - 1])) time_elapsed = time.time() - since logger.info('Training complete in {:.0f}m {:.0f}s'.format( time_elapsed // 60, time_elapsed % 60)) logger.info('-' * 10)
def __init__(self, config_file, epoch_label, **kwargs): """ Validation set is split into two parts - query (probe) and gallery (to be searched), based on num_query. ::Return: Initialize a file 'model_epoch.mtch': matching matrix M of num_query x num_gallery. M_ij is 1 <=> ith query is matched at rank j. """ cfg.merge_from_file(config_file) if kwargs: opts = [] for k, v in kwargs.items(): opts.append(k) opts.append(v) cfg.merge_from_list(opts) cfg.freeze() self.cfg = cfg device = torch.device(cfg.DEVICE) output_dir = cfg.OUTPUT_DIR epoch = epoch_label re_ranking = cfg.RE_RANKING if not os.path.exists(output_dir): raise OSError('Output directory does not exist.') save_filename = (cfg.MODEL.NAME + '_epo%s.mtch' % epoch_label) self._filepath = os.path.join(output_dir, save_filename) if os.path.exists(self._filepath): print('Loading matches file...') self.data = np.load(self._filepath) train_loader, val_loader, num_query, num_classes = data_loader( cfg, cfg.DATASETS.NAMES) self.dataset = val_loader.dataset print('Matches loaded.') else: print('Creating matches file...') PersonReID_Dataset_Downloader(cfg.DATASETS.STORE_DIR, cfg.DATASETS.NAMES) train_loader, val_loader, num_query, num_classes = data_loader( cfg, cfg.DATASETS.NAMES) # load model model = getattr(models, cfg.MODEL.NAME)(num_classes) model.load(output_dir, epoch) model.eval() all_feats = [] all_pids = [] all_camids = [] for data in tqdm(val_loader, desc='Feature Extraction', leave=False): with torch.no_grad(): images, pids, camids = data if device: model.to(device) images = images.to(device) feats = model(images) all_feats.append(feats) all_pids.extend(np.asarray(pids)) all_camids.extend(np.asarray(camids)) all_feats = torch.cat(all_feats, dim=0) # query qf = all_feats[:num_query] q_pids = np.asarray(all_pids[:num_query]) q_camids = np.asarray(all_camids[:num_query]) # gallery gf = all_feats[num_query:] g_pids = np.asarray(all_pids[num_query:]) g_camids = np.asarray(all_camids[num_query:]) if re_ranking: raise NotImplementedError() else: m, n = qf.shape[0], gf.shape[0] distmat = torch.pow(qf, 2).sum(dim=1, keepdim=True).expand(m, n) + \ torch.pow(gf, 2).sum(dim=1, keepdim=True).expand(n, m).t() distmat.addmm_(1, -2, qf, gf.t()) distmat = distmat.cpu().numpy() indices = np.argsort(distmat, axis=1) # matches = np.repeat(g_pids.reshape([1, n]), m, axis=0) == q_pids[:, np.newaxis] ranked_matches = (g_pids[indices] == q_pids[:, np.newaxis]).astype( np.int32) data = { 'q_pids': q_pids, 'g_pids': g_pids, 'q_camids': q_camids, 'g_camids': g_camids, 'ranked_matches': ranked_matches, # 'matches': matches, 'indices': indices, } # save as .mtch with open(self._filepath, 'wb') as f: np.savez(f, **data) print('Matches created.') self.data = data self.dataset = val_loader.dataset
def test(config_file, **kwargs): cfg.merge_from_file(config_file) if kwargs: opts = [] for k,v in kwargs.items(): opts.append(k) opts.append(v) cfg.merge_from_list(opts) cfg.freeze() re_ranking=cfg.RE_RANKING PersonReID_Dataset_Downloader('./datasets',cfg.DATASETS.NAMES) if not re_ranking: logger = make_logger("Reid_Baseline", cfg.OUTPUT_DIR,'result') logger.info("Test Results:") else: logger = make_logger("Reid_Baseline", cfg.OUTPUT_DIR,'result_re-ranking') logger.info("Re-Ranking Test Results:") device = torch.device(cfg.DEVICE) _, val_loader, num_query, num_classes = data_loader(cfg,cfg.DATASETS.NAMES) model = getattr(models, cfg.MODEL.NAME)(num_classes) model.load(cfg.OUTPUT_DIR,cfg.TEST.LOAD_EPOCH) if device: model.to(device) model = model.eval() all_feats = [] all_pids = [] all_camids = [] all_imgs = [] for data in tqdm(val_loader, desc='Feature Extraction', leave=False): with torch.no_grad(): images, pids, camids = data all_imgs.extend(images.numpy()) if device: model.to(device) images = images.to(device) feats = model(images) all_feats.append(feats) all_pids.extend(np.asarray(pids)) all_camids.extend(np.asarray(camids)) all_feats = torch.cat(all_feats, dim=0) # query qf = all_feats[:num_query] q_pids = np.asarray(all_pids[:num_query]) q_camids = np.asarray(all_camids[:num_query]) q_imgs = all_imgs[:num_query] # gallery gf = all_feats[num_query:] g_pids = np.asarray(all_pids[num_query:]) g_camids = np.asarray(all_camids[num_query:]) g_imgs = all_imgs[num_query:] if not re_ranking:: m, n = qf.shape[0], gf.shape[0] distmat = torch.pow(qf, 2).sum(dim=1, keepdim=True).expand(m, n) + \ torch.pow(gf, 2).sum(dim=1, keepdim=True).expand(n, m).t() distmat.addmm_(1, -2, qf, gf.t()) distmat = distmat.cpu().numpy() else: print('Calculating Distance') q_g_dist = np.dot(qf.data.cpu(), np.transpose(gf.data.cpu())) q_q_dist = np.dot(qf.data.cpu(), np.transpose(qf.data.cpu())) g_g_dist = np.dot(gf.data.cpu(), np.transpose(gf.data.cpu())) print('Re-ranking:') distmat= re_ranking(q_g_dist, q_q_dist, g_g_dist) indices = np.argsort(distmat, axis=1) mean=cfg.INPUT.PIXEL_MEAN std=cfg.INPUT.PIXEL_STD top_k = 7 for i in range(num_query): # get query pid and camid q_pid = q_pids[i] q_camid = q_camids[i] # remove gallery samples that have the same pid and camid with query order = indices[i] remove = (g_pids[order] == q_pid) & (g_camids[order] == q_camid) keep = np.invert(remove) # binary vector, positions with value 1 are correct matches true_index = indices[i][keep] plt.title("top5 query",fontsize=15) plt.subplot(181) img = np.clip(q_imgs[i].transpose(1,2,0)*std+mean,0.0,1.0) plt.imshow(img) for j in range(top_k): plt.subplot(182+j) img = np.clip(g_imgs[true_index[j]].transpose(1,2,0)*std+mean,0.0,1.0) plt.imshow(img) plt.savefig("./show/{}.jpg".format(i)) logger.info('Testing complete')