def extract_recalls(data, data_root, width, net, checkpoint, dim, batch_size, nThreads, pool_feature, gallery_eq_query, model=None, epoch=0, org_feature=False, save_txt="", args=None): gallery_feature, gallery_labels, query_feature, query_labels = \ Model2Feature(data=data, root=data_root, width=width, net=net, checkpoint=checkpoint, dim=dim, batch_size=batch_size, nThreads=nThreads, pool_feature=pool_feature, model=model, org_feature=org_feature, args=args) sim_mat = pairwise_similarity(query_feature, gallery_feature) if gallery_eq_query is True: sim_mat = sim_mat - torch.eye(sim_mat.size(0)) recall_ks = Recall_at_ks(sim_mat, query_ids=query_labels, gallery_ids=gallery_labels, data=data, args=args, epoch=epoch) labels = [x.item() for x in gallery_labels] nmi = NMI(gallery_feature, gallery_labels, n_cluster=len(set(labels))) print(recall_ks, nmi) result = ' '.join(['%.4f' % k for k in (recall_ks.tolist() + [nmi])]) print('Epoch-%d' % epoch, result)
def main(args): batch_time = AverageMeter() end = time.time() checkpoint = load_checkpoint(args.resume) #loaded print('pool_features:', args.pool_feature) epoch = checkpoint['epoch'] gallery_feature, gallery_labels, query_feature, query_labels = \ Model2Feature(data=args.data, root=args.data_root, net=args.net, checkpoint=checkpoint , batch_size=args.batch_size, nThreads=args.nThreads, pool_feature=args.pool_feature) #output sim_mat = pairwise_similarity(query_feature, gallery_feature) #成对相似性 if args.gallery_eq_query is True: sim_mat = sim_mat - torch.eye(sim_mat.size(0)) print('labels', query_labels) print('feature:', gallery_feature) recall_ks = Recall_at_ks(sim_mat, query_ids=query_labels, gallery_ids=gallery_labels, data=args.data) result = ' '.join(['%.4f' % k for k in recall_ks]) # result=recall_ks print('Epoch-%d' % epoch, result) batch_time.update(time.time() - end) print('Epoch-%d\t' % epoch, 'Time {batch_time.avg:.3f}\t'.format(batch_time=batch_time)) import matplotlib.pyplot as plt import torchvision import numpy as np similarity = torch.mm(gallery_feature, gallery_feature.t()) similarity.size() #draw Feature Map img = torchvision.utils.make_grid(similarity).numpy() plt.imshow(np.transpose(img, (1, 2, 0))) plt.show()
def test(args): checkpoint = load_latest(args.resume) if checkpoint == None: print('{} is not avaible! Exit!'.format(args.resume)) return epoch = checkpoint['epoch'] train_feature, train_labels, test_feature, test_labels = \ Model2Feature(data=args.data, root=args.data_root, width=args.width, net=args.net, checkpoint=checkpoint, dim=args.dim, batch_size=args.batch_size, nThreads=args.nThreads, pool_feature=args.pool_feature) # train-train pairwise similarity sim_mat = pairwise_similarity(train_feature, train_feature) sim_mat = sim_mat - torch.eye(sim_mat.size(0)) recall_ks, ks = Recall_at_ks(sim_mat, query_ids=train_labels, gallery_ids=train_labels, data=args.data) result = ' '.join( ['top@%d:%.4f' % (k, rc) for k, rc in zip(ks, recall_ks)]) print('Epoch-%d' % epoch, result)
def eval(ckp_path=None, model=None): args = Config() if (ckp_path != None): checkpoint = load_checkpoint(ckp_path, args) else: checkpoint = model checkpoint.eval() # print(args.pool_feature) gallery_feature, gallery_labels, query_feature, query_labels = \ Model2Feature(data=args.data,model = checkpoint, batch_size=args.batch_size, nThreads=args.nThreads, pool_feature=args.pool_feature) sim_mat = pairwise_similarity(query_feature, gallery_feature) if args.gallery_eq_query is True: sim_mat = sim_mat - torch.eye(sim_mat.size(0)) recall_ks = Recall_at_ks(sim_mat, query_ids=query_labels, gallery_ids=gallery_labels, data=args.data) if (ckp_path == None): checkpoint.train() return recall_ks
data_loader = torch.utils.data.DataLoader(data.test, batch_size=128, shuffle=False, drop_last=False) else: data = DataSet.create(args.data, test=False) data_loader = torch.utils.data.DataLoader(data.train, batch_size=128, shuffle=False, drop_last=False) features, labels = extract_features(model, data_loader, print_freq=1e5, metric=None) num_class = len(set(labels)) sim_mat = -pairwise_distance(features) if args.data == 'product': result = Recall_at_ks_products(sim_mat, query_ids=labels, gallery_ids=labels) else: result = Recall_at_ks(sim_mat, query_ids=labels, gallery_ids=labels) result = ['%.4f' % r for r in result] temp = ' ' result = temp.join(result) print('Epoch-%s' % name, result)
'-j', default=16, type=int, metavar='N', help='number of data loading threads (default: 2)') parser.add_argument('--pool_feature', type=ast.literal_eval, default=False, required=False, help='if True extract feature from the last pool layer') args = parser.parse_args() checkpoint = load_checkpoint(args.resume) print(args.pool_feature) epoch = checkpoint['epoch'] gallery_feature, gallery_labels, query_feature, query_labels = \ Model2Feature(data=args.data, root=args.data_root, width=args.width, net=args.net, checkpoint=checkpoint, dim=args.dim, batch_size=args.batch_size, nThreads=args.nThreads, pool_feature=args.pool_feature) sim_mat = pairwise_similarity(query_feature, gallery_feature) if args.gallery_eq_query is True: sim_mat = sim_mat - torch.eye(sim_mat.size(0)) recall_ks = Recall_at_ks(sim_mat, query_ids=query_labels, gallery_ids=gallery_labels, data=args.data) result = ' '.join(['%.4f' % k for k in recall_ks]) print('Epoch-%d' % epoch, result)
data = DataSet.create(args.data, train=False) data_loader = torch.utils.data.DataLoader(data.test, batch_size=8, shuffle=False, drop_last=False) else: print(' train %s***%s' % (args.data, name)) data = DataSet.create(args.data, test=False) data_loader = torch.utils.data.DataLoader(data.train, batch_size=8, shuffle=False, drop_last=False) features, labels = extract_features(model, data_loader, print_freq=999, metric=None) # print('embedding dimension is:', len(features[0])) # print('test data size is :', len(labels)) # num_class = len(set(labels)) # print('number of classes is :', num_class) # print('compute the NMI index:', NMI(features, labels, n_cluster=num_class)) # print(len(features)) # sim_mat = pairwise_similarity(features) # to google net pooling-5 sim_mat = -pairwise_distance(features) if args.data == 'products': print(Recall_at_ks_products(sim_mat, query_ids=labels, gallery_ids=labels)) else: print(Recall_at_ks(sim_mat, query_ids=labels, gallery_ids=labels))
# coding=utf-8 from __future__ import absolute_import, print_function import argparse import torch from torch.backends import cudnn from evaluations import extract_features, pairwise_distance, pairwise_similarity from evaluations import Recall_at_ks, Recall_at_ks_products, Recall_at_ks_shop import models import DataSet import os import numpy as np cudnn.benchmark = True os.environ["CUDA_VISIBLE_DEVICES"] = "0" im1 = '/opt/intern/users/xunwang/jd-comp/images/P/img/jfs/t18085/239/1572160811/242071/9e3b6d97/5ad06c21Nd73ffab7.jpg' im2 = '/opt/intern/users/xunwang/jd-comp/images/P/img/jfs/t17857/121/1655327696/242539/1771960e/5ad06c69N5b34d078.jpg' from PIL import Image im1 = Image.open(im1) im2 = Image.open(im2) im1.save('1.jpg') im2.save('2.jpg') r = '/opt/intern/users/xunwang/checkpoints/bin/jd/512-BN-alpha40/135_model.pth' PATH = r model = models.create('vgg', dim=512, pretrained=False) model = torch.nn.DataParallel(model) model.load_state_dict(torch.load(PATH)) model = model.cuda() data = DataSet.create('jd') data_loader = torch.utils.data.DataLoader(data.gallery, batch_size=64, shuffle=False,