def main(args):
    batch_time = AverageMeter()
    end = time.time()

    checkpoint = load_checkpoint(args.resume)  #loaded
    print('pool_features:', args.pool_feature)
    epoch = checkpoint['epoch']

    gallery_feature, gallery_labels, query_feature, query_labels = \
    Model2Feature(data=args.data, root=args.data_root, net=args.net, checkpoint=checkpoint
    , batch_size=args.batch_size, nThreads=args.nThreads, pool_feature=args.pool_feature)    #output

    sim_mat = pairwise_similarity(query_feature, gallery_feature)  #成对相似性
    if args.gallery_eq_query is True:
        sim_mat = sim_mat - torch.eye(sim_mat.size(0))

    print('labels', query_labels)
    print('feature:', gallery_feature)

    recall_ks = Recall_at_ks(sim_mat,
                             query_ids=query_labels,
                             gallery_ids=gallery_labels,
                             data=args.data)

    result = '  '.join(['%.4f' % k for k in recall_ks])  #   result=recall_ks
    print('Epoch-%d' % epoch, result)
    batch_time.update(time.time() - end)

    print('Epoch-%d\t' % epoch,
          'Time {batch_time.avg:.3f}\t'.format(batch_time=batch_time))

    import matplotlib.pyplot as plt
    import torchvision
    import numpy as np

    similarity = torch.mm(gallery_feature, gallery_feature.t())
    similarity.size()

    #draw Feature Map
    img = torchvision.utils.make_grid(similarity).numpy()
    plt.imshow(np.transpose(img, (1, 2, 0)))
    plt.show()
Exemplo n.º 2
0
def test(args):
    checkpoint = load_latest(args.resume)
    if checkpoint == None:
        print('{} is not avaible! Exit!'.format(args.resume))
        return

    epoch = checkpoint['epoch']
    train_feature, train_labels, test_feature, test_labels = \
        Model2Feature(data=args.data, root=args.data_root, width=args.width, net=args.net, checkpoint=checkpoint,
                       dim=args.dim, batch_size=args.batch_size, nThreads=args.nThreads, pool_feature=args.pool_feature)

    # train-train pairwise similarity
    sim_mat = pairwise_similarity(train_feature, train_feature)
    sim_mat = sim_mat - torch.eye(sim_mat.size(0))
    recall_ks, ks = Recall_at_ks(sim_mat,
                                 query_ids=train_labels,
                                 gallery_ids=train_labels,
                                 data=args.data)
    result = '  '.join(
        ['top@%d:%.4f' % (k, rc) for k, rc in zip(ks, recall_ks)])
    print('Epoch-%d' % epoch, result)
Exemplo n.º 3
0
def extract_recalls(data,
                    data_root,
                    width,
                    net,
                    checkpoint,
                    dim,
                    batch_size,
                    nThreads,
                    pool_feature,
                    gallery_eq_query,
                    model=None,
                    epoch=0,
                    org_feature=False,
                    save_txt="",
                    args=None):


    gallery_feature, gallery_labels, query_feature, query_labels = \
        Model2Feature(data=data, root=data_root, width=width, net=net, checkpoint=checkpoint,
                    dim=dim, batch_size=batch_size, nThreads=nThreads, pool_feature=pool_feature, model=model, org_feature=org_feature, args=args)

    sim_mat = pairwise_similarity(query_feature, gallery_feature)
    if gallery_eq_query is True:
        sim_mat = sim_mat - torch.eye(sim_mat.size(0))

    recall_ks = Recall_at_ks(sim_mat,
                             query_ids=query_labels,
                             gallery_ids=gallery_labels,
                             data=data,
                             args=args,
                             epoch=epoch)

    labels = [x.item() for x in gallery_labels]

    nmi = NMI(gallery_feature, gallery_labels, n_cluster=len(set(labels)))
    print(recall_ks, nmi)
    result = '  '.join(['%.4f' % k for k in (recall_ks.tolist() + [nmi])])

    print('Epoch-%d' % epoch, result)
def eval(ckp_path=None, model=None):
    args = Config()
    if (ckp_path != None):
        checkpoint = load_checkpoint(ckp_path, args)
    else:
        checkpoint = model
        checkpoint.eval()
    # print(args.pool_feature)

    gallery_feature, gallery_labels, query_feature, query_labels = \
        Model2Feature(data=args.data,model = checkpoint, batch_size=args.batch_size, nThreads=args.nThreads, pool_feature=args.pool_feature)

    sim_mat = pairwise_similarity(query_feature, gallery_feature)
    if args.gallery_eq_query is True:
        sim_mat = sim_mat - torch.eye(sim_mat.size(0))

    recall_ks = Recall_at_ks(sim_mat,
                             query_ids=query_labels,
                             gallery_ids=gallery_labels,
                             data=args.data)
    if (ckp_path == None):
        checkpoint.train()
    return recall_ks
Exemplo n.º 5
0
args = parser.parse_args()

checkpoint = load_checkpoint(args.resume)
print(args.pool_feature)

epoch = checkpoint['epoch']
print('Training Epoch:', epoch)

gallery_feature, gallery_labels, query_feature, query_labels, img_name, img_name_shuffled = \
    Model2Feature(data=args.data, root=args.data_root, width=args.width, net=args.net, checkpoint=checkpoint, Retrieval_visualization =args.Retrieval_visualization,
                   dim=args.dim, batch_size=args.batch_size, nThreads=args.nThreads, pool_feature=args.pool_feature)

if args.Retrieval_visualization:
    specific_query = 599  # the index of a query image
    sim_mat = pairwise_similarity(query_feature[specific_query:specific_query+1, :], gallery_feature) #query * gallery
else:
    sim_mat = pairwise_similarity(query_feature, gallery_feature) #query * gallery

if args.gallery_eq_query is True:
    sim_mat = sim_mat - torch.eye(sim_mat.size(0))

### for retrieval visual. the query image is given and fixed, args.Retrieval_visualization = True
if args.Retrieval_visualization:
    topK_visual(sim_mat, img_name, img_name_shuffled, specific_query, query_ids=query_labels[specific_query:specific_query+1], gallery_ids=gallery_labels, data=args.data)

recall_ks = Recall_at_ks(sim_mat, img_name, query_ids=query_labels, gallery_ids=gallery_labels, data=args.data)

result = '  '.join(['%.4f' % k for k in recall_ks])
print('Epoch-%d' % epoch, result)
Exemplo n.º 6
0
                                                 drop_last=False)
    query_loader = torch.utils.data.DataLoader(data.query,
                                               batch_size=64,
                                               shuffle=False,
                                               drop_last=False)

    gallery_feature, gallery_labels = extract_features(model,
                                                       gallery_loader,
                                                       print_freq=1e5,
                                                       metric=None)
    query_feature, query_labels = extract_features(model,
                                                   query_loader,
                                                   print_freq=1e5,
                                                   metric=None)

    sim_mat = pairwise_similarity(x=query_feature, y=gallery_feature)
    result = Recall_at_ks_shop(sim_mat,
                               query_ids=query_labels,
                               gallery_ids=gallery_labels)

elif args.data == 'jd':
    if args.test == 1:
        data_loader = torch.utils.data.DataLoader(data.gallery,
                                                  batch_size=64,
                                                  shuffle=False,
                                                  drop_last=False)
    else:
        data = DataSet.create(args.data)
        data_loader = torch.utils.data.DataLoader(data.gallery,
                                                  batch_size=64,
                                                  shuffle=False,
Exemplo n.º 7
0
                    '-j',
                    default=16,
                    type=int,
                    metavar='N',
                    help='number of data loading threads (default: 2)')
parser.add_argument('--pool_feature',
                    type=ast.literal_eval,
                    default=False,
                    required=False,
                    help='if True extract feature from the last pool layer')

args = parser.parse_args()
checkpoint = load_checkpoint(args.resume)
print(args.pool_feature)
epoch = checkpoint['epoch']

gallery_feature, gallery_labels, query_feature, query_labels = \
    Model2Feature(data=args.data, root=args.data_root, width=args.width, net=args.net, checkpoint=checkpoint,
                   dim=args.dim, batch_size=args.batch_size, nThreads=args.nThreads, pool_feature=args.pool_feature)

sim_mat = pairwise_similarity(query_feature, gallery_feature)
if args.gallery_eq_query is True:
    sim_mat = sim_mat - torch.eye(sim_mat.size(0))

recall_ks = Recall_at_ks(sim_mat,
                         query_ids=query_labels,
                         gallery_ids=gallery_labels,
                         data=args.data)

result = '  '.join(['%.4f' % k for k in recall_ks])
print('Epoch-%d' % epoch, result)
Exemplo n.º 8
0
if args.test == 1:
    print('evaluation on test set of %s with model: %s' % (args.data, args.r))
    data = DataSet.create(args.data, train=False)
    data_loader = torch.utils.data.DataLoader(data.test,
                                              batch_size=64,
                                              shuffle=False,
                                              drop_last=False)
else:
    print('evaluation on train set of %s with model: %s' % (args.data, args.r))
    data = DataSet.create(args.data, test=False)
    data_loader = torch.utils.data.DataLoader(data.train,
                                              batch_size=64,
                                              shuffle=False,
                                              drop_last=False)

features, labels = extract_features(model,
                                    data_loader,
                                    print_freq=32,
                                    metric=None)
print('embedding dimension is:', len(features[0]))
num_class = len(set(labels))
print('number of classes is :', num_class)
print('compute the NMI index:', NMI(features, labels, n_cluster=num_class))

# print(len(features))
sim_mat = pairwise_similarity(features)
if args.data == 'products':
    print(Recall_at_ks_products(sim_mat, query_ids=labels, gallery_ids=labels))
else:
    print(Recall_at_ks(sim_mat, query_ids=labels, gallery_ids=labels))