def validate(opt, val_loader, model):
    # compute the encoding for all the validation images and captions
    #img_embs, cap_embs = encode_data(
    #    model, val_loader, opt.log_step, logging.info)

    with torch.no_grad():
        q_id_dict = encode_data(model, val_loader, opt.log_step, logging.info)

        start = time.time()
        if opt.use_atten:
            #rerank_ndcg5 = nDCG5_t2i_atten_rerank(q_id_dict, val_loader.dataset.answer, opt)
            rerank_ndcg5 = -1
            org_ndcg5 = nDCG5_t2i_atten(q_id_dict, val_loader.dataset.answer,
                                        opt)
        else:
            rerank_ndcg5 = nDCG5_t2i_rerank(q_id_dict,
                                            val_loader.dataset.answer, opt)
            org_ndcg5 = nDCG5_t2i(q_id_dict, val_loader.dataset.answer, opt)
        end = time.time()
        print("calculate similarity time:", end - start)
    logging.info("Text to image: org:%.5f rerank:%.5f" %
                 (org_ndcg5, rerank_ndcg5))

    return org_ndcg5, rerank_ndcg5
Esempio n. 2
0
def validate(val_loader,
             model,
             tb_logger,
             measure='cosine',
             log_step=10,
             ndcg_scorer=None):
    # compute the encoding for all the validation images and captions
    img_embs, cap_embs = encode_data(model, val_loader, log_step, logging.info)

    # image retrieval
    (r1i, r5i, r10i, medri, meanr, mean_rougel_ndcg_i,
     mean_spice_ndcg_i) = t2i(img_embs,
                              cap_embs,
                              ndcg_scorer=ndcg_scorer,
                              measure=measure)

    logging.info(
        "Text to image: %.1f, %.1f, %.1f, %.1f, %.1f, ndcg_rouge=%.4f ndcg_spice=%.4f"
        %
        (r1i, r5i, r10i, medri, meanr, mean_rougel_ndcg_i, mean_spice_ndcg_i))
    # sum of recalls to be used for early stopping
    currscore = r1i + r5i + r10i

    # record metrics in tensorboard
    tb_logger.add_scalar('r1i', r1i, model.Eiters)
    tb_logger.add_scalar('r5i', r5i, model.Eiters)
    tb_logger.add_scalar('r10i', r10i, model.Eiters)
    tb_logger.add_scalars('mean_ndcg_i', {
        'rougeL': mean_rougel_ndcg_i,
        'spice': mean_spice_ndcg_i
    }, model.Eiters)
    tb_logger.add_scalar('medri', medri, model.Eiters)
    tb_logger.add_scalar('meanr', meanr, model.Eiters)
    tb_logger.add_scalar('rsum', currscore, model.Eiters)

    return currscore, mean_spice_ndcg_i
Esempio n. 3
0
def validate(opt, val_loader, model, measure='cosine'):
    # compute the encoding for all the validation video and captions
    video_embs, cap_embs, video_ids, caption_ids = evaluation.encode_data(
        model, val_loader, opt.log_step, logging.info)

    # we load data as video-sentence pairs
    # but we only need to forward each video once for evaluation
    # so we get the video set and mask out same videos with feature_mask
    feature_mask = []
    evaluate_videos = set()
    for video_id in video_ids:
        feature_mask.append(video_id not in evaluate_videos)
        evaluate_videos.add(video_id)
    video_embs = video_embs[feature_mask]
    video_ids = [
        x for idx, x in enumerate(video_ids) if feature_mask[idx] is True
    ]

    c2i_all_errors = evaluation.cal_error(video_embs, cap_embs, measure)
    if opt.val_metric == "recall":

        # video retrieval
        (r1i, r5i, r10i, medri,
         meanri) = evaluation.t2i(c2i_all_errors, n_caption=opt.n_caption)
        print(" * Text to video:")
        print(" * r_1_5_10: {}".format(
            [round(r1i, 3), round(r5i, 3),
             round(r10i, 3)]))
        print(" * medr, meanr: {}".format([round(medri, 3), round(meanri, 3)]))
        print(" * " + '-' * 10)

        # caption retrieval
        (r1, r5, r10, medr, meanr) = evaluation.i2t(c2i_all_errors,
                                                    n_caption=opt.n_caption)
        print(" * Video to text:")
        print(" * r_1_5_10: {}".format(
            [round(r1, 3), round(r5, 3),
             round(r10, 3)]))
        print(" * medr, meanr: {}".format([round(medr, 3), round(meanr, 3)]))
        print(" * " + '-' * 10)

        # record metrics in tensorboard
        tb_logger.log_value('r1', r1, step=model.Eiters)
        tb_logger.log_value('r5', r5, step=model.Eiters)
        tb_logger.log_value('r10', r10, step=model.Eiters)
        tb_logger.log_value('medr', medr, step=model.Eiters)
        tb_logger.log_value('meanr', meanr, step=model.Eiters)
        tb_logger.log_value('r1i', r1i, step=model.Eiters)
        tb_logger.log_value('r5i', r5i, step=model.Eiters)
        tb_logger.log_value('r10i', r10i, step=model.Eiters)
        tb_logger.log_value('medri', medri, step=model.Eiters)
        tb_logger.log_value('meanri', meanri, step=model.Eiters)

    elif opt.val_metric == "map":
        i2t_map_score = evaluation.i2t_map(c2i_all_errors,
                                           n_caption=opt.n_caption)
        t2i_map_score = evaluation.t2i_map(c2i_all_errors,
                                           n_caption=opt.n_caption)
        tb_logger.log_value('i2t_map', i2t_map_score, step=model.Eiters)
        tb_logger.log_value('t2i_map', t2i_map_score, step=model.Eiters)
        print('i2t_map', i2t_map_score)
        print('t2i_map', t2i_map_score)

    currscore = 0
    if opt.val_metric == "recall":
        if opt.direction == 'i2t' or opt.direction == 'all':
            currscore += (r1 + r5 + r10)
        if opt.direction == 't2i' or opt.direction == 'all':
            currscore += (r1i + r5i + r10i)
    elif opt.val_metric == "map":
        if opt.direction == 'i2t' or opt.direction == 'all':
            currscore += i2t_map_score
        if opt.direction == 't2i' or opt.direction == 'all':
            currscore += t2i_map_score

    tb_logger.log_value('rsum', currscore, step=model.Eiters)

    return currscore
Esempio n. 4
0
def main():
    opt = parse_args()
    print(json.dumps(vars(opt), indent=2))

    rootpath = opt.rootpath
    testCollection = opt.testCollection
    n_caption = opt.n_caption
    resume = os.path.join(opt.logger_name, opt.checkpoint_name)

    if not os.path.exists(resume):
        logging.info(resume + ' not exists.')
        sys.exit(0)

    checkpoint = torch.load(resume)
    start_epoch = checkpoint['epoch']
    best_rsum = checkpoint['best_rsum']
    print("=> loaded checkpoint '{}' (epoch {}, best_rsum {})".format(
        resume, start_epoch, best_rsum))
    options = checkpoint['opt']
    if not hasattr(options, 'concate'):
        setattr(options, "concate", "full")

    trainCollection = options.trainCollection
    output_dir = resume.replace(trainCollection, testCollection)
    output_dir = output_dir.replace('/%s/' % options.cv_name,
                                    '/results/%s/' % trainCollection)
    result_pred_sents = os.path.join(output_dir, 'id.sent.score.txt')
    pred_error_matrix_file = os.path.join(output_dir,
                                          'pred_errors_matrix.pth.tar')
    if checkToSkip(pred_error_matrix_file, opt.overwrite):
        sys.exit(0)
    makedirsforfile(pred_error_matrix_file)

    # data loader prepare
    caption_files = {
        'test':
        os.path.join(rootpath, testCollection, 'TextData',
                     '%s.caption.txt' % testCollection)
    }
    img_feat_path = os.path.join(rootpath, testCollection, 'FeatureData',
                                 options.visual_feature)
    visual_feats = {'test': BigFile(img_feat_path)}
    assert options.visual_feat_dim == visual_feats['test'].ndims
    video2frames = {
        'test':
        read_dict(
            os.path.join(rootpath, testCollection, 'FeatureData',
                         options.visual_feature, 'video2frames.txt'))
    }

    # set bow vocabulary and encoding
    bow_vocab_file = os.path.join(rootpath, options.trainCollection,
                                  'TextData', 'vocabulary', 'bow',
                                  options.vocab + '.pkl')
    bow_vocab = pickle.load(open(bow_vocab_file, 'rb'))
    bow2vec = get_text_encoder('bow')(bow_vocab)
    options.bow_vocab_size = len(bow_vocab)

    # set rnn vocabulary
    rnn_vocab_file = os.path.join(rootpath, options.trainCollection,
                                  'TextData', 'vocabulary', 'rnn',
                                  options.vocab + '.pkl')
    rnn_vocab = pickle.load(open(rnn_vocab_file, 'rb'))
    options.vocab_size = len(rnn_vocab)

    # Construct the model
    model = get_model(options.model)(options)
    model.load_state_dict(checkpoint['model'])
    model.Eiters = checkpoint['Eiters']
    model.val_start()

    if testCollection.startswith(
            'msvd'):  # or testCollection.startswith('msrvtt'):
        # set data loader
        video_ids_list = data.read_video_ids(caption_files['test'])
        vid_data_loader = data.get_vis_data_loader(visual_feats['test'],
                                                   opt.batch_size,
                                                   opt.workers,
                                                   video2frames['test'],
                                                   video_ids=video_ids_list)
        text_data_loader = data.get_txt_data_loader(caption_files['test'],
                                                    rnn_vocab, bow2vec,
                                                    opt.batch_size,
                                                    opt.workers)
        # mapping
        video_embs, video_ids = evaluation.encode_text_or_vid(
            model.embed_vis, vid_data_loader)
        cap_embs, caption_ids = evaluation.encode_text_or_vid(
            model.embed_txt, text_data_loader)
    else:
        # set data loader
        data_loader = data.get_test_data_loaders(caption_files,
                                                 visual_feats,
                                                 rnn_vocab,
                                                 bow2vec,
                                                 opt.batch_size,
                                                 opt.workers,
                                                 opt.n_caption,
                                                 video2frames=video2frames)
        # mapping
        video_embs, cap_embs, video_ids, caption_ids = evaluation.encode_data(
            model, data_loader['test'], opt.log_step, logging.info)
        # remove duplicate videos
        idx = range(0, video_embs.shape[0], n_caption)
        video_embs = video_embs[idx, :]
        video_ids = video_ids[::opt.n_caption]

    c2i_all_errors = evaluation.cal_error(video_embs, cap_embs,
                                          options.measure)
    torch.save(
        {
            'errors': c2i_all_errors,
            'videos': video_ids,
            'captions': caption_ids
        }, pred_error_matrix_file)
    print("write into: %s" % pred_error_matrix_file)

    if testCollection.startswith(
            'msvd'):  # or testCollection.startswith('msrvtt'):
        # caption retrieval
        (r1, r5, r10, medr, meanr,
         i2t_map_score) = evaluation.i2t_varied(c2i_all_errors, caption_ids,
                                                video_ids)
        # video retrieval
        (r1i, r5i, r10i, medri, meanri,
         t2i_map_score) = evaluation.t2i_varied(c2i_all_errors, caption_ids,
                                                video_ids)
    else:
        # caption retrieval
        (r1i, r5i, r10i, medri, meanri) = evaluation.t2i(c2i_all_errors,
                                                         n_caption=n_caption)
        t2i_map_score = evaluation.t2i_map(c2i_all_errors, n_caption=n_caption)

        # video retrieval
        (r1, r5, r10, medr, meanr) = evaluation.i2t(c2i_all_errors,
                                                    n_caption=n_caption)
        i2t_map_score = evaluation.i2t_map(c2i_all_errors, n_caption=n_caption)

    print(" * Text to Video:")
    print(" * r_1_5_10, medr, meanr: {}".format([
        round(r1i, 1),
        round(r5i, 1),
        round(r10i, 1),
        round(medri, 1),
        round(meanri, 1)
    ]))
    print(" * recall sum: {}".format(round(r1i + r5i + r10i, 1)))
    print(" * mAP: {}".format(round(t2i_map_score, 3)))
    print(" * " + '-' * 10)

    # caption retrieval
    print(" * Video to text:")
    print(" * r_1_5_10, medr, meanr: {}".format([
        round(r1, 1),
        round(r5, 1),
        round(r10, 1),
        round(medr, 1),
        round(meanr, 1)
    ]))
    print(" * recall sum: {}".format(round(r1 + r5 + r10, 1)))
    print(" * mAP: {}".format(round(i2t_map_score, 3)))
    print(" * " + '-' * 10)
def create_embs(data_loader, model):
    img_emb, cap_emb, cap_len, _ = encode_data(model, data_loader)
    return img_emb
Esempio n. 6
0
def validate(opt, val_loader, model):
    # compute the encoding for all the validation images and captions
    vid_seq_embs, para_seq_embs, clip_embs, cap_embs, _, _, num_clips, cur_vid_total = encode_data(
        opt, model, val_loader, opt.log_step, logging.info, contextual_model=True)

    # caption retrieval
#    vid_clip_rep, _, _ = i2t(clip_embs, cap_embs, measure=opt.measure)
    # image retrieval
#    cap_clip_rep, _, _ = t2i(clip_embs, cap_embs, measure=opt.measure)

    # caption retrieval
    vid_seq_rep, top1_v2p, rank_vid_v2p  = i2t(vid_seq_embs, para_seq_embs, measure=opt.measure)
    # image retrieval
    para_seq_rep, top1_p2v, rank_para_p2v = t2i(vid_seq_embs, para_seq_embs, measure=opt.measure)

    currscore = vid_seq_rep['sum'] + para_seq_rep['sum']

#    logging.info("Clip to Sent: %.1f, %.1f, %.1f, %.1f, %.1f" %
#         (vid_clip_rep['r1'], vid_clip_rep['r5'], vid_clip_rep['r10'], vid_clip_rep['medr'], vid_clip_rep['meanr']))
#    logging.info("Sent to Clip: %.1f, %.1f, %.1f, %.1f, %.1f" %
#         (cap_clip_rep['r1'], cap_clip_rep['r5'], cap_clip_rep['r10'], cap_clip_rep['medr'], cap_clip_rep['meanr']))
    logging.info("Video to Paragraph: %.1f, %.1f, %.1f, %.1f, %.1f" %
         (vid_seq_rep['r1'], vid_seq_rep['r5'], vid_seq_rep['r10'], vid_seq_rep['medr'], vid_seq_rep['meanr']))
    logging.info("Paragraph to Video: %.1f, %.1f, %.1f, %.1f, %.1f" %
         (para_seq_rep['r1'], para_seq_rep['r5'], para_seq_rep['r10'], para_seq_rep['medr'], para_seq_rep['meanr']))
    logging.info("Currscore: %.1f" % (currscore))

    # record metrics in tensorboard
#    LogReporter(tb_logger, vid_clip_rep, model.Eiters, 'clip')
#    LogReporter(tb_logger, cap_clip_rep, model.Eiters, 'clipi')
    LogReporter(tb_logger, vid_seq_rep, model.Eiters, 'seq')
    LogReporter(tb_logger, para_seq_rep, model.Eiters, 'seqi')
    tb_logger.log_value('rsum', currscore, step=model.Eiters)

    return currscore
Esempio n. 7
0
File: test.py Progetto: kiminh/cbvr
def main():
    # Hyper Parameters
    parser = argparse.ArgumentParser()
    parser.add_argument("--rootpath", default=ROOT_PATH, type=str, help="rootpath (default: %s)" % ROOT_PATH)
    parser.add_argument('--collection', default='track_1_shows', type=str, help='collection')
    parser.add_argument('--checkpoint_path', default='', type=str, metavar='PATH', help='path to latest checkpoint (default: none)')
    parser.add_argument("--test_set", default="val", type=str, help="val or test")
    parser.add_argument('--batch_size', default=128, type=int, help='Size of a training mini-batch.')
    parser.add_argument("--overwrite", default=0, type=int,  help="overwrite existing file (default: 0)")

    opt = parser.parse_args()
    print json.dumps(vars(opt), indent = 2)


    assert opt.test_set in ['val', 'test']
    output_dir = os.path.dirname(opt.checkpoint_path.replace('/cv/', '/results/%s/' % opt.test_set ))
    output_file = os.path.join(output_dir,'pred_video2rank.csv')
    if checkToSkip(output_file, opt.overwrite):
        sys.exit(0)
    makedirsforfile(output_file)


    # reading data
    train_video_set_file = os.path.join(opt.rootpath, opt.collection, 'split', 'train.csv')
    val_video_set_file = os.path.join(opt.rootpath, opt.collection, 'split', 'val.csv')
    train_video_list = read_video_set(train_video_set_file)
    val_video_list = read_video_set(val_video_set_file)
    if opt.test_set ==  'test':
        test_video_set_file = os.path.join(opt.rootpath, opt.collection, 'split', 'test.csv' )
        test_video_list = read_video_set(test_video_set_file)


    # optionally resume from a checkpoint
    print("=> loading checkpoint '{}'".format(opt.checkpoint_path))
    checkpoint = torch.load(opt.checkpoint_path)
    options = checkpoint['opt']

    # set feature reader
    video_feat_path = os.path.join(opt.rootpath, opt.collection, 'FeatureData', options.feature)
    video_feats = BigFile(video_feat_path)

 
    # Construct the model
    if opt.test_set == 'val':
        val_rootpath = os.path.join(opt.rootpath, opt.collection, 'relevance_val.csv')
        val_video2gtrank = read_csv_to_dict(val_rootpath)
        val_feat_loader = data.get_feat_loader(val_video_list, video_feats, opt.batch_size, False, 1)
        cand_feat_loader = data.get_feat_loader(train_video_list + val_video_list, video_feats, opt.batch_size, False, 1)
    elif opt.test_set == 'test':
        val_feat_loader = data.get_feat_loader(test_video_list, video_feats, opt.batch_size, False, 1)
        cand_feat_loader = data.get_feat_loader(train_video_list + val_video_list + test_video_list, video_feats, opt.batch_size, False, 1)
    
    model = ReLearning(options)
    model.load_state_dict(checkpoint['model'])
    val_video_embs, val_video_ids_list = encode_data(model, val_feat_loader, options.log_step, logging.info)
    cand_video_embs, cand_video_ids_list = encode_data(model, cand_feat_loader, options.log_step, logging.info)


    video2predrank = do_predict(val_video_embs, val_video_ids_list, cand_video_embs, cand_video_ids_list, output_dir=output_dir, overwrite=1, no_imgnorm=options.no_imgnorm)
    write_csv_video2rank(output_file, video2predrank)

    if opt.test_set ==  'val':
        hit_top_k = [5, 10, 20, 30]
        recall_top_k = [50, 100, 200, 300]
        hit_k_scores = hit_k_own(val_video2gtrank, video2predrank, top_k=hit_top_k)
        recall_K_scores = recall_k_own(val_video2gtrank, video2predrank, top_k=recall_top_k)

        # output val performance
        print '\nbest performance on validation:'
        print 'hit_top_k', [round(x,3) for x in hit_k_scores]
        print 'recall_top_k', [round(x,3) for x in recall_K_scores]
        with open(os.path.join(output_dir,'perf.txt'), 'w') as fout:
            fout.write('best performance on validation:')
            fout.write('\nhit_top_k: ' + ", ".join(map(str, [round(x,3) for x in hit_k_scores])))
            fout.write('\necall_top_k: ' + ", ".join(map(str, [round(x,3) for x in recall_K_scores])))
Esempio n. 8
0
def train(opt, model, epoch, train_loader, val_loader):
    # average meters to record the training statistics
    batch_time = AverageMeter()
    data_time = AverageMeter()
    train_logger = LogCollector()

    kmeans_features = None
    kmeans_emb = None

    end = time.time()

    if opt.cluster_loss:
        features = retrieve_features(train_loader)
        kmeans_features = get_centers(features, opt.n_clusters)

    # https://stats.stackexchange.com/questions/299013/cosine-distance-as-similarity-measure-in-kmeans
    # normalizing and euclidian distance is linear correlated with cosine distance

    for j, (images, targets, lengths, ids) in enumerate(train_loader):

        if opt.cluster_loss:
            img_embs, _, _ = encode_data(model, train_loader)
            kmeans_emb = get_centers(img_embs, opt.n_clusters)

        # switch to train mode
        model.train_start()
        # if j == i:
        #     same = True
        # else:
        #     same = False

        # measure data loading time
        data_time.update(time.time() - end)

        # make sure train logger is used
        model.logger = train_logger

        # Update the model
        model.train_emb(epoch, images, targets, lengths, ids, opt.cluster_loss,
                        kmeans_features, kmeans_emb)

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        # Print log info
        if model.Eiters % opt.log_step == 0:
            logging.info(
                'Epoch: [{0}][{1}/{2}]\t'
                '{e_log}\t'
                'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'.format(
                    epoch,
                    j,
                    len(train_loader),
                    batch_time=batch_time,
                    data_time=data_time,
                    e_log=str(model.logger)))

        # Record logs in tensorboard
        tb_logger.log_value('epoch', epoch, step=model.Eiters)
        tb_logger.log_value('step', j, step=model.Eiters)
        tb_logger.log_value('batch_time', batch_time.val, step=model.Eiters)
        tb_logger.log_value('data_time', data_time.val, step=model.Eiters)
        model.logger.tb_log(tb_logger, step=model.Eiters)

        # validate at every val_step
        if model.Eiters % opt.val_step == 0:
            validate(opt, val_loader, model)
Esempio n. 9
0
def validate(opt, val_loader, model):
    # compute the encoding for all the validation images and captions
    if opt.model_type == 'vse++':
        img_embs, cap_embs = encode_data(model, val_loader, opt.log_step,
                                         logging.info)
    else:
        img_embs, cap_embs, cap_lens, cap_inds, img_inds, tag_masks = encode_data(
            model, val_loader, opt.log_step, logging.info)

        img_embs = numpy.array(
            [img_embs[i] for i in range(0, len(img_embs), 5)])

        start = time.time()
        if opt.cross_attn == 't2i':
            sims = shard_xattn_t2i(img_embs,
                                   cap_embs,
                                   cap_lens,
                                   opt,
                                   tag_masks,
                                   shard_size=128)
        elif opt.cross_attn == 'i2t':
            sims = shard_xattn_i2t(img_embs,
                                   cap_embs,
                                   cap_lens,
                                   opt,
                                   tag_masks,
                                   shard_size=128)
        else:
            raise NotImplementedError
        end = time.time()
        print("calculate similarity time:", end - start)

    # caption retrieval
    if opt.model_type == 'vse++':
        (r1, r5, r10, medr, meanr) = i2t(img_embs,
                                         cap_embs,
                                         measure=opt.measure)
    else:
        (r1, r5, r10, medr, meanr) = i2t(img_embs, cap_embs, cap_lens, sims)
    logging.info("Image to text: %.1f, %.1f, %.1f, %.1f, %.1f" %
                 (r1, r5, r10, medr, meanr))
    # image retrieval
    if opt.model_type == 'vse++':
        (r1i, r5i, r10i, medri, meanr) = t2i(img_embs,
                                             cap_embs,
                                             measure=opt.measure)
    else:
        (r1i, r5i, r10i, medri, meanr) = t2i(img_embs, cap_embs, cap_lens,
                                             sims)
    logging.info("Text to image: %.1f, %.1f, %.1f, %.1f, %.1f" %
                 (r1i, r5i, r10i, medri, meanr))
    # sum of recalls to be used for early stopping
    currscore = r1 + r5 + r10 + r1i + r5i + r10i

    # record metrics in tensorboard
    tb_logger.log_value('r1', r1, step=model.Eiters)
    tb_logger.log_value('r5', r5, step=model.Eiters)
    tb_logger.log_value('r10', r10, step=model.Eiters)
    tb_logger.log_value('medr', medr, step=model.Eiters)
    tb_logger.log_value('meanr', meanr, step=model.Eiters)
    tb_logger.log_value('r1i', r1i, step=model.Eiters)
    tb_logger.log_value('r5i', r5i, step=model.Eiters)
    tb_logger.log_value('r10i', r10i, step=model.Eiters)
    tb_logger.log_value('medri', medri, step=model.Eiters)
    tb_logger.log_value('meanr', meanr, step=model.Eiters)
    tb_logger.log_value('rsum', currscore, step=model.Eiters)

    return currscore
Esempio n. 10
0
def test(test_loader,
         model,
         tb_logger,
         measure='cosine',
         log_step=10,
         ndcg_scorer=None):
    # compute the encoding for all the validation images and captions
    img_embs, cap_embs = encode_data(model, test_loader, log_step,
                                     logging.info)

    if measure == 'cosine':
        sim_fn = cosine_sim
    elif measure == 'dot':
        sim_fn = dot_sim

    results = []
    for i in range(5):
        r, rt0 = i2t(img_embs[i * 5000:(i + 1) * 5000],
                     cap_embs[i * 5000:(i + 1) * 5000],
                     None,
                     None,
                     return_ranks=True,
                     ndcg_scorer=ndcg_scorer,
                     fold_index=i)
        print(
            "Image to text: %.1f, %.1f, %.1f, %.1f, %.1f, ndcg_rouge=%.4f ndcg_spice=%.4f"
            % r)
        ri, rti0 = t2i(img_embs[i * 5000:(i + 1) * 5000],
                       cap_embs[i * 5000:(i + 1) * 5000],
                       None,
                       None,
                       return_ranks=True,
                       ndcg_scorer=ndcg_scorer,
                       fold_index=i)
        if i == 0:
            rt, rti = rt0, rti0
        print(
            "Text to image: %.1f, %.1f, %.1f, %.1f, %.1f, ndcg_rouge=%.4f, ndcg_spice=%.4f"
            % ri)
        ar = (r[0] + r[1] + r[2]) / 3
        ari = (ri[0] + ri[1] + ri[2]) / 3
        rsum = r[0] + r[1] + r[2] + ri[0] + ri[1] + ri[2]
        print("rsum: %.1f ar: %.1f ari: %.1f" % (rsum, ar, ari))
        results += [list(r) + list(ri) + [ar, ari, rsum]]

    print("-----------------------------------")
    print("Mean metrics: ")
    mean_metrics = tuple(np.array(results).mean(axis=0).flatten())
    print("rsum: %.1f" % (mean_metrics[16] * 6))
    print("Average i2t Recall: %.1f" % mean_metrics[14])
    print(
        "Image to text: %.1f %.1f %.1f %.1f %.1f ndcg_rouge=%.4f ndcg_spice=%.4f"
        % mean_metrics[:7])
    print("Average t2i Recall: %.1f" % mean_metrics[15])
    print(
        "Text to image: %.1f %.1f %.1f %.1f %.1f ndcg_rouge=%.4f ndcg_spice=%.4f"
        % mean_metrics[7:14])

    # record metrics in tensorboard
    tb_logger.add_scalar('test/r1', mean_metrics[0], model.Eiters)
    tb_logger.add_scalar('test/r5', mean_metrics[1], model.Eiters)
    tb_logger.add_scalar('test/r10', mean_metrics[2], model.Eiters)
    tb_logger.add_scalars('test/mean_ndcg', {
        'rougeL': mean_metrics[5],
        'spice': mean_metrics[6]
    }, model.Eiters)
    tb_logger.add_scalar('test/r1i', mean_metrics[7], model.Eiters)
    tb_logger.add_scalar('test/r5i', mean_metrics[8], model.Eiters)
    tb_logger.add_scalar('test/r10i', mean_metrics[9], model.Eiters)
    tb_logger.add_scalars('test/mean_ndcg_i', {
        'rougeL': mean_metrics[12],
        'spice': mean_metrics[13]
    }, model.Eiters)
Esempio n. 11
0
def validate(val_loader,
             model,
             tb_logger,
             measure='cosine',
             log_step=10,
             ndcg_scorer=None,
             alignment_mode=None):
    # compute the encoding for all the validation images and captions
    img_embs, cap_embs, img_lenghts, cap_lenghts = encode_data(
        model, val_loader, log_step, logging.info)

    # initialize similarity matrix evaluator
    sim_matrix_fn = AlignmentContrastiveLoss(
        aggregation=alignment_mode,
        return_similarity_mat=True) if alignment_mode is not None else None

    if measure == 'cosine':
        sim_fn = cosine_sim
    elif measure == 'dot':
        sim_fn = dot_sim

    # caption retrieval
    (r1, r5, r10, medr, meanr, mean_rougel_ndcg,
     mean_spice_ndcg) = i2t(img_embs,
                            cap_embs,
                            img_lenghts,
                            cap_lenghts,
                            measure=measure,
                            ndcg_scorer=ndcg_scorer,
                            sim_function=sim_matrix_fn)
    logging.info(
        "Image to text: %.1f, %.1f, %.1f, %.1f, %.1f, ndcg_rouge=%.4f ndcg_spice=%.4f"
        % (r1, r5, r10, medr, meanr, mean_rougel_ndcg, mean_spice_ndcg))
    # image retrieval
    (r1i, r5i, r10i, medri, meanr, mean_rougel_ndcg_i,
     mean_spice_ndcg_i) = t2i(img_embs,
                              cap_embs,
                              img_lenghts,
                              cap_lenghts,
                              ndcg_scorer=ndcg_scorer,
                              measure=measure,
                              sim_function=sim_matrix_fn)

    logging.info(
        "Text to image: %.1f, %.1f, %.1f, %.1f, %.1f, ndcg_rouge=%.4f ndcg_spice=%.4f"
        %
        (r1i, r5i, r10i, medri, meanr, mean_rougel_ndcg_i, mean_spice_ndcg_i))
    # sum of recalls to be used for early stopping
    currscore = r1 + r5 + r10 + r1i + r5i + r10i
    spice_ndcg_sum = mean_spice_ndcg + mean_spice_ndcg_i

    # record metrics in tensorboard
    tb_logger.add_scalar('r1', r1, model.Eiters)
    tb_logger.add_scalar('r5', r5, model.Eiters)
    tb_logger.add_scalar('r10', r10, model.Eiters)
    tb_logger.add_scalars('mean_ndcg', {
        'rougeL': mean_rougel_ndcg,
        'spice': mean_spice_ndcg
    }, model.Eiters)
    tb_logger.add_scalar('medr', medr, model.Eiters)
    tb_logger.add_scalar('meanr', meanr, model.Eiters)
    tb_logger.add_scalar('r1i', r1i, model.Eiters)
    tb_logger.add_scalar('r5i', r5i, model.Eiters)
    tb_logger.add_scalar('r10i', r10i, model.Eiters)
    tb_logger.add_scalars('mean_ndcg_i', {
        'rougeL': mean_rougel_ndcg_i,
        'spice': mean_spice_ndcg_i
    }, model.Eiters)
    tb_logger.add_scalar('medri', medri, model.Eiters)
    tb_logger.add_scalar('meanr', meanr, model.Eiters)
    tb_logger.add_scalar('rsum', currscore, model.Eiters)
    tb_logger.add_scalar('spice_ndcg_sum', spice_ndcg_sum, model.Eiters)

    return currscore, spice_ndcg_sum