def main(args):
    batch_time = AverageMeter()
    end = time.time()

    checkpoint = load_checkpoint(args.resume)  #loaded
    print('pool_features:', args.pool_feature)
    epoch = checkpoint['epoch']

    gallery_feature, gallery_labels, query_feature, query_labels = \
    Model2Feature(data=args.data, root=args.data_root, net=args.net, checkpoint=checkpoint
    , batch_size=args.batch_size, nThreads=args.nThreads, pool_feature=args.pool_feature)    #output

    sim_mat = pairwise_similarity(query_feature, gallery_feature)  #成对相似性
    if args.gallery_eq_query is True:
        sim_mat = sim_mat - torch.eye(sim_mat.size(0))

    print('labels', query_labels)
    print('feature:', gallery_feature)

    recall_ks = Recall_at_ks(sim_mat,
                             query_ids=query_labels,
                             gallery_ids=gallery_labels,
                             data=args.data)

    result = '  '.join(['%.4f' % k for k in recall_ks])  #   result=recall_ks
    print('Epoch-%d' % epoch, result)
    batch_time.update(time.time() - end)

    print('Epoch-%d\t' % epoch,
          'Time {batch_time.avg:.3f}\t'.format(batch_time=batch_time))

    import matplotlib.pyplot as plt
    import torchvision
    import numpy as np

    similarity = torch.mm(gallery_feature, gallery_feature.t())
    similarity.size()

    #draw Feature Map
    img = torchvision.utils.make_grid(similarity).numpy()
    plt.imshow(np.transpose(img, (1, 2, 0)))
    plt.show()
예제 #2
0
def extract_recalls(data,
                    data_root,
                    width,
                    net,
                    checkpoint,
                    dim,
                    batch_size,
                    nThreads,
                    pool_feature,
                    gallery_eq_query,
                    model=None,
                    epoch=0,
                    org_feature=False,
                    save_txt="",
                    args=None):


    gallery_feature, gallery_labels, query_feature, query_labels = \
        Model2Feature(data=data, root=data_root, width=width, net=net, checkpoint=checkpoint,
                    dim=dim, batch_size=batch_size, nThreads=nThreads, pool_feature=pool_feature, model=model, org_feature=org_feature, args=args)

    sim_mat = pairwise_similarity(query_feature, gallery_feature)
    if gallery_eq_query is True:
        sim_mat = sim_mat - torch.eye(sim_mat.size(0))

    recall_ks = Recall_at_ks(sim_mat,
                             query_ids=query_labels,
                             gallery_ids=gallery_labels,
                             data=data,
                             args=args,
                             epoch=epoch)

    labels = [x.item() for x in gallery_labels]

    nmi = NMI(gallery_feature, gallery_labels, n_cluster=len(set(labels)))
    print(recall_ks, nmi)
    result = '  '.join(['%.4f' % k for k in (recall_ks.tolist() + [nmi])])

    print('Epoch-%d' % epoch, result)
def eval(ckp_path=None, model=None):
    args = Config()
    if (ckp_path != None):
        checkpoint = load_checkpoint(ckp_path, args)
    else:
        checkpoint = model
        checkpoint.eval()
    # print(args.pool_feature)

    gallery_feature, gallery_labels, query_feature, query_labels = \
        Model2Feature(data=args.data,model = checkpoint, batch_size=args.batch_size, nThreads=args.nThreads, pool_feature=args.pool_feature)

    sim_mat = pairwise_similarity(query_feature, gallery_feature)
    if args.gallery_eq_query is True:
        sim_mat = sim_mat - torch.eye(sim_mat.size(0))

    recall_ks = Recall_at_ks(sim_mat,
                             query_ids=query_labels,
                             gallery_ids=gallery_labels,
                             data=args.data)
    if (ckp_path == None):
        checkpoint.train()
    return recall_ks
예제 #4
0
parser.add_argument('--Incremental_flag',  default=False, type=bool, help='incremental learning or not')
parser.add_argument('--Retrieval_visualization', default=False, type=bool, help='Visualize the retrieved image for a given img')
parser.add_argument('--batch_size', type=int, default=80)
parser.add_argument('--nThreads', '-j', default=16, type=int, metavar='N', help='number of data loading threads (default: 2)')
parser.add_argument('--pool_feature', type=ast.literal_eval, default=False, required=False, help='if True extract feature from the last pool layer')

args = parser.parse_args()

checkpoint = load_checkpoint(args.resume)
print(args.pool_feature)

epoch = checkpoint['epoch']
print('Training Epoch:', epoch)

gallery_feature, gallery_labels, query_feature, query_labels, img_name, img_name_shuffled = \
    Model2Feature(data=args.data, root=args.data_root, width=args.width, net=args.net, checkpoint=checkpoint, Retrieval_visualization =args.Retrieval_visualization,
                   dim=args.dim, batch_size=args.batch_size, nThreads=args.nThreads, pool_feature=args.pool_feature)

if args.Retrieval_visualization:
    specific_query = 599  # the index of a query image
    sim_mat = pairwise_similarity(query_feature[specific_query:specific_query+1, :], gallery_feature) #query * gallery
else:
    sim_mat = pairwise_similarity(query_feature, gallery_feature) #query * gallery

if args.gallery_eq_query is True:
    sim_mat = sim_mat - torch.eye(sim_mat.size(0))

### for retrieval visual. the query image is given and fixed, args.Retrieval_visualization = True
if args.Retrieval_visualization:
    topK_visual(sim_mat, img_name, img_name_shuffled, specific_query, query_ids=query_labels[specific_query:specific_query+1], gallery_ids=gallery_labels, data=args.data)

recall_ks = Recall_at_ks(sim_mat, img_name, query_ids=query_labels, gallery_ids=gallery_labels, data=args.data)
예제 #5
0
파일: test.py 프로젝트: CH-Liang/DSDML
                    '-j',
                    default=16,
                    type=int,
                    metavar='N',
                    help='number of data loading threads (default: 2)')
parser.add_argument('--pool_feature',
                    type=ast.literal_eval,
                    default=False,
                    required=False,
                    help='if True extract feature from the last pool layer')

args = parser.parse_args()
checkpoint = load_checkpoint(args.resume)
print(args.pool_feature)
epoch = checkpoint['epoch']

gallery_feature, gallery_labels, query_feature, query_labels = \
    Model2Feature(data=args.data, root=args.data_root, width=args.width, net=args.net, checkpoint=checkpoint,
                   dim=args.dim, batch_size=args.batch_size, nThreads=args.nThreads, pool_feature=args.pool_feature)

sim_mat = pairwise_similarity(query_feature, gallery_feature)
if args.gallery_eq_query is True:
    sim_mat = sim_mat - torch.eye(sim_mat.size(0))

recall_ks = Recall_at_ks(sim_mat,
                         query_ids=query_labels,
                         gallery_ids=gallery_labels,
                         data=args.data)

result = '  '.join(['%.4f' % k for k in recall_ks])
print('Epoch-%d' % epoch, result)