def run(config, num_checkpoint, epoch_end, output_filename): task = get_task(config) preprocess_opt = task.get_preprocess_opt() dataloader = get_dataloader(config, 'train', get_transform(config, 'dev', **preprocess_opt)) model = task.get_model() checkpoints = get_checkpoints(config, num_checkpoint, epoch_end) print('checkpoints:') print('\n'.join(checkpoints)) utils.checkpoint.load_checkpoint(model, None, checkpoints[0]) for i, checkpoint in enumerate(checkpoints[1:]): model2 = get_task(config).get_model() last_epoch, _ = utils.checkpoint.load_checkpoint( model2, None, checkpoint) swa.moving_average(model, model2, 1. / (i + 2)) with torch.no_grad(): swa.bn_update(dataloader, model) output_name = '{}.{}.{:03d}'.format(output_filename, num_checkpoint, last_epoch) print('save {}'.format(output_name)) utils.checkpoint.save_checkpoint( config, model, None, 0, 0, name=output_name, weights_dict={'state_dict': model.state_dict()})
def run(config, num_checkpoint, epoch_end, output_filename): dataloader = get_dataloader(config, split='val', transform=None) model = get_model(config).cuda() checkpoints = get_checkpoints(config, num_checkpoint, epoch_end) utils.checkpoint.load_checkpoint(config, model, checkpoints[0]) for i, checkpoint in enumerate(checkpoints[1:]): model2 = get_model(config).cuda() last_epoch, _, _ = utils.checkpoint.load_checkpoint(config, model2, checkpoint) swa.moving_average(model, model2, 1. / (i + 2)) with torch.no_grad(): swa.bn_update(dataloader, model) # output_name = '{}.{}.{:03d}'.format(output_filename, num_checkpoint, last_epoch) # print('save {}'.format(output_name)) utils.checkpoint.save_checkpoint(config, model, None, None, epoch_end, weights_dict={'state_dict': model.state_dict()}, name=output_filename)
def run(args): df = pd.read_csv(args.df_path) df_train = df[df['fold'] != args.fold] model = get_model(args).cuda() dataloader = get_dataloader(args.data_dir, df_train, 'train', args.pretrain, args.batch_size) checkpoints = get_checkpoints(args) checkpoint.load_checkpoint( args, model, None, checkpoint=checkpoints[0] ) # args, model, ckpt_name, checkpoint=None, optimizer=None for i, ckpt in enumerate(checkpoints[1:]): print(i, ckpt) model2 = get_model(args).cuda() last_epoch, _ = checkpoint.load_checkpoint(args, model2, None, checkpoint=ckpt) if args.ema is None: swa.moving_average(model, model2, 1. / (i + 2)) else: swa.moving_average(model, model2, args.ema) with torch.no_grad(): swa.bn_update(dataloader, model) if args.ema is not None: output_name = f'model_ema_{len(checkpoints)}' else: output_name = f'model_swa_{len(checkpoints)}' print('save {}'.format(output_name)) checkpoint.save_checkpoint(args, model, None, 0, 0, name=output_name, weights_dict={'state_dict': model.state_dict()})
def run(config, num_checkpoint, epoch_end, output_filename): dataloader = get_dataloader(config, 'train', get_transform(config, 'val')) model = get_model(config) if torch.cuda.is_available(): model = model.cuda() checkpoints = get_checkpoints(config, num_checkpoint, epoch_end) utils.checkpoint.load_checkpoint(model, None, checkpoints[0]) for i, checkpoint in enumerate(checkpoints[1:]): model2 = get_model(config) if torch.cuda.is_available(): model2 = model2.cuda() last_epoch, _ = utils.checkpoint.load_checkpoint(model2, None, checkpoint) swa.moving_average(model, model2, 1. / (i + 2)) with torch.no_grad(): swa.bn_update(dataloader, model) output_name = '{}.{}.{:03d}'.format(output_filename, num_checkpoint, last_epoch) print('save {}'.format(output_name)) utils.checkpoint.save_checkpoint(config, model, None, 0, 0, name=output_name, weights_dict={'state_dict': model.state_dict()})
def inference_samples(args, model, transform, batch_size, query_txt, query_dir, gallery_dir, save_dir, k1=20, k2=6, p=0.3, use_rerank=False, use_flip=False, max_rank=200, bn_keys=[]): print("==>load data info..") if query_txt != "": query_list = list() with open(query_txt, 'r') as f: lines = f.readlines() for i, line in enumerate(lines): data = line.split(" ") image_name = data[0].split("/")[1] img_file = os.path.join(query_dir, image_name) query_list.append(img_file) else: query_list = [ os.path.join(query_dir, x) for x in os.listdir(query_dir) ] gallery_list = [ os.path.join(gallery_dir, x) for x in os.listdir(gallery_dir) ] query_num = len(query_list) if args.save_fname != '': print(query_list[:10]) query_list = sorted(query_list) print(query_list[:10]) gallery_list = sorted(gallery_list) print("==>build dataloader..") image_set = ImageDataset(query_list + gallery_list, transform) dataloader = DataLoader(image_set, sampler=SequentialSampler(image_set), batch_size=batch_size, num_workers=4) bn_dataloader = DataLoader(image_set, sampler=RandomSampler(image_set), batch_size=batch_size, num_workers=4, drop_last=True) print("==>model inference..") model = model.to(device) if args.adabn and len(bn_keys) > 0: print("==> using adabn for specific bn layers") specific_bn_update(model, bn_dataloader, cumulative=not args.adabn_emv, bn_keys=bn_keys) elif args.adabn: print("==> using adabn for all bn layers") bn_update(model, bn_dataloader, cumulative=not args.adabn_emv) model.eval() feats = [] with torch.no_grad(): for batch in tqdm(dataloader, total=len(dataloader)): data = batch data = data.cuda() if use_flip: ff = torch.FloatTensor(data.size(0), 2048 * 2).zero_() for i in range(2): # flip if i == 1: data = data.index_select( 3, torch.arange(data.size(3) - 1, -1, -1).long().to('cuda')) outputs = model(data) f = outputs.data.cpu() # cat if i == 0: ff[:, :2048] = f if i == 1: ff[:, 2048:] = f else: ff = model(data).data.cpu() feats.append(ff) all_feature = torch.cat(feats, dim=0) # DBA if args.dba: k2 = args.dba_k2 alpha = args.dba_alpha assert alpha < 0 print("==>using DBA k2:{} alpha:{}".format(k2, alpha)) st = time.time() # [todo] heap sort distmat = euclidean_dist(all_feature, all_feature) # initial_rank = distmat.numpy().argsort(axis=1) initial_rank = np.argpartition(distmat.numpy(), range(1, k2 + 1)) all_feature = all_feature.numpy() V_qe = np.zeros_like(all_feature, dtype=np.float32) weights = np.logspace(0, alpha, k2).reshape((-1, 1)) with tqdm(total=len(all_feature)) as pbar: for i in range(len(all_feature)): V_qe[i, :] = np.mean(all_feature[initial_rank[i, :k2], :] * weights, axis=0) pbar.update(1) # import pdb;pdb.set_trace() all_feature = V_qe del V_qe all_feature = torch.from_numpy(all_feature) fnorm = torch.norm(all_feature, p=2, dim=1, keepdim=True) all_feature = all_feature.div(fnorm.expand_as(all_feature)) print("DBA cost:", time.time() - st) # aQE: weight query expansion if args.aqe: k2 = args.aqe_k2 alpha = args.aqe_alpha print("==>using weight query expansion k2: {} alpha: {}".format( k2, alpha)) st = time.time() all_feature = F.normalize(all_feature, p=2, dim=1) # all_feature = all_feature.numpy() # all_features = [] # with tqdm(total=len(all_feature)) as pbar: # for i in range(len(all_feature)): # all_features.append(aqe_func(all_feature[i],all_feature=all_feature,k2=k2,alpha=alpha)) # pbar.update(1) # all_feature = np.stack(all_features,axis=0) # all_feature = torch.from_numpy(all_feature) # all_feature = F.normalize(all_feature, p=2, dim=1) # fast by gpu all_feature = aqe_func_gpu(all_feature, k2, alpha, len_slice=2000) print("aQE cost:", time.time() - st) print('feature shape:', all_feature.size()) if args.pseudo: print("==> using pseudo eps:{} minPoints:{} maxpoints:{}".format( args.pseudo_eps, args.pseudo_minpoints, args.pseudo_maxpoints)) st = time.time() all_feature = F.normalize(all_feature, p=2, dim=1) if args.pseudo_visual: all_distmat, kdist = get_sparse_distmat( all_feature, eps=args.pseudo_eps + 0.1, len_slice=2000, use_gpu=True, dist_k=args.pseudo_minpoints) plt.plot(list(range(len(kdist))), np.sort(kdist), linewidth=0.5) plt.savefig('test_kdist.png') plt.savefig(save_dir + 'test_kdist.png') else: all_distmat = get_sparse_distmat(all_feature, eps=args.pseudo_eps + 0.1, len_slice=2000, use_gpu=True) # print(all_distmat.todense()[0]) pseudolabels = predict_pseudo_label(all_distmat, args.pseudo_eps, args.pseudo_minpoints, args.pseudo_maxpoints, args.pseudo_algorithm) print("pseudo cost: {}s".format(time.time() - st)) print("pseudo id cnt:", len(pseudolabels)) print("pseudo img cnt:", len([x for k, v in pseudolabels.items() for x in v])) # # save all_list = query_list + gallery_list save_path = args.pseudo_savepath pid = args.pseudo_startid camid = 0 for k, v in pseudolabels.items(): os.makedirs(os.path.join(save_path, str(pid)), exist_ok=True) for _index in pseudolabels[k]: filename = all_list[_index].split("/")[-1] new_filename = str(pid) + "_c" + str(camid) + ".png" shutil.copy(all_list[_index], os.path.join(save_path, str(pid), new_filename)) camid += 1 pid += 1 gallery_feat = all_feature[query_num:] query_feat = all_feature[:query_num] if use_rerank: print("==>use re_rank") st = time.time() distmat = re_rank(query_feat, gallery_feat, k1, k2, p) print("re_rank cost:", time.time() - st) else: st = time.time() weights = [1, 1, 1, 1 / 2, 1 / 2, 1 / 3, 1 / 3, 1 / 3] print('==> using mgn_euclidean_dist') print('==> using weights:', weights) # distmat = euclidean_dist(query_feat, gallery_feat) distmat = None if use_flip: for i in range(2): dist = mgn_euclidean_dist( query_feat[:, i * 2048:(i + 1) * 2048], gallery_feat[:, i * 2048:(i + 1) * 2048], norm=True, weights=weights) if distmat is None: distmat = dist / 2.0 else: distmat += dist / 2.0 else: distmat = mgn_euclidean_dist(query_feat, gallery_feat, norm=True, weights=weights) print("euclidean_dist cost:", time.time() - st) distmat = distmat.numpy() num_q, num_g = distmat.shape print("==>saving..") if args.post: qfnames = [fname.split('/')[-1] for fname in query_list] gfnames = [fname.split('/')[-1] for fname in gallery_list] st = time.time() print("post json using top_per:", args.post_top_per) res_dict = get_post_json(distmat, qfnames, gfnames, args.post_top_per) print("post cost:", time.time() - st) else: # [todo] fast test print("==>sorting..") st = time.time() indices = np.argsort(distmat, axis=1) print("argsort cost:", time.time() - st) # print(indices[:2, :max_rank]) # st = time.time() # indices = np.argpartition( distmat, range(1,max_rank+1)) # print("argpartition cost:",time.time()-st) # print(indices[:2, :max_rank]) max_200_indices = indices[:, :max_rank] res_dict = dict() for q_idx in range(num_q): filename = query_list[q_idx].split('/')[-1] max_200_files = [ gallery_list[i].split('/')[-1] for i in max_200_indices[q_idx] ] res_dict[filename] = max_200_files if args.dba: save_fname = 'mgnsub_dba.json' elif args.aqe: save_fname = 'mgnsub_aqe.json' else: save_fname = 'mgnsub.json' if use_rerank: save_fname = 'rerank_' + save_fname if args.adabn: if args.adabn_all: save_fname = 'adabnall_' + save_fname else: save_fname = 'adabn_' + save_fname if use_flip: save_fname = 'flip_' + save_fname if args.post: save_fname = 'post_' + save_fname save_fname = args.save_fname + save_fname print('savefname:', save_fname) with open(save_dir + save_fname, 'w', encoding='utf-8') as f: json.dump(res_dict, f) with open(save_dir + save_fname.replace('.json', '.pkl'), 'wb') as fid: pickle.dump(distmat, fid, -1)
def inference_val(args, model, dataloader, num_query, save_dir, k1=20, k2=6, p=0.3, use_rerank=False, use_flip=False, n_randperm=0, bn_keys=[]): model = model.to(device) if args.adabn and len(bn_keys) > 0: print("==> using adabn for specific bn layers") specific_bn_update(model, dataloader, cumulative=not args.adabn_emv, bn_keys=bn_keys) elif args.adabn: print("==> using adabn for all bn layers") bn_update(model, dataloader, cumulative=not args.adabn_emv) model.eval() feats, pids, camids = [], [], [] with torch.no_grad(): for batch in tqdm(dataloader, total=len(dataloader)): data, pid, camid, _ = batch data = data.cuda() if use_flip: ff = torch.FloatTensor(data.size(0), 2048 * 2).zero_() for i in range(2): # flip if i == 1: data = data.index_select( 3, torch.arange(data.size(3) - 1, -1, -1).long().to('cuda')) outputs = model(data) f = outputs.data.cpu() # cat if i == 0: ff[:, :2048] = f if i == 1: ff[:, 2048:] = f # ff = F.normalize(ff, p=2, dim=1) else: ff = model(data).data.cpu() # ff = F.normalize(ff, p=2, dim=1) feats.append(ff) pids.append(pid) camids.append(camid) all_feature = torch.cat(feats, dim=0) # all_feature = all_feature[:,:1024+512] pids = torch.cat(pids, dim=0) camids = torch.cat(camids, dim=0) # DBA if args.dba: k2 = args.dba_k2 alpha = args.dba_alpha assert alpha < 0 print("==>using DBA k2:{} alpha:{}".format(k2, alpha)) st = time.time() # [todo] heap sort distmat = euclidean_dist(all_feature, all_feature) # initial_rank = distmat.numpy().argsort(axis=1) initial_rank = np.argpartition(distmat.numpy(), range(1, k2 + 1)) all_feature = all_feature.numpy() V_qe = np.zeros_like(all_feature, dtype=np.float32) weights = np.logspace(0, alpha, k2).reshape((-1, 1)) with tqdm(total=len(all_feature)) as pbar: for i in range(len(all_feature)): V_qe[i, :] = np.mean(all_feature[initial_rank[i, :k2], :] * weights, axis=0) pbar.update(1) # import pdb;pdb.set_trace() all_feature = V_qe del V_qe all_feature = torch.from_numpy(all_feature) fnorm = torch.norm(all_feature, p=2, dim=1, keepdim=True) all_feature = all_feature.div(fnorm.expand_as(all_feature)) print("DBA cost:", time.time() - st) # aQE: weight query expansion if args.aqe: k2 = args.aqe_k2 alpha = args.aqe_alpha print("==>using weight query expansion k2: {} alpha: {}".format( k2, alpha)) st = time.time() all_feature = F.normalize(all_feature, p=2, dim=1) # all_feature = all_feature.numpy() # all_features = [] # with tqdm(total=len(all_feature)) as pbar: # for i in range(len(all_feature)): # all_features.append(aqe_func(all_feature[i],all_feature=all_feature,k2=k2,alpha=alpha)) # pbar.update(1) # all_feature = np.stack(all_features,axis=0) # all_feature = torch.from_numpy(all_feature) # all_feature = F.normalize(all_feature, p=2, dim=1) # norm_feature = F.normalize(all_feature, p=2, dim=1) # norm_feature = norm_feature.numpy() # all_feature = all_feature.numpy() # all_features = [] # with tqdm(total=len(all_feature)) as pbar: # for i in range(len(all_feature)): # all_features.append(aqe_nonorm_func(norm_feature[i],all_norm_feature_T=norm_feature.T ,all_feature=all_feature,k2=k2,alpha=alpha)) # pbar.update(1) # part 2hour for val # part_norm_feat = [] # for i in range(8): # norm_feature = F.normalize(all_feature[:,i*256:(i+1)*256], p=2, dim=1) # part_norm_feat.append(norm_feature) # norm_feature = torch.cat(part_norm_feat,dim=1) # norm_feature = norm_feature.numpy() # all_feature = all_feature.numpy() # all_features = [] # with tqdm(total=len(all_feature)) as pbar: # for i in range(len(all_feature)): # all_features.append(mgn_aqe_func(norm_feature[i],all_norm_feature_T=norm_feature.T ,all_feature=all_feature,k2=k2,alpha=alpha)) # pbar.update(1) # all_feature = np.stack(all_features,axis=0) # all_feature = torch.from_numpy(all_feature) # all_feature = F.normalize(all_feature, p=2, dim=1) all_feature = aqe_func_gpu(all_feature, k2, alpha, len_slice=2000) print("aQE cost:", time.time() - st) # import pdb;pdb.set_trace() print('feature shape:', all_feature.size()) # # for k1 in range(5,10,2): # for k2 in range(2,5,1): # for l in range(5,8): # p = l*0.1 if n_randperm <= 0: k2 = args.k2 gallery_feat = all_feature[num_query:] query_feat = all_feature[:num_query] query_pid = pids[:num_query] query_camid = camids[:num_query] gallery_pid = pids[num_query:] gallery_camid = camids[num_query:] distmat = None if use_rerank: print('==> using rerank') distmat = re_rank(query_feat, gallery_feat, args.k1, args.k2, p) else: # if args.aqe: # print('==> using euclidean_dist') # distmat = euclidean_dist(query_feat, gallery_feat) # else: weights = [1, 1, 1, 1 / 2, 1 / 2, 1 / 3, 1 / 3, 1 / 3] print('==> using mgn_euclidean_dist') print('==> using weights:', weights) if use_flip: for i in range(2): dist = mgn_euclidean_dist( query_feat[:, i * 2048:(i + 1) * 2048], gallery_feat[:, i * 2048:(i + 1) * 2048], norm=True, weights=weights) if distmat is None: distmat = dist / 2.0 else: distmat += dist / 2.0 else: distmat = mgn_euclidean_dist(query_feat, gallery_feat, norm=True, weights=weights) cmc, mAP, _ = eval_func(distmat, query_pid.numpy(), gallery_pid.numpy(), query_camid.numpy(), gallery_camid.numpy()) else: k2 = args.k2 torch.manual_seed(0) cmc = 0 mAP = 0 for i in range(n_randperm): index = torch.randperm(all_feature.size()[0]) query_feat = all_feature[index][:num_query] gallery_feat = all_feature[index][num_query:] query_pid = pids[index][:num_query] query_camid = camids[index][:num_query] gallery_pid = pids[index][num_query:] gallery_camid = camids[index][num_query:] if use_rerank: print('==> using rerank') st = time.time() distmat = re_rank(query_feat, gallery_feat, args.k1, args.k2, p) print("re_rank cost:", time.time() - st) else: print('==> using euclidean_dist') st = time.time() # distmat = euclidean_dist(query_feat, gallery_feat) # weights = [1,1,1,1,1,1,1,1] weights = [1, 1, 1, 1 / 2, 1 / 2, 1 / 3, 1 / 3, 1 / 3] # weights = [2,1,1,1/2,1/2,1/3,1/3,1/3] print('==> using mgn_euclidean_dist') print('==> using weights:', weights) # distmat = euclidean_dist(query_feat, gallery_feat) distmat = None if use_flip: for i in range(2): dist = mgn_euclidean_dist( query_feat[:, i * 2048:(i + 1) * 2048], gallery_feat[:, i * 2048:(i + 1) * 2048], norm=True, weights=weights) if distmat is None: distmat = dist / 2.0 else: distmat += dist / 2.0 else: distmat = mgn_euclidean_dist(query_feat, gallery_feat, norm=True, weights=weights) print("euclidean_dist cost:", time.time() - st) _cmc, _mAP, _ = eval_func(distmat, query_pid.numpy(), gallery_pid.numpy(), query_camid.numpy(), gallery_camid.numpy()) cmc += _cmc / n_randperm mAP += _mAP / n_randperm print('Validation Result:') if use_rerank: print(str(k1) + " - " + str(k2) + " - " + str(p)) print('mAP: {:.2%}'.format(mAP)) for r in [1, 5, 10]: print('CMC Rank-{}: {:.2%}'.format(r, cmc[r - 1])) print('average of mAP and rank1: {:.2%}'.format((mAP + cmc[0]) / 2.0)) with open(save_dir + 'eval.txt', 'a') as f: if use_rerank: f.write('==> using rerank\n') f.write(str(k1) + " - " + str(k2) + " - " + str(p) + "\n") else: f.write('==> using euclidean_dist\n') f.write('mAP: {:.2%}'.format(mAP) + "\n") for r in [1, 5, 10]: f.write('CMC Rank-{}: {:.2%}'.format(r, cmc[r - 1]) + "\n") f.write('average of mAP and rank1: {:.2%}\n'.format( (mAP + cmc[0]) / 2.0)) f.write('------------------------------------------\n') f.write('------------------------------------------\n') f.write('\n\n')
def inference_samples(args, model, transform, batch_size, query_txt, query_dir, gallery_dir, save_dir, k1=20, k2=6, p=0.3, use_rerank=False, use_flip=False, max_rank=200, bn_keys=[]): print("==>load data info..") if query_txt != "": query_list = list() with open(query_txt, 'r') as f: lines = f.readlines() for i, line in enumerate(lines): data = line.split(" ") image_name = data[0].split("/")[1] img_file = os.path.join(query_dir, image_name) query_list.append(img_file) else: query_list = [ os.path.join(query_dir, x) for x in os.listdir(query_dir) ] gallery_list = [ os.path.join(gallery_dir, x) for x in os.listdir(gallery_dir) ] query_num = len(query_list) if args.save_fname != '': print(query_list[:10]) query_list = sorted(query_list) print(query_list[:10]) gallery_list = sorted(gallery_list) print("==>build dataloader..") image_set = ImageDataset(query_list + gallery_list, transform) dataloader = DataLoader(image_set, sampler=SequentialSampler(image_set), batch_size=batch_size, num_workers=6) bn_dataloader = DataLoader(image_set, sampler=RandomSampler(image_set), batch_size=batch_size, num_workers=6, drop_last=True) print("==>model inference..") model = model.to(device) if args.adabn and len(bn_keys) > 0: print("==> using adabn for specific bn layers") specific_bn_update(model, bn_dataloader, cumulative=not args.adabn_emv, bn_keys=bn_keys) elif args.adabn: print("==> using adabn for all bn layers") bn_update(model, bn_dataloader, cumulative=not args.adabn_emv) model.eval() feats = [] with torch.no_grad(): for batch in tqdm(dataloader, total=len(dataloader)): data = batch data = data.cuda() if use_flip: ff = torch.FloatTensor(data.size(0), 2048 * 2).zero_() for i in range(2): # flip if i == 1: data = data.index_select( 3, torch.arange(data.size(3) - 1, -1, -1).long().to('cuda')) outputs = model(data) f = outputs.data.cpu() # cat if i == 0: ff[:, :2048] = F.normalize(f, p=2, dim=1) if i == 1: ff[:, 2048:] = F.normalize(f, p=2, dim=1) ff = F.normalize(ff, p=2, dim=1) else: ff = model(data).data.cpu() ff = F.normalize(ff, p=2, dim=1) feats.append(ff) all_feature = torch.cat(feats, dim=0) # DBA if args.dba: k2 = args.dba_k2 alpha = args.dba_alpha assert alpha < 0 print("==>using DBA k2:{} alpha:{}".format(k2, alpha)) st = time.time() # [todo] heap sort distmat = euclidean_dist(all_feature, all_feature) # initial_rank = distmat.numpy().argsort(axis=1) initial_rank = np.argpartition(distmat.numpy(), range(1, k2 + 1)) all_feature = all_feature.numpy() V_qe = np.zeros_like(all_feature, dtype=np.float32) weights = np.logspace(0, alpha, k2).reshape((-1, 1)) with tqdm(total=len(all_feature)) as pbar: for i in range(len(all_feature)): V_qe[i, :] = np.mean(all_feature[initial_rank[i, :k2], :] * weights, axis=0) pbar.update(1) # import pdb;pdb.set_trace() all_feature = V_qe del V_qe all_feature = torch.from_numpy(all_feature) fnorm = torch.norm(all_feature, p=2, dim=1, keepdim=True) all_feature = all_feature.div(fnorm.expand_as(all_feature)) print("DBA cost:", time.time() - st) # aQE: weight query expansion if args.aqe: k2 = args.aqe_k2 alpha = args.aqe_alpha print("==>using weight query expansion k2: {} alpha: {}".format( k2, alpha)) st = time.time() # fast by gpu all_feature = aqe_func_gpu(all_feature, k2, alpha, len_slice=2000) print("aQE cost:", time.time() - st) print('feature shape:', all_feature.size()) if args.pseudo: print("==> using pseudo eps:{} minPoints:{} maxpoints:{}".format( args.pseudo_eps, args.pseudo_minpoints, args.pseudo_maxpoints)) st = time.time() all_feature = F.normalize(all_feature, p=2, dim=1) if args.pseudo_hist: print("==> predict histlabel...") img_filenames = query_list + gallery_list img_idx = list(range(len(img_filenames))) imgs = {'filename': img_filenames, 'identity': img_idx} df_img = pd.DataFrame(imgs) hist_labels = mmcv.track_parallel_progress(img_hist_predictor, df_img['filename'], 6) print("hist label describe..") unique_hist_labels = sorted(list(set(hist_labels))) hl_idx = [] hl_query_infos = [] hl_gallery_infos = [] for label_idx in range(len(unique_hist_labels)): hl_query_infos.append([]) hl_gallery_infos.append([]) hl_idx.append([]) with tqdm(total=len(img_filenames)) as pbar: for idx, info in enumerate(img_filenames): for label_idx in range(len(unique_hist_labels)): if hist_labels[idx] == unique_hist_labels[label_idx]: if idx < len(query_list): hl_query_infos[label_idx].append(info) else: hl_gallery_infos[label_idx].append(info) hl_idx[label_idx].append(idx) pbar.update(1) for label_idx in range(len(unique_hist_labels)): print('hist_label:', unique_hist_labels[label_idx], ' query number:', len(hl_query_infos[label_idx])) print('hist_label:', unique_hist_labels[label_idx], ' gallery number:', len(hl_gallery_infos[label_idx])) print( 'hist_label:', unique_hist_labels[label_idx], ' q+g number:', len(hl_query_infos[label_idx]) + len(hl_gallery_infos[label_idx])) print('hist_label:', unique_hist_labels[label_idx], ' idx q+g number:', len(hl_idx[label_idx])) # pseudo pid = args.pseudo_startid camid = 0 all_list = query_list + gallery_list save_path = args.pseudo_savepath pseudo_eps = args.pseudo_eps pseudo_minpoints = args.pseudo_minpoints for label_idx in range(len(unique_hist_labels)): # if label_idx == 0: # pseudo_eps = 0.6 # else: # pseudo_eps = 0.75 if label_idx == 0: pseudo_eps = 0.65 else: pseudo_eps = 0.80 feature = all_feature[hl_idx[label_idx]] img_list = [all_list[idx] for idx in hl_idx[label_idx]] print("==> get sparse distmat!") if args.pseudo_visual: all_distmat, kdist = get_sparse_distmat( feature, eps=pseudo_eps + 0.05, len_slice=2000, use_gpu=True, dist_k=pseudo_minpoints) plt.plot(list(range(len(kdist))), np.sort(kdist), linewidth=0.5) plt.savefig('test_kdist_hl{}_eps{}_{}.png'.format( label_idx, pseudo_eps, pseudo_minpoints)) plt.savefig(save_dir + 'test_kdist_hl{}_eps{}_{}.png'.format( label_idx, pseudo_eps, pseudo_minpoints)) else: all_distmat = get_sparse_distmat(feature, eps=pseudo_eps + 0.05, len_slice=2000, use_gpu=True) print("==> predict pseudo label!") pseudolabels = predict_pseudo_label(all_distmat, pseudo_eps, pseudo_minpoints, args.pseudo_maxpoints, args.pseudo_algorithm) print( "==> using pseudo eps:{} minPoints:{} maxpoints:{}".format( pseudo_eps, pseudo_minpoints, args.pseudo_maxpoints)) print("pseudo cost: {}s".format(time.time() - st)) print("pseudo id cnt:", len(pseudolabels)) print("pseudo img cnt:", len([x for k, v in pseudolabels.items() for x in v])) if label_idx == 0: sf = 1 else: sf = 1 sample_id_cnt = 0 sample_file_cnt = 0 nignore_query = 0 for i, (k, v) in enumerate(pseudolabels.items()): if i % sf != 0: continue # query_cnt = 0 # for _index in pseudolabels[k]: # if _index<len(query_list): # query_cnt += 1 # if query_cnt>=2: # nignore_query += 1 # continue os.makedirs(os.path.join(save_path, str(pid)), exist_ok=True) for _index in pseudolabels[k]: filename = img_list[_index].split("/")[-1] new_filename = str(pid) + "_c" + str(camid) + ".png" shutil.copy( img_list[_index], os.path.join(save_path, str(pid), new_filename)) camid += 1 sample_file_cnt += 1 sample_id_cnt += 1 pid += 1 print("pseudo ignore id cnt:", nignore_query) print("sample id cnt:", sample_id_cnt) print("sample file cnt:", sample_file_cnt) else: if args.pseudo_visual: all_distmat, kdist = get_sparse_distmat( all_feature, eps=args.pseudo_eps + 0.05, len_slice=2000, use_gpu=True, dist_k=args.pseudo_minpoints) plt.plot(list(range(len(kdist))), np.sort(kdist), linewidth=0.5) plt.savefig('test_kdist.png') plt.savefig(save_dir + 'test_kdist.png') else: all_distmat = get_sparse_distmat(all_feature, eps=args.pseudo_eps + 0.05, len_slice=2000, use_gpu=True) # print(all_distmat.todense()[0]) pseudolabels = predict_pseudo_label(all_distmat, args.pseudo_eps, args.pseudo_minpoints, args.pseudo_maxpoints, args.pseudo_algorithm) print("pseudo cost: {}s".format(time.time() - st)) print("pseudo id cnt:", len(pseudolabels)) print("pseudo img cnt:", len([x for k, v in pseudolabels.items() for x in v])) # # save all_list = query_list + gallery_list save_path = args.pseudo_savepath pid = args.pseudo_startid camid = 0 nignore_query = 0 for k, v in pseudolabels.items(): os.makedirs(os.path.join(save_path, str(pid)), exist_ok=True) # [fileter] query_cnt = 0 for _index in pseudolabels[k]: if _index < len(query_list): query_cnt += 1 if query_cnt >= 4: nignore_query += 1 continue for _index in pseudolabels[k]: filename = all_list[_index].split("/")[-1] new_filename = str(pid) + "_c" + str(camid) + ".png" shutil.copy( all_list[_index], os.path.join(save_path, str(pid), new_filename)) camid += 1 pid += 1 print("pseudo ignore id cnt:", nignore_query) else: gallery_feat = all_feature[query_num:] query_feat = all_feature[:query_num] if use_rerank: print("==>use re_rank") st = time.time() k2 = args.k2 # distmat = re_rank(query_feat, gallery_feat, k1, k2, p) num_query = len(query_feat) print("using k1:{} k2:{} lambda:{}".format(args.k1, args.k2, p)) distmat = re_ranking_batch_gpu( torch.cat([query_feat, gallery_feat], dim=0), num_query, args.k1, args.k2, p) print("re_rank cost:", time.time() - st) else: print("==>use euclidean_dist") st = time.time() distmat = euclidean_dist(query_feat, gallery_feat) print("euclidean_dist cost:", time.time() - st) distmat = distmat.numpy() num_q, num_g = distmat.shape print("==>saving..") if args.post: qfnames = [fname.split('/')[-1] for fname in query_list] gfnames = [fname.split('/')[-1] for fname in gallery_list] st = time.time() print("post json using top_per:", args.post_top_per) res_dict = get_post_json(distmat, qfnames, gfnames, args.post_top_per) print("post cost:", time.time() - st) else: # [todo] fast test print("==>sorting..") st = time.time() indices = np.argsort(distmat, axis=1) print("argsort cost:", time.time() - st) # print(indices[:2, :max_rank]) # st = time.time() # indices = np.argpartition( distmat, range(1,max_rank+1)) # print("argpartition cost:",time.time()-st) # print(indices[:2, :max_rank]) max_200_indices = indices[:, :max_rank] res_dict = dict() for q_idx in range(num_q): filename = query_list[q_idx].split('/')[-1] max_200_files = [ gallery_list[i].split('/')[-1] for i in max_200_indices[q_idx] ] res_dict[filename] = max_200_files if args.dba: save_fname = 'sub_dba.json' elif args.aqe: save_fname = 'sub_aqe.json' else: save_fname = 'sub.json' if use_rerank: save_fname = 'rerank_' + save_fname if args.adabn: if args.adabn_all: save_fname = 'adabnall_' + save_fname else: save_fname = 'adabn_' + save_fname if use_flip: save_fname = 'flip_' + save_fname if args.post: save_fname = 'post_' + save_fname save_fname = args.save_fname + save_fname print('savefname:', save_fname) with open(save_dir + save_fname, 'w', encoding='utf-8') as f: json.dump(res_dict, f) with open(save_dir + save_fname.replace('.json', '.pkl'), 'wb') as fid: pickle.dump(distmat, fid, -1)
def inference_val(args, model, dataloader, num_query, save_dir, k1=20, k2=6, p=0.3, use_rerank=False, use_flip=False, n_randperm=0, bn_keys=[]): model = model.to(device) if args.adabn and len(bn_keys) > 0: print("==> using adabn for specific bn layers") specific_bn_update(model, dataloader, cumulative=not args.adabn_emv, bn_keys=bn_keys) elif args.adabn: print("==> using adabn for all bn layers") bn_update(model, dataloader, cumulative=not args.adabn_emv) model.eval() feats, pids, camids = [], [], [] with torch.no_grad(): for batch in tqdm(dataloader, total=len(dataloader)): data, pid, camid, _ = batch data = data.cuda() if use_flip: ff = torch.FloatTensor(data.size(0), 2048 * 2).zero_() for i in range(2): # flip if i == 1: data = data.index_select( 3, torch.arange(data.size(3) - 1, -1, -1).long().to('cuda')) outputs = model(data) f = outputs.data.cpu() # cat if i == 0: ff[:, :2048] = F.normalize(f, p=2, dim=1) if i == 1: ff[:, 2048:] = F.normalize(f, p=2, dim=1) ff = F.normalize(ff, p=2, dim=1) # ff = torch.FloatTensor(data.size(0), 2048).zero_() # for i in range(2): # if i == 1: # data = data.index_select(3, torch.arange(data.size(3) - 1, -1, -1).long().to('cuda')) # outputs = model(data) # f = outputs.data.cpu() # ff = ff + f # fnorm = torch.norm(ff, p=2, dim=1, keepdim=True) # ff = ff.div(fnorm.expand_as(ff)) else: ff = model(data).data.cpu() ff = F.normalize(ff, p=2, dim=1) feats.append(ff) pids.append(pid) camids.append(camid) all_feature = torch.cat(feats, dim=0) # all_feature = all_feature[:,:1024+512] pids = torch.cat(pids, dim=0) camids = torch.cat(camids, dim=0) # DBA if args.dba: k2 = args.dba_k2 alpha = args.dba_alpha assert alpha < 0 print("==>using DBA k2:{} alpha:{}".format(k2, alpha)) st = time.time() # [todo] heap sort distmat = euclidean_dist(all_feature, all_feature) # initial_rank = distmat.numpy().argsort(axis=1) initial_rank = np.argpartition(distmat.numpy(), range(1, k2 + 1)) all_feature = all_feature.numpy() V_qe = np.zeros_like(all_feature, dtype=np.float32) weights = np.logspace(0, alpha, k2).reshape((-1, 1)) with tqdm(total=len(all_feature)) as pbar: for i in range(len(all_feature)): V_qe[i, :] = np.mean(all_feature[initial_rank[i, :k2], :] * weights, axis=0) pbar.update(1) # import pdb;pdb.set_trace() all_feature = V_qe del V_qe all_feature = torch.from_numpy(all_feature) fnorm = torch.norm(all_feature, p=2, dim=1, keepdim=True) all_feature = all_feature.div(fnorm.expand_as(all_feature)) print("DBA cost:", time.time() - st) # aQE: weight query expansion if args.aqe: k2 = args.aqe_k2 alpha = args.aqe_alpha print("==>using weight query expansion k2: {} alpha: {}".format( k2, alpha)) st = time.time() # # [todo] remove norma; normalize is used to to make sure the similiar one is itself # all_feature = F.normalize(all_feature, p=2, dim=1) # sims = torch.mm(all_feature, all_feature.t()).numpy() # # [todo] heap sort # # initial_rank = sims.argsort(axis=1)[:,::-1] # initial_rank = np.argpartition(-sims,range(1,k2+1)) # all_feature = all_feature.numpy() # V_qe = np.zeros_like(all_feature,dtype=np.float32) # # [todo] update query feature only? # with tqdm(total=len(all_feature)) as pbar: # for i in range(len(all_feature)): # # get weights from similarity # weights = sims[i,initial_rank[i,:k2]].reshape((-1,1)) # # weights = (weights-weights.min())/(weights.max()-weights.min()) # weights = np.power(weights,alpha) # # import pdb;pdb.set_trace() # V_qe[i,:] = np.mean(all_feature[initial_rank[i,:k2],:]*weights,axis=0) # pbar.update(1) # # import pdb;pdb.set_trace() # all_feature = V_qe # del V_qe # all_feature = torch.from_numpy(all_feature) # all_feature = F.normalize(all_feature, p=2, dim=1) # func = functools.partial(aqe_func,all_feature=all_feature,k2=k2,alpha=alpha) # all_features = mmcv.track_parallel_progress(func, all_feature, 6) # cpu # all_feature = F.normalize(all_feature, p=2, dim=1) # all_feature = all_feature.numpy() # all_features = [] # with tqdm(total=len(all_feature)) as pbar: # for i in range(len(all_feature)): # all_features.append(aqe_func(all_feature[i],all_feature=all_feature,k2=k2,alpha=alpha)) # pbar.update(1) # all_feature = np.stack(all_features,axis=0) # all_feature = torch.from_numpy(all_feature) # all_feature = F.normalize(all_feature, p=2, dim=1) all_feature = aqe_func_gpu(all_feature, k2, alpha, len_slice=2000) print("aQE cost:", time.time() - st) # import pdb;pdb.set_trace() if args.pseudo: print("==> using pseudo eps:{} minPoints:{} maxpoints:{}".format( args.pseudo_eps, args.pseudo_minpoints, args.pseudo_maxpoints)) st = time.time() # cal sparse distmat all_feature = F.normalize(all_feature, p=2, dim=1) # all_distmat = euclidean_dist(all_feature, all_feature).numpy() # print(all_distmat[0]) # pred1 = predict_pseudo_label(all_distmat,args.pseudo_eps,args.pseudo_minpoints,args.pseudo_maxpoints,args.pseudo_algorithm) # print(list(pred1.keys())[:10]) if args.pseudo_visual: all_distmat, kdist = get_sparse_distmat( all_feature, eps=args.pseudo_eps + 0.1, len_slice=2000, use_gpu=True, dist_k=args.pseudo_minpoints) plt.plot(list(range(len(kdist))), np.sort(kdist), linewidth=0.5) plt.savefig('eval_kdist.png') plt.savefig(save_dir + 'eval_kdist.png') else: all_distmat = get_sparse_distmat(all_feature, eps=args.pseudo_eps + 0.1, len_slice=2000, use_gpu=True) # print(all_distmat.todense()[0]) pseudolabels = predict_pseudo_label(all_distmat, args.pseudo_eps, args.pseudo_minpoints, args.pseudo_maxpoints, args.pseudo_algorithm) print("pseudo cost: {}s".format(time.time() - st)) print("pseudo id cnt:", len(pseudolabels)) print("pseudo img cnt:", len([x for k, v in pseudolabels.items() for x in v])) print("pseudo cost: {}s".format(time.time() - st)) # print(list(pred.keys())[:10]) print('feature shape:', all_feature.size()) # # for k1 in range(5,10,2): # for k2 in range(2,5,1): # for l in range(5,8): # p = l*0.1 if n_randperm <= 0: k2 = args.k2 gallery_feat = all_feature[num_query:] query_feat = all_feature[:num_query] query_pid = pids[:num_query] query_camid = camids[:num_query] gallery_pid = pids[num_query:] gallery_camid = camids[num_query:] if use_rerank: print('==> using rerank') # distmat = re_rank(query_feat, gallery_feat, k1, k2, p) distmat = re_ranking_batch_gpu( torch.cat([query_feat, gallery_feat], dim=0), num_query, args.k1, args.k2, p) else: print('==> using euclidean_dist') distmat = euclidean_dist(query_feat, gallery_feat) cmc, mAP, _ = eval_func(distmat, query_pid.numpy(), gallery_pid.numpy(), query_camid.numpy(), gallery_camid.numpy()) else: k2 = args.k2 torch.manual_seed(0) cmc = 0 mAP = 0 for i in range(n_randperm): index = torch.randperm(all_feature.size()[0]) query_feat = all_feature[index][:num_query] gallery_feat = all_feature[index][num_query:] query_pid = pids[index][:num_query] query_camid = camids[index][:num_query] gallery_pid = pids[index][num_query:] gallery_camid = camids[index][num_query:] if use_rerank: print('==> using rerank') st = time.time() # distmat = re_rank(query_feat, gallery_feat, k1, k2, p) distmat = re_ranking_batch_gpu( torch.cat([query_feat, gallery_feat], dim=0), num_query, args.k1, args.k2, p) print("re_rank cost:", time.time() - st) else: print('==> using euclidean_dist') st = time.time() distmat = euclidean_dist(query_feat, gallery_feat) print("euclidean_dist cost:", time.time() - st) _cmc, _mAP, _ = eval_func(distmat, query_pid.numpy(), gallery_pid.numpy(), query_camid.numpy(), gallery_camid.numpy()) cmc += _cmc / n_randperm mAP += _mAP / n_randperm print('Validation Result:') if use_rerank: print(str(k1) + " - " + str(k2) + " - " + str(p)) print('mAP: {:.2%}'.format(mAP)) for r in [1, 5, 10]: print('CMC Rank-{}: {:.2%}'.format(r, cmc[r - 1])) print('average of mAP and rank1: {:.2%}'.format((mAP + cmc[0]) / 2.0)) with open(save_dir + 'eval.txt', 'a') as f: if use_rerank: f.write('==> using rerank\n') f.write(str(k1) + " - " + str(k2) + " - " + str(p) + "\n") else: f.write('==> using euclidean_dist\n') f.write('mAP: {:.2%}'.format(mAP) + "\n") for r in [1, 5, 10]: f.write('CMC Rank-{}: {:.2%}'.format(r, cmc[r - 1]) + "\n") f.write('average of mAP and rank1: {:.2%}\n'.format( (mAP + cmc[0]) / 2.0)) f.write('------------------------------------------\n') f.write('------------------------------------------\n') f.write('\n\n')