Ejemplo n.º 1
0
def main():
    opt = parse_args()
    print(json.dumps(vars(opt), indent=2))

    rootpath = opt.rootpath
    trainCollection = opt.trainCollection
    valCollection = opt.valCollection
    testCollection = opt.testCollection

    if opt.loss_fun == "mrl" and opt.measure == "cosine":
        assert opt.text_norm is True
        assert opt.visual_norm is True

    # checkpoint path
    model_info = '%s_concate_%s_dp_%.1f_measure_%s' % (
        opt.model, opt.concate, opt.dropout, opt.measure)
    # text-side multi-level encoding info
    text_encode_info = 'vocab_%s_word_dim_%s_text_rnn_size_%s_text_norm_%s' % \
            (opt.vocab, opt.word_dim, opt.text_rnn_size, opt.text_norm)
    text_encode_info += "_kernel_sizes_%s_num_%s" % (opt.text_kernel_sizes,
                                                     opt.text_kernel_num)
    # video-side multi-level encoding info
    visual_encode_info = 'visual_feature_%s_visual_rnn_size_%d_visual_norm_%s' % \
            (opt.visual_feature, opt.visual_rnn_size, opt.visual_norm)
    visual_encode_info += "_kernel_sizes_%s_num_%s" % (opt.visual_kernel_sizes,
                                                       opt.visual_kernel_num)
    # common space learning info
    mapping_info = "mapping_text_%s_img_%s" % (opt.text_mapping_layers,
                                               opt.visual_mapping_layers)
    loss_info = 'loss_func_%s_margin_%s_direction_%s_max_violation_%s_cost_style_%s' % \
                    (opt.loss_fun, opt.margin, opt.direction, opt.max_violation, opt.cost_style)
    optimizer_info = 'optimizer_%s_lr_%s_decay_%.2f_grad_clip_%.1f_val_metric_%s' % \
                    (opt.optimizer, opt.learning_rate, opt.lr_decay_rate, opt.grad_clip, opt.val_metric)

    opt.logger_name = os.path.join(rootpath, trainCollection, opt.cv_name,
                                   valCollection, model_info, text_encode_info,
                                   visual_encode_info, mapping_info, loss_info,
                                   optimizer_info, opt.postfix)
    print(opt.logger_name)

    if checkToSkip(os.path.join(opt.logger_name, 'model_best.pth.tar'),
                   opt.overwrite):
        sys.exit(0)
    if checkToSkip(os.path.join(opt.logger_name, 'val_metric.txt'),
                   opt.overwrite):
        sys.exit(0)
    makedirsforfile(os.path.join(opt.logger_name, 'val_metric.txt'))
    logging.basicConfig(format='%(asctime)s %(message)s', level=logging.INFO)
    tb_logger.configure(opt.logger_name, flush_secs=5)

    opt.text_kernel_sizes = map(int, opt.text_kernel_sizes.split('-'))
    opt.visual_kernel_sizes = map(int, opt.visual_kernel_sizes.split('-'))
    # collections: trian, val
    collections = {'train': trainCollection, 'val': valCollection}
    cap_file = {
        'train': '%s.caption.txt' % trainCollection,
        'val': '%s.caption.txt' % valCollection
    }
    # caption
    caption_files = {
        x: os.path.join(rootpath, collections[x], 'TextData', cap_file[x])
        for x in collections
    }
    # Load visual features
    visual_feat_path = {
        x: os.path.join(rootpath, collections[x], 'FeatureData',
                        opt.visual_feature)
        for x in collections
    }
    visual_feats = {x: BigFile(visual_feat_path[x]) for x in visual_feat_path}
    opt.visual_feat_dim = visual_feats['train'].ndims

    # set bow vocabulary and encoding
    bow_vocab_file = os.path.join(rootpath, opt.trainCollection, 'TextData',
                                  'vocabulary', 'bow', opt.vocab + '.pkl')
    bow_vocab = pickle.load(open(bow_vocab_file, 'rb'))
    bow2vec = get_text_encoder('bow')(bow_vocab)
    opt.bow_vocab_size = len(bow_vocab)

    # set rnn vocabulary
    rnn_vocab_file = os.path.join(rootpath, opt.trainCollection, 'TextData',
                                  'vocabulary', 'rnn', opt.vocab + '.pkl')
    rnn_vocab = pickle.load(open(rnn_vocab_file, 'rb'))
    opt.vocab_size = len(rnn_vocab)

    # initialize word embedding
    opt.we_parameter = None
    if opt.word_dim == 500:
        w2v_data_path = os.path.join(rootpath, "word2vec", 'flickr',
                                     'vec500flickr30m')
        opt.we_parameter = get_we_parameter(rnn_vocab, w2v_data_path)

    # mapping layer structure
    opt.text_mapping_layers = map(int, opt.text_mapping_layers.split('-'))
    opt.visual_mapping_layers = map(int, opt.visual_mapping_layers.split('-'))
    if opt.concate == 'full':
        opt.text_mapping_layers[
            0] = opt.bow_vocab_size + opt.text_rnn_size * 2 + opt.text_kernel_num * len(
                opt.text_kernel_sizes)
        opt.visual_mapping_layers[
            0] = opt.visual_feat_dim + opt.visual_rnn_size * 2 + opt.visual_kernel_num * len(
                opt.visual_kernel_sizes)
    elif opt.concate == 'reduced':
        opt.text_mapping_layers[
            0] = opt.text_rnn_size * 2 + opt.text_kernel_num * len(
                opt.text_kernel_sizes)
        opt.visual_mapping_layers[
            0] = opt.visual_rnn_size * 2 + opt.visual_kernel_num * len(
                opt.visual_kernel_sizes)
    else:
        raise NotImplementedError('Model %s not implemented' % opt.model)

    # set data loader
    video2frames = {
        x: read_dict(
            os.path.join(rootpath, collections[x], 'FeatureData',
                         opt.visual_feature, 'video2frames.txt'))
        for x in collections
    }
    data_loaders = data.get_data_loaders(caption_files,
                                         visual_feats,
                                         rnn_vocab,
                                         bow2vec,
                                         opt.batch_size,
                                         opt.workers,
                                         opt.n_caption,
                                         video2frames=video2frames)

    # Construct the model
    model = get_model(opt.model)(opt)
    opt.we_parameter = None

    # optionally resume from a checkpoint
    if opt.resume:
        if os.path.isfile(opt.resume):
            print("=> loading checkpoint '{}'".format(opt.resume))
            checkpoint = torch.load(opt.resume)
            start_epoch = checkpoint['epoch']
            best_rsum = checkpoint['best_rsum']
            model.load_state_dict(checkpoint['model'])
            # Eiters is used to show logs as the continuation of another
            # training
            model.Eiters = checkpoint['Eiters']
            print("=> loaded checkpoint '{}' (epoch {}, best_rsum {})".format(
                opt.resume, start_epoch, best_rsum))
            validate(opt, data_loaders['val'], model, measure=opt.measure)
        else:
            print("=> no checkpoint found at '{}'".format(opt.resume))

    # Train the Model
    best_rsum = 0
    no_impr_counter = 0
    lr_counter = 0
    best_epoch = None
    fout_val_metric_hist = open(
        os.path.join(opt.logger_name, 'val_metric_hist.txt'), 'w')
    for epoch in range(opt.num_epochs):
        print('Epoch[{0} / {1}] LR: {2}'.format(
            epoch, opt.num_epochs,
            get_learning_rate(model.optimizer)[0]))
        print('-' * 10)
        # train for one epoch
        train(opt, data_loaders['train'], model, epoch)

        # evaluate on validation set
        rsum = validate(opt, data_loaders['val'], model, measure=opt.measure)

        # remember best R@ sum and save checkpoint
        is_best = rsum > best_rsum
        best_rsum = max(rsum, best_rsum)
        print(' * Current perf: {}'.format(rsum))
        print(' * Best perf: {}'.format(best_rsum))
        print('')
        fout_val_metric_hist.write('epoch_%d: %f\n' % (epoch, rsum))
        fout_val_metric_hist.flush()

        if is_best:
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'model': model.state_dict(),
                    'best_rsum': best_rsum,
                    'opt': opt,
                    'Eiters': model.Eiters,
                },
                is_best,
                filename='checkpoint_epoch_%s.pth.tar' % epoch,
                prefix=opt.logger_name + '/',
                best_epoch=best_epoch)
            best_epoch = epoch

        lr_counter += 1
        decay_learning_rate(opt, model.optimizer, opt.lr_decay_rate)
        if not is_best:
            # Early stop occurs if the validation performance does not improve in ten consecutive epochs
            no_impr_counter += 1
            if no_impr_counter > 10:
                print('Early stopping happended.\n')
                break

            # When the validation performance decreased after an epoch,
            # we divide the learning rate by 2 and continue training;
            # but we use each learning rate for at least 3 epochs.
            if lr_counter > 2:
                decay_learning_rate(opt, model.optimizer, 0.5)
                lr_counter = 0
        else:
            no_impr_counter = 0

    fout_val_metric_hist.close()

    print('best performance on validation: {}\n'.format(best_rsum))
    with open(os.path.join(opt.logger_name, 'val_metric.txt'), 'w') as fout:
        fout.write('best performance on validation: ' + str(best_rsum))

    # generate evaluation shell script
    if testCollection == 'iacc.3':
        templete = ''.join(open('util/TEMPLATE_do_predict.sh').readlines())
        striptStr = templete.replace('@@@query_sets@@@',
                                     'tv16.avs.txt,tv17.avs.txt,tv18.avs.txt')
    else:
        templete = ''.join(open('util/TEMPLATE_do_test.sh').readlines())
        striptStr = templete.replace('@@@n_caption@@@', str(opt.n_caption))
    striptStr = striptStr.replace('@@@rootpath@@@', rootpath)
    striptStr = striptStr.replace('@@@testCollection@@@', testCollection)
    striptStr = striptStr.replace('@@@logger_name@@@', opt.logger_name)
    striptStr = striptStr.replace('@@@overwrite@@@', str(opt.overwrite))

    # perform evaluation on test set
    runfile = 'do_test_%s_%s.sh' % (opt.model, testCollection)
    open(runfile, 'w').write(striptStr + '\n')
    os.system('chmod +x %s' % runfile)
Ejemplo n.º 2
0
def main():
    opt = parse_args()
    print(json.dumps(vars(opt), indent=2))

    rootpath = opt.rootpath
    testCollection = opt.testCollection
    n_caption = opt.n_caption
    resume = os.path.join(opt.logger_name, opt.checkpoint_name)

    if not os.path.exists(resume):
        logging.info(resume + ' not exists.')
        sys.exit(0)

    checkpoint = torch.load(resume)
    start_epoch = checkpoint['epoch']
    best_rsum = checkpoint['best_rsum']
    print("=> loaded checkpoint '{}' (epoch {}, best_rsum {})".format(
        resume, start_epoch, best_rsum))
    options = checkpoint['opt']
    if not hasattr(options, 'concate'):
        setattr(options, "concate", "full")

    trainCollection = options.trainCollection
    output_dir = resume.replace(trainCollection, testCollection)
    output_dir = output_dir.replace('/%s/' % options.cv_name,
                                    '/results/%s/' % trainCollection)
    result_pred_sents = os.path.join(output_dir, 'id.sent.score.txt')
    pred_error_matrix_file = os.path.join(output_dir,
                                          'pred_errors_matrix.pth.tar')
    if checkToSkip(pred_error_matrix_file, opt.overwrite):
        sys.exit(0)
    makedirsforfile(pred_error_matrix_file)

    # data loader prepare
    caption_files = {
        'test':
        os.path.join(rootpath, testCollection, 'TextData',
                     '%s.caption.txt' % testCollection)
    }
    img_feat_path = os.path.join(rootpath, testCollection, 'FeatureData',
                                 options.visual_feature)
    visual_feats = {'test': BigFile(img_feat_path)}
    assert options.visual_feat_dim == visual_feats['test'].ndims
    video2frames = {
        'test':
        read_dict(
            os.path.join(rootpath, testCollection, 'FeatureData',
                         options.visual_feature, 'video2frames.txt'))
    }

    # set bow vocabulary and encoding
    bow_vocab_file = os.path.join(rootpath, options.trainCollection,
                                  'TextData', 'vocabulary', 'bow',
                                  options.vocab + '.pkl')
    bow_vocab = pickle.load(open(bow_vocab_file, 'rb'))
    bow2vec = get_text_encoder('bow')(bow_vocab)
    options.bow_vocab_size = len(bow_vocab)

    # set rnn vocabulary
    rnn_vocab_file = os.path.join(rootpath, options.trainCollection,
                                  'TextData', 'vocabulary', 'rnn',
                                  options.vocab + '.pkl')
    rnn_vocab = pickle.load(open(rnn_vocab_file, 'rb'))
    options.vocab_size = len(rnn_vocab)

    # Construct the model
    model = get_model(options.model)(options)
    model.load_state_dict(checkpoint['model'])
    model.Eiters = checkpoint['Eiters']
    model.val_start()

    if testCollection.startswith(
            'msvd'):  # or testCollection.startswith('msrvtt'):
        # set data loader
        video_ids_list = data.read_video_ids(caption_files['test'])
        vid_data_loader = data.get_vis_data_loader(visual_feats['test'],
                                                   opt.batch_size,
                                                   opt.workers,
                                                   video2frames['test'],
                                                   video_ids=video_ids_list)
        text_data_loader = data.get_txt_data_loader(caption_files['test'],
                                                    rnn_vocab, bow2vec,
                                                    opt.batch_size,
                                                    opt.workers)
        # mapping
        video_embs, video_ids = evaluation.encode_text_or_vid(
            model.embed_vis, vid_data_loader)
        cap_embs, caption_ids = evaluation.encode_text_or_vid(
            model.embed_txt, text_data_loader)
    else:
        # set data loader
        data_loader = data.get_test_data_loaders(caption_files,
                                                 visual_feats,
                                                 rnn_vocab,
                                                 bow2vec,
                                                 opt.batch_size,
                                                 opt.workers,
                                                 opt.n_caption,
                                                 video2frames=video2frames)
        # mapping
        video_embs, cap_embs, video_ids, caption_ids = evaluation.encode_data(
            model, data_loader['test'], opt.log_step, logging.info)
        # remove duplicate videos
        idx = range(0, video_embs.shape[0], n_caption)
        video_embs = video_embs[idx, :]
        video_ids = video_ids[::opt.n_caption]

    c2i_all_errors = evaluation.cal_error(video_embs, cap_embs,
                                          options.measure)
    torch.save(
        {
            'errors': c2i_all_errors,
            'videos': video_ids,
            'captions': caption_ids
        }, pred_error_matrix_file)
    print("write into: %s" % pred_error_matrix_file)

    if testCollection.startswith(
            'msvd'):  # or testCollection.startswith('msrvtt'):
        # caption retrieval
        (r1, r5, r10, medr, meanr,
         i2t_map_score) = evaluation.i2t_varied(c2i_all_errors, caption_ids,
                                                video_ids)
        # video retrieval
        (r1i, r5i, r10i, medri, meanri,
         t2i_map_score) = evaluation.t2i_varied(c2i_all_errors, caption_ids,
                                                video_ids)
    else:
        # caption retrieval
        (r1i, r5i, r10i, medri, meanri) = evaluation.t2i(c2i_all_errors,
                                                         n_caption=n_caption)
        t2i_map_score = evaluation.t2i_map(c2i_all_errors, n_caption=n_caption)

        # video retrieval
        (r1, r5, r10, medr, meanr) = evaluation.i2t(c2i_all_errors,
                                                    n_caption=n_caption)
        i2t_map_score = evaluation.i2t_map(c2i_all_errors, n_caption=n_caption)

    print(" * Text to Video:")
    print(" * r_1_5_10, medr, meanr: {}".format([
        round(r1i, 1),
        round(r5i, 1),
        round(r10i, 1),
        round(medri, 1),
        round(meanri, 1)
    ]))
    print(" * recall sum: {}".format(round(r1i + r5i + r10i, 1)))
    print(" * mAP: {}".format(round(t2i_map_score, 3)))
    print(" * " + '-' * 10)

    # caption retrieval
    print(" * Video to text:")
    print(" * r_1_5_10, medr, meanr: {}".format([
        round(r1, 1),
        round(r5, 1),
        round(r10, 1),
        round(medr, 1),
        round(meanr, 1)
    ]))
    print(" * recall sum: {}".format(round(r1 + r5 + r10, 1)))
    print(" * mAP: {}".format(round(i2t_map_score, 3)))
    print(" * " + '-' * 10)
Ejemplo n.º 3
0
def main():
    opt = parse_args()
    print(json.dumps(vars(opt), indent=2))

    rootpath = opt.rootpath
    collectionStrt = opt.collectionStrt
    resume = os.path.join(opt.logger_name, opt.checkpoint_name)

    if not os.path.exists(resume):
        logging.info(resume + ' not exists.')
        sys.exit(0)

    checkpoint = torch.load(resume)
    start_epoch = checkpoint['epoch']
    best_rsum = checkpoint['best_rsum']
    print("=> loaded checkpoint '{}' (epoch {}, best_rsum {})".format(
        resume, start_epoch, best_rsum))
    options = checkpoint['opt']

    # collection setting
    testCollection = opt.testCollection
    collections_pathname = options.collections_pathname
    collections_pathname['test'] = testCollection

    trainCollection = options.trainCollection
    output_dir = resume.replace(trainCollection, testCollection)
    if 'checkpoints' in output_dir:
        output_dir = output_dir.replace('/checkpoints/', '/results/')
    else:
        output_dir = output_dir.replace(
            '/%s/' % options.cv_name,
            '/results/%s/%s/' % (options.cv_name, trainCollection))
    result_pred_sents = os.path.join(output_dir, 'id.sent.score.txt')
    pred_error_matrix_file = os.path.join(output_dir,
                                          'pred_errors_matrix.pth.tar')
    if checkToSkip(pred_error_matrix_file, opt.overwrite):
        sys.exit(0)
    makedirsforfile(pred_error_matrix_file)

    log_config(output_dir)
    logging.info(json.dumps(vars(opt), indent=2))

    # data loader prepare
    test_cap = os.path.join(rootpath, collections_pathname['test'], 'TextData',
                            '%s.caption.txt' % testCollection)
    if collectionStrt == 'single':
        test_cap = os.path.join(
            rootpath, collections_pathname['test'], 'TextData',
            '%s%s.caption.txt' % (testCollection, opt.split))
    elif collectionStrt == 'multiple':
        test_cap = os.path.join(rootpath, collections_pathname['test'],
                                'TextData', '%s.caption.txt' % testCollection)
    else:
        raise NotImplementedError('collection structure %s not implemented' %
                                  collectionStrt)

    caption_files = {'test': test_cap}
    img_feat_path = os.path.join(rootpath, collections_pathname['test'],
                                 'FeatureData', options.visual_feature)
    visual_feats = {'test': BigFile(img_feat_path)}
    assert options.visual_feat_dim == visual_feats['test'].ndims
    video2frames = {
        'test':
        read_dict(
            os.path.join(rootpath, collections_pathname['test'], 'FeatureData',
                         options.visual_feature, 'video2frames.txt'))
    }

    # set bow vocabulary and encoding
    bow_vocab_file = os.path.join(rootpath, collections_pathname['train'],
                                  'TextData', 'vocabulary', 'bow',
                                  options.vocab + '.pkl')
    bow_vocab = pickle.load(open(bow_vocab_file, 'rb'))
    bow2vec = get_text_encoder('bow')(bow_vocab)
    options.bow_vocab_size = len(bow_vocab)

    # set rnn vocabulary
    rnn_vocab_file = os.path.join(rootpath, collections_pathname['train'],
                                  'TextData', 'vocabulary', 'rnn',
                                  options.vocab + '.pkl')
    rnn_vocab = pickle.load(open(rnn_vocab_file, 'rb'))
    options.vocab_size = len(rnn_vocab)

    # Construct the model
    model = get_model(options.model)(options)
    model.load_state_dict(checkpoint['model'])
    model.Eiters = checkpoint['Eiters']
    model.val_start()

    # set data loader
    video_ids_list = data.read_video_ids(caption_files['test'])
    vid_data_loader = data.get_vis_data_loader(visual_feats['test'],
                                               opt.batch_size,
                                               opt.workers,
                                               video2frames['test'],
                                               video_ids=video_ids_list)
    text_data_loader = data.get_txt_data_loader(caption_files['test'],
                                                rnn_vocab, bow2vec,
                                                opt.batch_size, opt.workers)

    # mapping
    if options.space == 'hybrid':
        video_embs, video_tag_probs, video_ids = evaluation.encode_text_or_vid_tag_hist_prob(
            model.embed_vis, vid_data_loader)
        cap_embs, cap_tag_probs, caption_ids = evaluation.encode_text_or_vid_tag_hist_prob(
            model.embed_txt, text_data_loader)
    else:
        video_embs, video_ids = evaluation.encode_text_or_vid(
            model.embed_vis, vid_data_loader)
        cap_embs, caption_ids = evaluation.encode_text_or_vid(
            model.embed_txt, text_data_loader)

    v2t_gt, t2v_gt = metrics.get_gt(video_ids, caption_ids)

    logging.info("write into: %s" % output_dir)
    if options.space != 'latent':
        tag_vocab_path = os.path.join(
            rootpath, collections_pathname['train'], 'TextData', 'tags',
            'video_label_th_1', 'tag_vocab_%d.json' % options.tag_vocab_size)
        evaluation.pred_tag(video_tag_probs, video_ids, tag_vocab_path,
                            os.path.join(output_dir, 'video'))
        evaluation.pred_tag(cap_tag_probs, caption_ids, tag_vocab_path,
                            os.path.join(output_dir, 'text'))

    if options.space in ['latent', 'hybrid']:
        # logging.info("=======Latent Space=======")
        t2v_all_errors_1 = evaluation.cal_error(video_embs, cap_embs,
                                                options.measure)

    if options.space in ['concept', 'hybrid']:
        # logging.info("=======Concept Space=======")
        t2v_all_errors_2 = evaluation.cal_error_batch(video_tag_probs,
                                                      cap_tag_probs,
                                                      options.measure_2)

    if options.space in ['hybrid']:
        w = 0.6
        t2v_all_errors_1 = norm_score(t2v_all_errors_1)
        t2v_all_errors_2 = norm_score(t2v_all_errors_2)
        t2v_tag_all_errors = w * t2v_all_errors_1 + (1 - w) * t2v_all_errors_2
        cal_perf(t2v_tag_all_errors, v2t_gt, t2v_gt)
        torch.save(
            {
                'errors': t2v_tag_all_errors,
                'videos': video_ids,
                'captions': caption_ids
            }, pred_error_matrix_file)
        logging.info("write into: %s" % pred_error_matrix_file)

    elif options.space in ['latent']:
        cal_perf(t2v_all_errors_1, v2t_gt, t2v_gt)
        torch.save(
            {
                'errors': t2v_all_errors_1,
                'videos': video_ids,
                'captions': caption_ids
            }, pred_error_matrix_file)
        logging.info("write into: %s" % pred_error_matrix_file)
Ejemplo n.º 4
0
def main():
    opt = parse_args()
    logging.info(json.dumps(vars(opt), indent=2))

    rootpath = opt.rootpath
    testCollection = opt.testCollection
    assert collectionStrt == "multiple"
    resume = os.path.join(opt.logger_name, opt.checkpoint_name)

    if not os.path.exists(resume):
        logging.info(resume + ' not exists.')
        sys.exit(0)

    checkpoint = torch.load(resume)
    start_epoch = checkpoint['epoch']
    best_rsum = checkpoint['best_rsum']
    logging.info("=> loaded checkpoint '{}' (epoch {}, best_rsum {})".format(
        resume, start_epoch, best_rsum))
    options = checkpoint['opt']

    trainCollection = options.trainCollection
    valCollection = options.valCollection

    visual_feat_file = BigFile(
        os.path.join(rootpath, testCollection, 'FeatureData',
                     options.visual_feature))
    assert options.visual_feat_dim == visual_feat_file.ndims
    video2frame = read_dict(
        os.path.join(rootpath, testCollection, 'FeatureData',
                     options.visual_feature, 'video2frames.txt'))
    vid_data_loader = data.get_vis_data_loader(visual_feat_file,
                                               opt.batch_size, opt.workers,
                                               video2frame)
    vis_embs = None

    # set bow vocabulary and encoding
    bow_vocab_file = os.path.join(rootpath, options.trainCollection,
                                  'TextData', 'vocabulary', 'bow',
                                  options.vocab + '.pkl')
    bow_vocab = pickle.load(open(bow_vocab_file, 'rb'))
    bow2vec = get_text_encoder('bow')(bow_vocab)
    options.bow_vocab_size = len(bow_vocab)

    # set rnn vocabulary
    rnn_vocab_file = os.path.join(rootpath, options.trainCollection,
                                  'TextData', 'vocabulary', 'rnn',
                                  options.vocab + '.pkl')
    rnn_vocab = pickle.load(open(rnn_vocab_file, 'rb'))
    options.vocab_size = len(rnn_vocab)

    model = get_model(options.model)(options)
    model.load_state_dict(checkpoint['model'])
    model.val_start()

    output_dir = resume.replace(trainCollection, testCollection)
    for query_set in opt.query_sets.strip().split(','):
        output_dir_tmp = output_dir.replace(
            valCollection,
            '%s/%s/%s' % (query_set, trainCollection, valCollection))
        output_dir_tmp = output_dir_tmp.replace('/%s/' % options.cv_name,
                                                '/results/')
        pred_result_file = os.path.join(output_dir_tmp, 'id.sent.score.txt')
        logging.info(pred_result_file)
        if checkToSkip(pred_result_file, opt.overwrite):
            sys.exit(0)
        makedirsforfile(pred_result_file)

        # query data loader
        query_file = os.path.join(rootpath, testCollection, 'TextData',
                                  query_set)
        query_loader = data.get_txt_data_loader(query_file, rnn_vocab, bow2vec,
                                                opt.batch_size, opt.workers)

        # encode videos
        if vis_embs is None:
            start = time.time()
            if options.space == 'hybrid':
                video_embs, video_tag_probs, video_ids = evaluation.encode_text_or_vid_tag_hist_prob(
                    model.embed_vis, vid_data_loader)
            else:
                video_embs, video_ids = evaluation.encode_text_or_vid(
                    model.embed_vis, vid_data_loader)
            logging.info("encode video time: %.3f s" % (time.time() - start))

        # encode text
        start = time.time()
        if options.space == 'hybrid':
            query_embs, query_tag_probs, query_ids = evaluation.encode_text_or_vid_tag_hist_prob(
                model.embed_txt, query_loader)
        else:
            query_embs, query_ids = evaluation.encode_text_or_vid(
                model.embed_txt, query_loader)
        logging.info("encode text time: %.3f s" % (time.time() - start))

        if options.space == 'hybrid':
            t2v_matrix_1 = evaluation.cal_simi(query_embs, video_embs)
            # eval_avs(t2v_matrix_1, query_ids, video_ids, pred_result_file, rootpath, testCollection, query_set)

            t2v_matrix_2 = evaluation.cal_simi(query_tag_probs,
                                               video_tag_probs)
            # pred_result_file = os.path.join(output_dir_tmp, 'id.sent.score_2.txt')
            # eval_avs(t2v_matrix_2, query_ids, video_ids, pred_result_file, rootpath, testCollection, query_set)

            t2v_matrix_1 = norm_score(t2v_matrix_1)
            t2v_matrix_2 = norm_score(t2v_matrix_2)
            for w in [0.8]:
                print("\n")
                t2v_matrix = w * t2v_matrix_1 + (1 - w) * t2v_matrix_2
                pred_result_file = os.path.join(output_dir_tmp,
                                                'id.sent.score_%.1f.txt' % w)
                eval_avs(t2v_matrix, query_ids, video_ids, pred_result_file,
                         rootpath, testCollection, query_set)
        else:
            t2v_matrix_1 = evaluation.cal_simi(query_embs, video_embs)
            eval_avs(t2v_matrix_1, query_ids, video_ids, pred_result_file,
                     rootpath, testCollection, query_set)
def main():
    opt = parse_args()
    print(json.dumps(vars(opt), indent=2))

    rootpath = opt.rootpath
    evalpath = opt.evalpath
    testCollection = opt.testCollection
    batchsize = opt.batch_size

    # n_caption = opt.n_caption
    resume = os.path.join(opt.logger_name, opt.checkpoint_name)

    if not os.path.exists(resume):
        logging.info(resume + ' not exists.')
        sys.exit(0)

    saveFile_AVS16 = (opt.logger_name + '/AVS16_' + testCollection +
                      '_Dense_Dual_model_bin.txt')
    saveFile_AVS17 = (opt.logger_name + '/AVS17_' + testCollection +
                      '_Dense_Dual_model_bin.txt')
    saveFile_AVS18 = (opt.logger_name + '/AVS18_' + testCollection +
                      '_Dense_Dual_model_bin.txt')

    if os.path.exists(saveFile_AVS17):
        sys.exit(0)

    queriesFile = 'AVS/tv16_17_18.avs.topics_parsed.txt'
    lineList = [line.rstrip('\n') for line in open(queriesFile)]

    checkpoint = torch.load(resume)
    start_epoch = checkpoint['epoch']
    best_rsum = checkpoint['best_rsum']
    print("=> loaded checkpoint '{}' (epoch {}, best_rsum {})".format(
        resume, start_epoch, best_rsum))
    options = checkpoint['opt']

    if not hasattr(options, 'do_visual_feas_norm'):
        setattr(options, "do_visual_feas_norm", 0)

    if not hasattr(options, 'concate'):
        setattr(options, "concate", "full")

    trainCollection = options.trainCollection
    output_dir = resume.replace(trainCollection, testCollection)
    output_dir = output_dir.replace('/%s/' % options.cv_name,
                                    '/results/%s/' % trainCollection)
    result_pred_sents = os.path.join(output_dir, 'id.sent.score.txt')
    pred_error_matrix_file = os.path.join(output_dir,
                                          'pred_errors_matrix.pth.tar')
    if checkToSkip(pred_error_matrix_file, opt.overwrite):
        sys.exit(0)
    makedirsforfile(pred_error_matrix_file)

    # data loader prepare
    caption_files = {
        'test':
        os.path.join(evalpath, testCollection, 'TextData',
                     '%s.caption.txt' % testCollection)
    }
    img_feat_path = os.path.join(evalpath, testCollection, 'FeatureData',
                                 options.visual_feature)
    visual_feats = {'test': BigFile(img_feat_path)}
    assert options.visual_feat_dim == visual_feats['test'].ndims
    video2frames = {
        'test':
        read_dict(
            os.path.join(evalpath, testCollection, 'FeatureData',
                         options.visual_feature, 'video2frames.txt'))
    }
    # video2frames = None

    # set bow vocabulary and encoding
    bow_vocab_file = os.path.join(rootpath, options.trainCollection,
                                  'TextData', 'vocabulary', 'bow',
                                  options.vocab + '.pkl')
    bow_vocab = pickle.load(open(bow_vocab_file, 'rb'))
    bow2vec = get_text_encoder('bow')(bow_vocab)
    options.bow_vocab_size = len(bow_vocab)

    # set rnn vocabulary
    rnn_vocab_file = os.path.join(rootpath, options.trainCollection,
                                  'TextData', 'vocabulary', 'rnn',
                                  options.vocab + '.pkl')
    rnn_vocab = pickle.load(open(rnn_vocab_file, 'rb'))
    options.vocab_size = len(rnn_vocab)

    # initialize word embedding
    options.we_parameter = None
    if options.word_dim == 500:
        w2v_data_path = os.path.join(rootpath, "word2vec", 'flickr',
                                     'vec500flickr30m')
        options.we_parameter = get_we_parameter(rnn_vocab, w2v_data_path)

    # Construct the model
    model = get_model(options.model)(options)
    model.load_state_dict(checkpoint['model'])
    model.Eiters = checkpoint['Eiters']
    # switch to evaluate mode
    model.val_start()

    video2frames = video2frames['test']
    videoIDs = [key for key in video2frames.keys()]

    # Queries embeddings
    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased',
                                              do_lower_case=True)
    queryEmbeddingsTMP = []
    for quer in lineList:
        videBatch = videoIDs[0]  # a dummy video
        data = dataLoadedVideoText_one(video2frames, videBatch,
                                       visual_feats['test'], quer, bow2vec,
                                       rnn_vocab, tokenizer, options)
        videos, captions = collate_frame_gru_fn(data)
        # compute the embeddings
        vid_emb, cap_emb = model.forward_emb(videos, captions, True)
        # preserve the embeddings by copying from gpu and converting to numpy
        cap_embs = cap_emb.data.cpu().numpy().copy()
        queryEmbeddingsTMP.append(cap_embs[0])

    queryEmbeddings = np.stack(queryEmbeddingsTMP)
    # print(queryEmbeddings.shape)

    start = time.time()
    VideoIDS = []
    errorlistList = []

    for i in xrange(0, len(videoIDs), batchsize):
        videBatch = videoIDs[i:i + batchsize]
        VideoIDS.extend(videBatch)

        data = []
        for bb in videBatch:
            data.extend(
                dataLoadedVideoText_one(video2frames, bb, visual_feats['test'],
                                        lineList[0], bow2vec, rnn_vocab,
                                        tokenizer, options))
        videos, captions = collate_frame_gru_fn(data)

        # compute the embeddings
        vid_emb, cap_emb = model.forward_emb(videos, captions, True)
        # preserve the embeddings by copying from gpu and converting to numpy
        video_embs = vid_emb.data.cpu().numpy().copy()

        # calculate cosine distance
        errorlistList.extend(cosine_calculate(video_embs, queryEmbeddings))

        if i % 100000 == 0:
            # print (i)
            end = time.time()
            print(str(i) + ' in: ' + str(end - start))
            start = time.time()

    errorlist = np.asarray(errorlistList)
    f = open(saveFile_AVS16, "w")
    for num, name in enumerate(lineList[:30], start=1):
        queryError = errorlist[:, num - 1]
        scoresIndex = np.argsort(queryError)

        f = open(saveFile_AVS16, "a")
        c = 0
        for ind in scoresIndex:
            imgID = VideoIDS[ind]
            c = c + 1
            f.write('15%02d' % num)
            f.write(' 0 ' + imgID + ' ' + str(c) + ' ' + str(1000 - c) +
                    ' ITI-CERTH' + '\n')
            if c == 1000:
                break
    f.close()

    # AVS17
    f = open(saveFile_AVS17, "w")
    for num, name in enumerate(lineList[30:60], start=31):
        queryError = errorlist[:, num - 1]
        scoresIndex = np.argsort(queryError)

        f = open(saveFile_AVS17, "a")
        c = 0
        for ind in scoresIndex:
            imgID = VideoIDS[ind]
            c = c + 1
            f.write('15%02d' % num)
            f.write(' 0 ' + imgID + ' ' + str(c) + ' ' + str(1000 - c) +
                    ' ITI-CERTH' + '\n')
            if c == 1000:
                break
    f.close()

    # AVS18
    f = open(saveFile_AVS18, "w")
    for num, name in enumerate(lineList[60:90], start=61):
        queryError = errorlist[:, num - 1]
        scoresIndex = np.argsort(queryError)

        f = open(saveFile_AVS18, "a")
        c = 0
        for ind in scoresIndex:
            imgID = VideoIDS[ind]
            c = c + 1
            f.write('15%02d' % num)
            f.write(' 0 ' + imgID + ' ' + str(c) + ' ' + str(1000 - c) +
                    ' ITI-CERTH' + '\n')
            if c == 1000:
                break
    f.close()

    resultAVSFile16 = saveFile_AVS16[:-4] + '_results.txt'
    command = "perl data/AVS/sample_eval.pl -q data/AVS/avs.qrels.tv16 {} > {}".format(
        saveFile_AVS16, resultAVSFile16)
    os.system(command)
    resultAVSFile17 = saveFile_AVS17[:-4] + '_results.txt'
    command = "perl data/AVS/sample_eval.pl -q data/AVS/avs.qrels.tv17 {} > {}".format(
        saveFile_AVS17, resultAVSFile17)
    os.system(command)
    resultAVSFile18 = saveFile_AVS18[:-4] + '_results.txt'
    command = "perl data/AVS/sample_eval.pl -q data/AVS/avs.qrels.tv18 {} > {}".format(
        saveFile_AVS18, resultAVSFile18)
    os.system(command)
Ejemplo n.º 6
0
def main():
    opt = parse_args()
    print(json.dumps(vars(opt), indent=2))

    rootpath = opt.rootpath
    testCollection = opt.testCollection
    resume = os.path.join(opt.logger_name, opt.checkpoint_name)

    if not os.path.exists(resume):
        logging.info(resume + ' not exists.')
        sys.exit(0)

    checkpoint = torch.load(resume)
    start_epoch = checkpoint['epoch']
    best_rsum = checkpoint['best_rsum']
    print("=> loaded checkpoint '{}' (epoch {}, best_rsum {})"
          .format(resume, start_epoch, best_rsum))

    options = checkpoint['opt']
    if not hasattr(options, 'concate'):
        setattr(options, "concate", "full")
    model = get_model(options.model)(options)
    model.load_state_dict(checkpoint['model'])
    model.val_start()

    trainCollection = options.trainCollection
    valCollection = options.valCollection

    visual_feat_file = BigFile(os.path.join(rootpath, testCollection, 'FeatureData', options.visual_feature))
    assert options.visual_feat_dim == visual_feat_file.ndims
    video2frame = read_dict(os.path.join(rootpath, testCollection, 'FeatureData', options.visual_feature, 'video2frames.txt'))
    visual_loader = data.get_vis_data_loader(visual_feat_file, opt.batch_size, opt.workers, video2frame)
    vis_embs = None

    # set bow vocabulary and encoding
    bow_vocab_file = os.path.join(rootpath, options.trainCollection, 'TextData', 'vocabulary', 'bow', options.vocab+'.pkl')
    bow_vocab = pickle.load(open(bow_vocab_file, 'rb'))
    bow2vec = get_text_encoder('bow')(bow_vocab)
    options.bow_vocab_size = len(bow_vocab)

    # set rnn vocabulary 
    rnn_vocab_file = os.path.join(rootpath, options.trainCollection, 'TextData', 'vocabulary', 'rnn', options.vocab+'.pkl')
    rnn_vocab = pickle.load(open(rnn_vocab_file, 'rb'))
    options.vocab_size = len(rnn_vocab)

    output_dir = resume.replace(trainCollection, testCollection)
    for query_set in opt.query_sets.strip().split(','):
        output_dir_tmp = output_dir.replace(valCollection, '%s/%s/%s' % (query_set, trainCollection, valCollection))
        output_dir_tmp = output_dir_tmp.replace('/%s/' % options.cv_name, '/results/')
        pred_result_file = os.path.join(output_dir_tmp, 'id.sent.score.txt')
        print(pred_result_file)
        if checkToSkip(pred_result_file, opt.overwrite):
            continue
        try:
            makedirsforfile(pred_result_file)
        except Exception as e:
            print(e)

        # data loader prepare
        query_file = os.path.join(rootpath, testCollection, 'TextData', query_set)

        # set data loader
        query_loader = data.get_txt_data_loader(query_file, rnn_vocab, bow2vec, opt.batch_size, opt.workers)

        if vis_embs is None:
            start = time.time()
            vis_embs, vis_ids = encode_data(model.embed_vis, visual_loader)
            print("encode image time: %.3f s" % (time.time()-start))

        start = time.time()
        query_embs, query_ids = encode_data(model.embed_txt, query_loader)
        print("encode text time: %.3f s" % (time.time()-start))

        start = time.time()
        t2i_matrix = query_embs.dot(vis_embs.T)
        inds = np.argsort(t2i_matrix, axis=1)
        print("compute similarity time: %.3f s" % (time.time()-start))

        with open(pred_result_file, 'w') as fout:
            for index in range(inds.shape[0]):
                ind = inds[index][::-1]
                fout.write(query_ids[index]+' '+' '.join([vis_ids[i]+' %s'%t2i_matrix[index][i]
                    for i in ind])+'\n')

        if testCollection == 'iacc.3':
            templete = ''.join(open( 'tv-avs-eval/TEMPLATE_do_eval.sh').readlines())
            striptStr = templete.replace('@@@rootpath@@@', rootpath)
            striptStr = striptStr.replace('@@@testCollection@@@', testCollection)
            striptStr = striptStr.replace('@@@topic_set@@@', query_set.split('.')[0])
            striptStr = striptStr.replace('@@@overwrite@@@', str(opt.overwrite))
            striptStr = striptStr.replace('@@@score_file@@@', pred_result_file)

            runfile = 'do_eval_%s.sh' % testCollection
            open(os.path.join('tv-avs-eval', runfile), 'w').write(striptStr + '\n')
            os.system('cd tv-avs-eval; chmod +x %s; bash %s; cd -' % (runfile, runfile))