コード例 #1
0
def extract_recalls(data,
                    data_root,
                    width,
                    net,
                    checkpoint,
                    dim,
                    batch_size,
                    nThreads,
                    pool_feature,
                    gallery_eq_query,
                    model=None,
                    epoch=0,
                    org_feature=False,
                    save_txt="",
                    args=None):


    gallery_feature, gallery_labels, query_feature, query_labels = \
        Model2Feature(data=data, root=data_root, width=width, net=net, checkpoint=checkpoint,
                    dim=dim, batch_size=batch_size, nThreads=nThreads, pool_feature=pool_feature, model=model, org_feature=org_feature, args=args)

    sim_mat = pairwise_similarity(query_feature, gallery_feature)
    if gallery_eq_query is True:
        sim_mat = sim_mat - torch.eye(sim_mat.size(0))

    recall_ks = Recall_at_ks(sim_mat,
                             query_ids=query_labels,
                             gallery_ids=gallery_labels,
                             data=data,
                             args=args,
                             epoch=epoch)

    labels = [x.item() for x in gallery_labels]

    nmi = NMI(gallery_feature, gallery_labels, n_cluster=len(set(labels)))
    print(recall_ks, nmi)
    result = '  '.join(['%.4f' % k for k in (recall_ks.tolist() + [nmi])])

    print('Epoch-%d' % epoch, result)
コード例 #2
0
def main(args):
    batch_time = AverageMeter()
    end = time.time()

    checkpoint = load_checkpoint(args.resume)  #loaded
    print('pool_features:', args.pool_feature)
    epoch = checkpoint['epoch']

    gallery_feature, gallery_labels, query_feature, query_labels = \
    Model2Feature(data=args.data, root=args.data_root, net=args.net, checkpoint=checkpoint
    , batch_size=args.batch_size, nThreads=args.nThreads, pool_feature=args.pool_feature)    #output

    sim_mat = pairwise_similarity(query_feature, gallery_feature)  #成对相似性
    if args.gallery_eq_query is True:
        sim_mat = sim_mat - torch.eye(sim_mat.size(0))

    print('labels', query_labels)
    print('feature:', gallery_feature)

    recall_ks = Recall_at_ks(sim_mat,
                             query_ids=query_labels,
                             gallery_ids=gallery_labels,
                             data=args.data)

    result = '  '.join(['%.4f' % k for k in recall_ks])  #   result=recall_ks
    print('Epoch-%d' % epoch, result)
    batch_time.update(time.time() - end)

    print('Epoch-%d\t' % epoch,
          'Time {batch_time.avg:.3f}\t'.format(batch_time=batch_time))

    import matplotlib.pyplot as plt
    import torchvision
    import numpy as np

    similarity = torch.mm(gallery_feature, gallery_feature.t())
    similarity.size()

    #draw Feature Map
    img = torchvision.utils.make_grid(similarity).numpy()
    plt.imshow(np.transpose(img, (1, 2, 0)))
    plt.show()
コード例 #3
0
def test(args):
    checkpoint = load_latest(args.resume)
    if checkpoint == None:
        print('{} is not avaible! Exit!'.format(args.resume))
        return

    epoch = checkpoint['epoch']
    train_feature, train_labels, test_feature, test_labels = \
        Model2Feature(data=args.data, root=args.data_root, width=args.width, net=args.net, checkpoint=checkpoint,
                       dim=args.dim, batch_size=args.batch_size, nThreads=args.nThreads, pool_feature=args.pool_feature)

    # train-train pairwise similarity
    sim_mat = pairwise_similarity(train_feature, train_feature)
    sim_mat = sim_mat - torch.eye(sim_mat.size(0))
    recall_ks, ks = Recall_at_ks(sim_mat,
                                 query_ids=train_labels,
                                 gallery_ids=train_labels,
                                 data=args.data)
    result = '  '.join(
        ['top@%d:%.4f' % (k, rc) for k, rc in zip(ks, recall_ks)])
    print('Epoch-%d' % epoch, result)
コード例 #4
0
def eval(ckp_path=None, model=None):
    args = Config()
    if (ckp_path != None):
        checkpoint = load_checkpoint(ckp_path, args)
    else:
        checkpoint = model
        checkpoint.eval()
    # print(args.pool_feature)

    gallery_feature, gallery_labels, query_feature, query_labels = \
        Model2Feature(data=args.data,model = checkpoint, batch_size=args.batch_size, nThreads=args.nThreads, pool_feature=args.pool_feature)

    sim_mat = pairwise_similarity(query_feature, gallery_feature)
    if args.gallery_eq_query is True:
        sim_mat = sim_mat - torch.eye(sim_mat.size(0))

    recall_ks = Recall_at_ks(sim_mat,
                             query_ids=query_labels,
                             gallery_ids=gallery_labels,
                             data=args.data)
    if (ckp_path == None):
        checkpoint.train()
    return recall_ks
コード例 #5
0
ファイル: test.py プロジェクト: yuanmengzhixing/Deep_metric
    data_loader = torch.utils.data.DataLoader(data.test,
                                              batch_size=128,
                                              shuffle=False,
                                              drop_last=False)
else:
    data = DataSet.create(args.data, test=False)
    data_loader = torch.utils.data.DataLoader(data.train,
                                              batch_size=128,
                                              shuffle=False,
                                              drop_last=False)

features, labels = extract_features(model,
                                    data_loader,
                                    print_freq=1e5,
                                    metric=None)

num_class = len(set(labels))

sim_mat = -pairwise_distance(features)
if args.data == 'product':

    result = Recall_at_ks_products(sim_mat,
                                   query_ids=labels,
                                   gallery_ids=labels)
else:
    result = Recall_at_ks(sim_mat, query_ids=labels, gallery_ids=labels)
    result = ['%.4f' % r for r in result]
    temp = '  '
    result = temp.join(result)
    print('Epoch-%s' % name, result)
コード例 #6
0
ファイル: test.py プロジェクト: CH-Liang/DSDML
                    '-j',
                    default=16,
                    type=int,
                    metavar='N',
                    help='number of data loading threads (default: 2)')
parser.add_argument('--pool_feature',
                    type=ast.literal_eval,
                    default=False,
                    required=False,
                    help='if True extract feature from the last pool layer')

args = parser.parse_args()
checkpoint = load_checkpoint(args.resume)
print(args.pool_feature)
epoch = checkpoint['epoch']

gallery_feature, gallery_labels, query_feature, query_labels = \
    Model2Feature(data=args.data, root=args.data_root, width=args.width, net=args.net, checkpoint=checkpoint,
                   dim=args.dim, batch_size=args.batch_size, nThreads=args.nThreads, pool_feature=args.pool_feature)

sim_mat = pairwise_similarity(query_feature, gallery_feature)
if args.gallery_eq_query is True:
    sim_mat = sim_mat - torch.eye(sim_mat.size(0))

recall_ks = Recall_at_ks(sim_mat,
                         query_ids=query_labels,
                         gallery_ids=gallery_labels,
                         data=args.data)

result = '  '.join(['%.4f' % k for k in recall_ks])
print('Epoch-%d' % epoch, result)
コード例 #7
0
ファイル: test.py プロジェクト: hyzcn/Deep_metric
    data = DataSet.create(args.data, train=False)
    data_loader = torch.utils.data.DataLoader(data.test,
                                              batch_size=8,
                                              shuffle=False,
                                              drop_last=False)
else:
    print('  train %s***%s' % (args.data, name))
    data = DataSet.create(args.data, test=False)
    data_loader = torch.utils.data.DataLoader(data.train,
                                              batch_size=8,
                                              shuffle=False,
                                              drop_last=False)

features, labels = extract_features(model,
                                    data_loader,
                                    print_freq=999,
                                    metric=None)
# print('embedding dimension is:', len(features[0]))
# print('test data size is :', len(labels))
# num_class = len(set(labels))
# print('number of classes is :', num_class)
# print('compute the NMI index:', NMI(features, labels, n_cluster=num_class))
# print(len(features))
# sim_mat = pairwise_similarity(features)
#  to google net pooling-5
sim_mat = -pairwise_distance(features)
if args.data == 'products':
    print(Recall_at_ks_products(sim_mat, query_ids=labels, gallery_ids=labels))
else:
    print(Recall_at_ks(sim_mat, query_ids=labels, gallery_ids=labels))
コード例 #8
0
# coding=utf-8
from __future__ import absolute_import, print_function
import argparse
import torch
from torch.backends import cudnn
from evaluations import extract_features, pairwise_distance, pairwise_similarity
from evaluations import Recall_at_ks, Recall_at_ks_products, Recall_at_ks_shop
import models
import DataSet
import os
import numpy as np
cudnn.benchmark = True
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
im1 = '/opt/intern/users/xunwang/jd-comp/images/P/img/jfs/t18085/239/1572160811/242071/9e3b6d97/5ad06c21Nd73ffab7.jpg'
im2 = '/opt/intern/users/xunwang/jd-comp/images/P/img/jfs/t17857/121/1655327696/242539/1771960e/5ad06c69N5b34d078.jpg'
from PIL import Image
im1 = Image.open(im1)
im2 = Image.open(im2)
im1.save('1.jpg')
im2.save('2.jpg')
r = '/opt/intern/users/xunwang/checkpoints/bin/jd/512-BN-alpha40/135_model.pth'
PATH = r
model = models.create('vgg', dim=512, pretrained=False)
model = torch.nn.DataParallel(model)
model.load_state_dict(torch.load(PATH))
model = model.cuda()
data = DataSet.create('jd')
data_loader = torch.utils.data.DataLoader(data.gallery,
                                          batch_size=64,
                                          shuffle=False,