예제 #1
0
def evaluate(img_ids,
             img_embs,
             t_embs,
             measure='cosine',
             n_caption=2,
             val_metric='map',
             direction='t2i'):
    count = {}
    for iid in img_ids:
        if int(iid) not in count:
            count[int(iid)] = (1, 0)
        else:
            count[int(iid)] = (count[int(iid)][0] + 1, 0)
    img_mask, text_mask = [False for _ in img_ids], [True for _ in img_ids]
    for idx, iid in enumerate(img_ids):
        c, u = count[int(iid)]
        if c >= n_caption and u == 0:
            img_mask[idx] = True
            count[int(iid)] = (c, 1)
        elif c >= n_caption and u == 1:
            count[int(iid)] = (c, 2)
        else:
            text_mask[idx] = False

    img_ids = [x for idx, x in enumerate(img_ids) if img_mask[idx]]
    img_embs = img_embs[img_mask]
    t_embs = t_embs[text_mask]

    c2i_all_errors = evaluation.cal_error(img_embs, t_embs, measure)

    if val_metric == "recall":
        # meme retrieval
        (r1i, r5i, r10i, medri, meanri) = evaluation.t2i(c2i_all_errors,
                                                         n_caption=n_caption)
        # caption retrieval
        (r1, r5, r10, medr, meanr) = evaluation.i2t(c2i_all_errors,
                                                    n_caption=n_caption)
    elif val_metric == "map":
        # meme retrieval
        t2i_map_score = evaluation.t2i_map(c2i_all_errors, n_caption=n_caption)
        # caption retrieval
        i2t_map_score = evaluation.i2t_map(c2i_all_errors, n_caption=n_caption)

    currscore = 0
    if val_metric == "recall":
        if direction == 'i2t' or direction == 'all':
            rsum = r1 + r5 + r10
            currscore += rsum
        if direction == 't2i' or direction == 'all':
            rsumi = r1i + r5i + r10i
            currscore += rsumi
    elif val_metric == "map":
        if direction == 'i2t' or direction == 'all':
            currscore += i2t_map_score
        if direction == 't2i' or direction == 'all':
            currscore += t2i_map_score

    return currscore
예제 #2
0
파일: trainer.py 프로젝트: nttung1110/W2VV
def validate(opt, val_loader, model, measure='cosine'):
    # compute the encoding for all the validation video and captions
    video_embs, cap_embs, video_ids, caption_ids = evaluation.encode_data(
        model, val_loader, opt.log_step, logging.info)

    # we load data as video-sentence pairs
    # but we only need to forward each video once for evaluation
    # so we get the video set and mask out same videos with feature_mask
    feature_mask = []
    evaluate_videos = set()
    for video_id in video_ids:
        feature_mask.append(video_id not in evaluate_videos)
        evaluate_videos.add(video_id)
    video_embs = video_embs[feature_mask]
    video_ids = [
        x for idx, x in enumerate(video_ids) if feature_mask[idx] is True
    ]

    c2i_all_errors = evaluation.cal_error(video_embs, cap_embs, measure)
    if opt.val_metric == "recall":

        # video retrieval
        (r1i, r5i, r10i, medri,
         meanri) = evaluation.t2i(c2i_all_errors, n_caption=opt.n_caption)
        print(" * Text to video:")
        print(" * r_1_5_10: {}".format(
            [round(r1i, 3), round(r5i, 3),
             round(r10i, 3)]))
        print(" * medr, meanr: {}".format([round(medri, 3), round(meanri, 3)]))
        print(" * " + '-' * 10)

        # caption retrieval
        (r1, r5, r10, medr, meanr) = evaluation.i2t(c2i_all_errors,
                                                    n_caption=opt.n_caption)
        print(" * Video to text:")
        print(" * r_1_5_10: {}".format(
            [round(r1, 3), round(r5, 3),
             round(r10, 3)]))
        print(" * medr, meanr: {}".format([round(medr, 3), round(meanr, 3)]))
        print(" * " + '-' * 10)

        # record metrics in tensorboard
        tb_logger.log_value('r1', r1, step=model.Eiters)
        tb_logger.log_value('r5', r5, step=model.Eiters)
        tb_logger.log_value('r10', r10, step=model.Eiters)
        tb_logger.log_value('medr', medr, step=model.Eiters)
        tb_logger.log_value('meanr', meanr, step=model.Eiters)
        tb_logger.log_value('r1i', r1i, step=model.Eiters)
        tb_logger.log_value('r5i', r5i, step=model.Eiters)
        tb_logger.log_value('r10i', r10i, step=model.Eiters)
        tb_logger.log_value('medri', medri, step=model.Eiters)
        tb_logger.log_value('meanri', meanri, step=model.Eiters)

    elif opt.val_metric == "map":
        i2t_map_score = evaluation.i2t_map(c2i_all_errors,
                                           n_caption=opt.n_caption)
        t2i_map_score = evaluation.t2i_map(c2i_all_errors,
                                           n_caption=opt.n_caption)
        tb_logger.log_value('i2t_map', i2t_map_score, step=model.Eiters)
        tb_logger.log_value('t2i_map', t2i_map_score, step=model.Eiters)
        print('i2t_map', i2t_map_score)
        print('t2i_map', t2i_map_score)

    currscore = 0
    if opt.val_metric == "recall":
        if opt.direction == 'i2t' or opt.direction == 'all':
            currscore += (r1 + r5 + r10)
        if opt.direction == 't2i' or opt.direction == 'all':
            currscore += (r1i + r5i + r10i)
    elif opt.val_metric == "map":
        if opt.direction == 'i2t' or opt.direction == 'all':
            currscore += i2t_map_score
        if opt.direction == 't2i' or opt.direction == 'all':
            currscore += t2i_map_score

    tb_logger.log_value('rsum', currscore, step=model.Eiters)

    return currscore
예제 #3
0
def validate_split(opt,
                   vid_data_loader,
                   text_data_loader,
                   model,
                   measure='cosine'):
    # compute the encoding for all the validation video and captions

    model.val_start()
    video_embs, video_ids = evaluation.encode_text_or_vid(
        model.embed_vis, vid_data_loader)
    cap_embs, caption_ids = evaluation.encode_text_or_vid(
        model.embed_txt, text_data_loader)

    c2i_all_errors = evaluation.cal_error(video_embs, cap_embs, measure)
    if opt.val_metric == "recall":

        # video retrieval
        if opt.testCollection.startswith('msvd'):
            (r1i, r5i, r10i, medri, meanri,
             t2i_map_score) = evaluation.t2i_varied(c2i_all_errors,
                                                    caption_ids, video_ids)
        else:
            (r1i, r5i, r10i, medri,
             meanri) = evaluation.t2i(c2i_all_errors, n_caption=opt.n_caption)
        print(" * Text to video:")
        print(" * r_1_5_10: {}".format(
            [round(r1i, 3), round(r5i, 3),
             round(r10i, 3)]))
        print(" * medr, meanr: {}".format([round(medri, 3), round(meanri, 3)]))
        print(" * " + '-' * 10)

        # caption retrieval
        if opt.testCollection.startswith('msvd'):
            (r1, r5, r10, medr, meanr,
             i2t_map_score) = evaluation.i2t_varied(c2i_all_errors,
                                                    caption_ids, video_ids)
        else:
            (r1, r5, r10, medr,
             meanr) = evaluation.i2t(c2i_all_errors, n_caption=opt.n_caption)
        print(" * Video to text:")
        print(" * r_1_5_10: {}".format(
            [round(r1, 3), round(r5, 3),
             round(r10, 3)]))
        print(" * medr, meanr: {}".format([round(medr, 3), round(meanr, 3)]))
        print(" * " + '-' * 10)

        # record metrics in tensorboard
        tb_logger.log_value('r1', r1, step=model.Eiters)
        tb_logger.log_value('r5', r5, step=model.Eiters)
        tb_logger.log_value('r10', r10, step=model.Eiters)
        tb_logger.log_value('medr', medr, step=model.Eiters)
        tb_logger.log_value('meanr', meanr, step=model.Eiters)
        tb_logger.log_value('r1i', r1i, step=model.Eiters)
        tb_logger.log_value('r5i', r5i, step=model.Eiters)
        tb_logger.log_value('r10i', r10i, step=model.Eiters)
        tb_logger.log_value('medri', medri, step=model.Eiters)
        tb_logger.log_value('meanri', meanri, step=model.Eiters)

    elif opt.val_metric == "map":
        i2t_map_score = evaluation.i2t_map(c2i_all_errors,
                                           n_caption=opt.n_caption)
        t2i_map_score = evaluation.t2i_map(c2i_all_errors,
                                           n_caption=opt.n_caption)
        tb_logger.log_value('i2t_map', i2t_map_score, step=model.Eiters)
        tb_logger.log_value('t2i_map', t2i_map_score, step=model.Eiters)
        print('i2t_map', i2t_map_score)
        print('t2i_map', t2i_map_score)

    currscore = 0
    if opt.val_metric == "recall":
        if opt.direction == 'i2t' or opt.direction == 'all':
            currscore += (r1 + r5 + r10)
        if opt.direction == 't2i' or opt.direction == 'all':
            currscore += (r1i + r5i + r10i)
    elif opt.val_metric == "map":
        if opt.direction == 'i2t' or opt.direction == 'all':
            currscore += i2t_map_score
        if opt.direction == 't2i' or opt.direction == 'all':
            currscore += t2i_map_score

    tb_logger.log_value('rsum', currscore, step=model.Eiters)

    return currscore
예제 #4
0
def main():
    opt = parse_args()
    print(json.dumps(vars(opt), indent=2))

    rootpath = opt.rootpath
    testCollection = opt.testCollection
    n_caption = opt.n_caption
    resume = os.path.join(opt.logger_name, opt.checkpoint_name)

    if not os.path.exists(resume):
        logging.info(resume + ' not exists.')
        sys.exit(0)

    checkpoint = torch.load(resume)
    start_epoch = checkpoint['epoch']
    best_rsum = checkpoint['best_rsum']
    print("=> loaded checkpoint '{}' (epoch {}, best_rsum {})".format(
        resume, start_epoch, best_rsum))
    options = checkpoint['opt']
    if not hasattr(options, 'concate'):
        setattr(options, "concate", "full")

    trainCollection = options.trainCollection
    output_dir = resume.replace(trainCollection, testCollection)
    output_dir = output_dir.replace('/%s/' % options.cv_name,
                                    '/results/%s/' % trainCollection)
    result_pred_sents = os.path.join(output_dir, 'id.sent.score.txt')
    pred_error_matrix_file = os.path.join(output_dir,
                                          'pred_errors_matrix.pth.tar')
    if checkToSkip(pred_error_matrix_file, opt.overwrite):
        sys.exit(0)
    makedirsforfile(pred_error_matrix_file)

    # data loader prepare
    caption_files = {
        'test':
        os.path.join(rootpath, testCollection, 'TextData',
                     '%s.caption.txt' % testCollection)
    }
    img_feat_path = os.path.join(rootpath, testCollection, 'FeatureData',
                                 options.visual_feature)
    visual_feats = {'test': BigFile(img_feat_path)}
    assert options.visual_feat_dim == visual_feats['test'].ndims
    video2frames = {
        'test':
        read_dict(
            os.path.join(rootpath, testCollection, 'FeatureData',
                         options.visual_feature, 'video2frames.txt'))
    }

    # set bow vocabulary and encoding
    bow_vocab_file = os.path.join(rootpath, options.trainCollection,
                                  'TextData', 'vocabulary', 'bow',
                                  options.vocab + '.pkl')
    bow_vocab = pickle.load(open(bow_vocab_file, 'rb'))
    bow2vec = get_text_encoder('bow')(bow_vocab)
    options.bow_vocab_size = len(bow_vocab)

    # set rnn vocabulary
    rnn_vocab_file = os.path.join(rootpath, options.trainCollection,
                                  'TextData', 'vocabulary', 'rnn',
                                  options.vocab + '.pkl')
    rnn_vocab = pickle.load(open(rnn_vocab_file, 'rb'))
    options.vocab_size = len(rnn_vocab)

    # Construct the model
    model = get_model(options.model)(options)
    model.load_state_dict(checkpoint['model'])
    model.Eiters = checkpoint['Eiters']
    model.val_start()

    if testCollection.startswith(
            'msvd'):  # or testCollection.startswith('msrvtt'):
        # set data loader
        video_ids_list = data.read_video_ids(caption_files['test'])
        vid_data_loader = data.get_vis_data_loader(visual_feats['test'],
                                                   opt.batch_size,
                                                   opt.workers,
                                                   video2frames['test'],
                                                   video_ids=video_ids_list)
        text_data_loader = data.get_txt_data_loader(caption_files['test'],
                                                    rnn_vocab, bow2vec,
                                                    opt.batch_size,
                                                    opt.workers)
        # mapping
        video_embs, video_ids = evaluation.encode_text_or_vid(
            model.embed_vis, vid_data_loader)
        cap_embs, caption_ids = evaluation.encode_text_or_vid(
            model.embed_txt, text_data_loader)
    else:
        # set data loader
        data_loader = data.get_test_data_loaders(caption_files,
                                                 visual_feats,
                                                 rnn_vocab,
                                                 bow2vec,
                                                 opt.batch_size,
                                                 opt.workers,
                                                 opt.n_caption,
                                                 video2frames=video2frames)
        # mapping
        video_embs, cap_embs, video_ids, caption_ids = evaluation.encode_data(
            model, data_loader['test'], opt.log_step, logging.info)
        # remove duplicate videos
        idx = range(0, video_embs.shape[0], n_caption)
        video_embs = video_embs[idx, :]
        video_ids = video_ids[::opt.n_caption]

    c2i_all_errors = evaluation.cal_error(video_embs, cap_embs,
                                          options.measure)
    torch.save(
        {
            'errors': c2i_all_errors,
            'videos': video_ids,
            'captions': caption_ids
        }, pred_error_matrix_file)
    print("write into: %s" % pred_error_matrix_file)

    if testCollection.startswith(
            'msvd'):  # or testCollection.startswith('msrvtt'):
        # caption retrieval
        (r1, r5, r10, medr, meanr,
         i2t_map_score) = evaluation.i2t_varied(c2i_all_errors, caption_ids,
                                                video_ids)
        # video retrieval
        (r1i, r5i, r10i, medri, meanri,
         t2i_map_score) = evaluation.t2i_varied(c2i_all_errors, caption_ids,
                                                video_ids)
    else:
        # caption retrieval
        (r1i, r5i, r10i, medri, meanri) = evaluation.t2i(c2i_all_errors,
                                                         n_caption=n_caption)
        t2i_map_score = evaluation.t2i_map(c2i_all_errors, n_caption=n_caption)

        # video retrieval
        (r1, r5, r10, medr, meanr) = evaluation.i2t(c2i_all_errors,
                                                    n_caption=n_caption)
        i2t_map_score = evaluation.i2t_map(c2i_all_errors, n_caption=n_caption)

    print(" * Text to Video:")
    print(" * r_1_5_10, medr, meanr: {}".format([
        round(r1i, 1),
        round(r5i, 1),
        round(r10i, 1),
        round(medri, 1),
        round(meanri, 1)
    ]))
    print(" * recall sum: {}".format(round(r1i + r5i + r10i, 1)))
    print(" * mAP: {}".format(round(t2i_map_score, 3)))
    print(" * " + '-' * 10)

    # caption retrieval
    print(" * Video to text:")
    print(" * r_1_5_10, medr, meanr: {}".format([
        round(r1, 1),
        round(r5, 1),
        round(r10, 1),
        round(medr, 1),
        round(meanr, 1)
    ]))
    print(" * recall sum: {}".format(round(r1 + r5 + r10, 1)))
    print(" * mAP: {}".format(round(i2t_map_score, 3)))
    print(" * " + '-' * 10)