def sample_cst_questions(checkpoint_path=None, subset='kptrain'):
    model_config = ModelConfig()
    model_config.convert = FLAGS.convert
    model_config.loss_type = 'pairwise'
    model_config.top_k = 3
    batch_size = 8
    # Get model
    create_fn = create_reader(FLAGS.model_type, phase='test')

    # Create the vocabulary.
    to_sentence = SentenceGenerator(trainset='trainval')

    # get data reader
    reader = create_fn(batch_size=batch_size,
                       subset=subset,
                       version=FLAGS.test_version)

    # Build model
    g = tf.Graph()
    with g.as_default():
        # Build the model.
        model = ContrastQuestionSampler(model_config)
        model.build()
        # Restore from checkpoint
        restorer = Restorer(g)
        sess = tf.Session()
        restorer.restore(sess, checkpoint_path)

    num_batches = reader.num_batches

    print('Running beam search inference...')

    for i in range(num_batches):
        outputs = reader.get_test_batch()

        # inference
        quest_ids, image_ids = outputs[-2:]
        c_ans, c_ans_len, pathes, scores = model.greedy_inference(
            outputs[:-2], sess)
        scores, pathes = post_process_prediction(scores, pathes)

        k = 3
        capt, capt_len = outputs[2:4]

        gt = capt[0, :capt_len[0]]
        print('gt: %s [%s]' %
              (to_sentence.index_to_question(gt),
               to_sentence.index_to_answer(c_ans[0, :c_ans_len[0]])))
        for ix in range(k):
            question = to_sentence.index_to_question(pathes[ix])
            answer = to_sentence.index_to_answer(c_ans[ix, :c_ans_len[ix]])
            print('%s %d: %s [%s]' %
                  ('pre' if ix == 0 else 'cst', ix, question, answer))
        import pdb
        pdb.set_trace()
def test():
    top_ans_file = '/import/vision-ephemeral/fl302/code/' \
                   'VQA-tensorflow/data/vqa_trainval_top2000_answers.txt'
    # top_ans_file = 'data/vqa_trainval_top2000_answers.txt'
    mc_ctx = MultiChoiceQuestionManger(subset='val', load_ans=True,
                                       top_ans_file=top_ans_file)
    to_sentence = SentenceGenerator(trainset='trainval',
                                    top_ans_file=top_ans_file)
    answer_enc = mc_ctx.encoder
    # quest_ids = mc_ctx._quest_id2image_id.keys()
    # quest_ids = np.array(quest_ids)

    # qids = np.random.choice(quest_ids, size=(5,), replace=False)

    create_fn = create_reader('VAQ-CA', 'train')
    reader = create_fn(batch_size=4, subset='kprestval')
    reader.start()

    for _ in range(20):
        # inputs = reader.get_test_batch()
        inputs = reader.pop_batch()

        _, _, _, _, labels, ans_seq, ans_len, quest_ids, image_ids = inputs

        b_top_ans = answer_enc.get_top_answers(labels)
        for i, (quest_id, i_a) in enumerate(zip(quest_ids, b_top_ans)):
            print('question id: %d' % quest_id)
            gt = mc_ctx.get_gt_answer(quest_id)
            print('GT: %s' % gt)
            print('Top: %s' % i_a)
            print('SG: top: %s' % to_sentence.index_to_top_answer(labels[i]))
            seq = ans_seq[i][:ans_len[i]].tolist()
            print('SG: seq: %s\n' % to_sentence.index_to_answer(seq))

    reader.stop()
Exemple #3
0
def convert():
    model_name = 'ivaq_var_restval'
    checkpoint_path = 'model/var_ivqa_pretrain_restval/model.ckpt-505000'
    # build model
    from config import ModelConfig
    model_config = ModelConfig()
    model_fn = get_model_creation_fn('VAQ-Var')
    # create graph
    g = tf.Graph()
    with g.as_default():
        # Build the model.
        model = model_fn(model_config, 'beam')
        model.build()
        tf_embedding = model._answer_embed
        tf_answer_feed = model._ans
        tf_answer_len_feed = model._ans_len
        # Restore from checkpoint
        print('Restore from %s' % checkpoint_path)
        restorer = Restorer(g)
        sess = tf.Session()
        restorer.restore(sess, checkpoint_path)

    # build reader
    top_ans_file = '/import/vision-ephemeral/fl302/code/' \
                   'VQA-tensorflow/data/vqa_trainval_top2000_answers.txt'
    mc_ctx = MultiChoiceQuestionManger(subset='val',
                                       load_ans=True,
                                       top_ans_file=top_ans_file)
    to_sentence = SentenceGenerator(trainset='trainval',
                                    top_ans_file=top_ans_file)
    answer_encoder = mc_ctx.encoder

    top_answer_inds = range(2000)
    top_answers = answer_encoder.get_top_answers(top_answer_inds)

    answer_seqs = answer_encoder.encode_to_sequence(top_answers)
    for i, (ans, seq) in enumerate(zip(top_answers, answer_seqs)):
        rec_ans = to_sentence.index_to_answer(seq)
        ans = ' '.join(_tokenize_sentence(ans))
        print('%d: Raw: %s, Rec: %s' % (i + 1, ans, rec_ans))
        assert (ans == rec_ans)
    print('Checking passed')

    # extract
    print('Converting...')
    ans_arr, ans_arr_len = put_to_array(answer_seqs)
    import pdb
    pdb.set_trace()
    embedding = sess.run(tf_embedding,
                         feed_dict={
                             tf_answer_feed: ans_arr.astype(np.int32),
                             tf_answer_len_feed: ans_arr_len.astype(np.int32)
                         })
    # save
    sv_file = 'data/v1_%s_top2000_lstm_embedding.h5' % model_name
    from util import save_hdf5
    save_hdf5(sv_file, {'answer_embedding': embedding})
    print('Done')
def main(_):
    # Build the inference graph.
    config = QuestionGeneratorConfig()
    reader = TFRecordDataFetcher(FLAGS.input_files, config.image_feature_key)

    g = tf.Graph()
    ckpt = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir)
    checkpoint_path = ckpt.model_checkpoint_path
    print(checkpoint_path)
    with g.as_default():
        model = QuestionGenerator(config, phase='evaluate')
        model.build()
    # g.finalize()

    # Create the vocabulary.
    to_sentence = SentenceGenerator(trainset=FLAGS.model_trainset)

    filenames = []
    for file_pattern in FLAGS.input_files.split(","):
        filenames.extend(tf.gfile.Glob(file_pattern))
    tf.logging.set_verbosity(tf.logging.INFO)
    tf.logging.info("Running caption generation on %d files matching %s",
                    len(filenames), FLAGS.input_files)

    with tf.Session(graph=g) as sess:
        # Load the model from checkpoint.
        saver = tf.train.Saver(var_list=tf.all_variables())
        saver.restore(sess, checkpoint_path)

        itr = 0
        while not reader.eof():
            outputs = reader.pop_batch()
            im_ids, quest_id, im_feat, ans_w2v, quest_ids, ans_ids = outputs
            inputs = post_processing_data(outputs)
            perplexity = sess.run(model.likelihood,
                                  feed_dict=model.fill_feed_dict(inputs))

            # generated = [generated[0]]  # sample 3
            question = to_sentence.index_to_question(quest_ids)
            answer = to_sentence.index_to_answer(ans_ids)

            print('============== %d ============' % itr)
            print('image id: %d, question id: %d' % (im_ids, quest_id))
            print('question\t: %s' % question)
            elems = question.split(' ')
            tmp = ' '.join([
                '%s (%0.2f)' % (w, p)
                for w, p in zip(elems, perplexity.flatten())
            ][:-1])
            print('question\t' + tmp)
            print('answer\t: %s' % answer)
            print('perplexity\t: %0.2f\n' % perplexity.mean())

            itr += 1
def test_rerank_reader():
    reader = RetrievalDataReader(batch_size=1, n_contrast=10, subset='train')
    reader.start()
    outputs = reader.pop_batch()
    im_feat, quest_arr, quest_len, ans_arr, ans_len = outputs
    from inference_utils.question_generator_util import SentenceGenerator
    to_sentence = SentenceGenerator(
        trainset='trainval',
        ans_vocab_file='data/vqa_trainval_question_answer_word_counts.txt',
        quest_vocab_file='data/vqa_trainval_question_answer_word_counts.txt')
    for q_seq, q_len, a_seq, a_len in zip(quest_arr, quest_len, ans_arr,
                                          ans_len):
        q_ = np.array([0] + q_seq[:q_len].tolist() + [0])
        a_ = np.array([0] + a_seq[:a_len].tolist() + [0])
        q = to_sentence.index_to_question(q_)
        a = to_sentence.index_to_answer(a_)
        print('Q: %s' % q)
        print('A: %s\n' % a)
    reader.stop()
def var_vqa_decoding_beam_search(checkpoint_path=None, subset='kpval'):
    model_config = ModelConfig()
    res_file = 'result/quest_vaq_greedy_%s.json' % FLAGS.model_type.upper()
    # Get model
    model_fn = get_model_creation_fn(FLAGS.model_type)
    create_fn = create_reader(FLAGS.model_type, phase='test')

    # Create the vocabulary.
    to_sentence = SentenceGenerator(trainset='trainval')

    # get data reader
    subset = 'kpval'
    reader = create_fn(batch_size=1, subset=subset, version=FLAGS.test_version)

    if checkpoint_path is None:
        ckpt_dir = FLAGS.checkpoint_dir % (FLAGS.version, FLAGS.model_type)
        # ckpt_dir = '/import/vision-ephemeral/fl302/models/v2_kpvaq_VAQ-RL/'
        ckpt = tf.train.get_checkpoint_state(ckpt_dir)
        checkpoint_path = ckpt.model_checkpoint_path

    # Build model
    g = tf.Graph()
    with g.as_default():
        # Build the model.
        model = model_fn(model_config, 'sampling')
        model.build()
        # Restore from checkpoint
        restorer = Restorer(g)
        sess = tf.Session()
        restorer.restore(sess, checkpoint_path)

    num_batches = reader.num_batches

    print('Running beam search inference...')
    results = []
    for i in range(num_batches):
        outputs = reader.get_test_batch()
        # pdb.set_trace()
        if i % 100 == 0:
            print('batch: %d/%d' % (i, num_batches))

        # inference
        images, quest, quest_len, ans, ans_len, quest_ids, image_ids = outputs
        scores, pathes = model.greedy_inference([images, quest, quest_len],
                                                sess)
        scores, pathes = post_process_prediction(scores, pathes)
        pathes, pathes_len = put_to_array(pathes)
        scores, pathes = find_unique_rows(scores, pathes)
        scores, pathes = post_process_prediction(scores, pathes[:, 1:])
        # question = to_sentence.index_to_question(pathes[0])
        # print('%d/%d: %s' % (i, num_batches, question))

        answers = []
        for path in pathes:
            sentence = to_sentence.index_to_answer(path)
            answers.append(sentence)
            # print(sentence)

        res_i = {'question_id': int(quest_ids[0]), 'answers': answers}
        results.append(res_i)

    eval_recall(results)
    return
def ivqa_decoding_beam_search(checkpoint_path=None):
    model_config = ModelConfig()
    method = FLAGS.method
    res_file = 'result/bs_cand_for_vis.json'
    # Get model
    model_fn = get_model_creation_fn('VAQ-Var')
    create_fn = create_reader('VAQ-VVIS', phase='test')

    # Create the vocabulary.
    to_sentence = SentenceGenerator(trainset='trainval',
                                    top_ans_file='../VQA-tensorflow/data/vqa_trainval_top2000_answers.txt')

    # get data reader
    subset = 'kpval'
    reader = create_fn(batch_size=1, subset=subset,
                       version=FLAGS.test_version)

    exemplar = ExemplarLanguageModel()

    if checkpoint_path is None:
        if FLAGS.checkpoint_dir:
            ckpt_dir = FLAGS.checkpoint_dir
        else:
            ckpt_dir = FLAGS.checkpoint_pat % (FLAGS.version, FLAGS.model_type)
        # ckpt_dir = '/import/vision-ephemeral/fl302/models/v2_kpvaq_VAQ-RL/'
        ckpt = tf.train.get_checkpoint_state(ckpt_dir)
        checkpoint_path = ckpt.model_checkpoint_path

    # Build model
    g = tf.Graph()
    with g.as_default():
        # Build the model.ex
        model = model_fn(model_config, 'sampling')
        model.set_num_sampling_points(5000)
        model.build()
        # Restore from checkpoint
        restorer = Restorer(g)
        sess = tf.Session()
        restorer.restore(sess, checkpoint_path)

        # build language model
        language_model = LanguageModel()
        language_model.build()
        language_model.set_cache_dir('test_empty')
        # language_model.set_cache_dir('v1_var_att_lowthresh_cache_restval_VAQ-VarRL')
        language_model.set_session(sess)
        language_model.setup_model()

        # build VQA model
    # vqa_model = N2MNWrapper()
    # vqa_model = MLBWrapper()
    num_batches = reader.num_batches

    quest_ids_to_vis = {5682052: 'bread',
                        965492: 'plane',
                        681282: 'station'}

    print('Running beam search inference...')
    results = []
    batch_vqa_scores = []

    num = FLAGS.max_iters if FLAGS.max_iters > 0 else num_batches
    for i in range(num):

        outputs = reader.get_test_batch()

        # inference
        quest_ids, image_ids = outputs[-2:]
        quest_id_key = int(quest_ids)

        if quest_id_key not in quest_ids_to_vis:
            continue
        # pdb.set_trace()

        im, gt_q, _, top_ans, ans_tokens, ans_len = outputs[:-2]
        # pdb.set_trace()
        if top_ans == 2000:
            continue

        print('\n%d/%d' % (i, num))
        question_id = int(quest_ids[0])
        image_id = int(image_ids[0])

        t1 = time()
        pathes, scores = model.greedy_inference([im, ans_tokens, ans_len], sess)

        # find unique
        ivqa_scores, ivqa_pathes = process_one(scores, pathes)
        t2 = time()
        print('Time for sample generation: %0.2fs' % (t2 - t1))

        # apply language model
        language_model_inputs = wrap_samples_for_language_model([ivqa_pathes],
                                                                pad_token=model.pad_token - 1,
                                                                max_length=20)
        match_gt = exemplar.query(ivqa_pathes)
        legality_scores = language_model.inference(language_model_inputs)
        legality_scores[match_gt] = 1.0
        num_keep = max(100, (legality_scores > 0.1).sum())  # no less than 100
        valid_inds = (-legality_scores).argsort()[:num_keep]
        print('keep: %d/%d' % (num_keep, len(ivqa_pathes)))

        t3 = time()
        print('Time for language model filtration: %0.2fs' % (t3 - t2))

        def token_arr_to_list(arr):
            return arr.flatten().tolist()

        for _pid, idx in enumerate(valid_inds):
            path = ivqa_pathes[idx]
            # sc = vqa_scores[idx]
            sentence = to_sentence.index_to_question(path)
            aug_quest_id = question_id * 1000 + _pid
            res_i = {'image_id': int(image_id),
                     'aug_id': aug_quest_id,
                     'question_id': question_id,
                     'target': sentence,
                     'top_ans_id': int(top_ans),
                     'question': to_sentence.index_to_question(token_arr_to_list(gt_q)),
                     'answer': to_sentence.index_to_answer(token_arr_to_list(ans_tokens))}
            results.append(res_i)

    save_json(res_file, results)
    return None
Exemple #8
0
def ivqa_decoding_beam_search(checkpoint_path=None, subset='kpval'):
    model_config = ModelConfig()
    res_file = 'result/quest_vaq_greedy_%s.json' % FLAGS.model_type.upper()
    # Get model
    model_fn = get_model_creation_fn('VAQ-Var')
    create_fn = create_reader('VAQ-Var', phase='test')
    writer = ExperimentWriter('latex/examples_noimage_tmp')

    # Create the vocabulary.
    to_sentence = SentenceGenerator(trainset='trainval')

    # get data reader
    subset = 'kpval'
    reader = create_fn(batch_size=1, subset=subset, version=FLAGS.test_version)

    if checkpoint_path is None:
        # ckpt_dir = FLAGS.checkpoint_dir % (FLAGS.version, FLAGS.model_type)
        ckpt_dir = 'model/v1_var_att_noimage_cache_restval_VAQ-VarRL'
        ckpt = tf.train.get_checkpoint_state(ckpt_dir)
        checkpoint_path = ckpt.model_checkpoint_path

    # Build model
    g = tf.Graph()
    with g.as_default():
        # Build the model.
        model = model_fn(model_config, 'sampling')
        model.build()
        # Restore from checkpoint
        restorer = Restorer(g)
        sess = tf.Session()
        restorer.restore(sess, checkpoint_path)

    num_batches = reader.num_batches

    print('Running beam search inference...')
    results = []
    for i in range(num_batches):
        outputs = reader.get_test_batch()

        # inference
        quest_ids, image_ids = outputs[-2:]
        scores, pathes = model.greedy_inference(outputs[:-2], sess)
        scores, pathes = post_process_prediction(scores, pathes)
        pathes, pathes_len = put_to_array(pathes)
        scores, pathes = find_unique_rows(scores, pathes)
        scores, pathes = post_process_prediction(scores, pathes[:, 1:])
        # question = to_sentence.index_to_question(pathes[0])
        # print('%d/%d: %s' % (i, num_batches, question))

        # show image
        os.system('clear')
        im_file = '%s2014/COCO_%s2014_%012d.jpg' % ('val', 'val', image_ids[0])
        im_path = os.path.join(IM_ROOT, im_file)
        # im = imread(im_path)
        # plt.imshow(im)
        ans, ans_len = outputs[1:1 + 2]
        answers = extract_gt(ans, ans_len)
        answer = to_sentence.index_to_answer(answers[0])
        # plt.title(answer)

        print('Answer: %s' % answer)
        questions = []
        for path in pathes:
            sentence = to_sentence.index_to_question(path)
            questions.append(sentence)
            print(sentence)
        # plt.show()
        writer.add_result(image_ids[0], quest_ids[0], im_path, answer,
                          questions)

        for quest_id, image_id, path in zip(quest_ids, image_ids, pathes):
            sentence = to_sentence.index_to_question(path)
            res_i = {
                'image_id': int(image_id),
                'question_id': int(quest_id),
                'question': sentence
            }
            results.append(res_i)

        if i == 40:
            break

    writer.render()
    return

    save_json(res_file, results)
    return res_file
Exemple #9
0
def test():
    # Build the inference graph.
    config = QuestionGeneratorConfig()
    reader = TFRecordDataFetcher(FLAGS.input_files, config.image_feature_key)

    # Create model creator
    model_creator = create_model_fn(FLAGS.model_type)

    # create multiple choice question manger
    mc_manager = MultiChoiceQuestionManger(
        subset='trainval', answer_coding=model_creator.ans_coding)

    # Create reader post-processing function
    reader_post_proc_fn = build_mc_reader_proc_fn(model_creator.ans_coding)

    g = tf.Graph()
    ckpt = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir)
    checkpoint_path = ckpt.model_checkpoint_path
    print(checkpoint_path)
    with g.as_default():
        model = model_creator(config, phase='evaluate')
        model.build()
    # g.finalize()

    # Create the vocabulary.
    to_sentence = SentenceGenerator(trainset=FLAGS.model_trainset)

    filenames = []
    for file_pattern in FLAGS.input_files.split(","):
        filenames.extend(tf.gfile.Glob(file_pattern))
    tf.logging.set_verbosity(tf.logging.INFO)
    tf.logging.info("Running caption generation on %d files matching %s",
                    len(filenames), FLAGS.input_files)

    result, rescore_data, state_rescore_data = [], [], []
    with tf.Session(graph=g) as sess:
        # Load the model from checkpoint.
        saver = tf.train.Saver(var_list=tf.all_variables())
        saver.restore(sess, checkpoint_path)

        itr = 0
        while not reader.eof():
            if itr > 50000:  # cache at most 50k questions
                break
            outputs = reader.pop_batch()
            im_ids, quest_id, im_feat, ans_w2v, quest_ids, ans_ids = outputs
            mc_ans, mc_coding = mc_manager.get_candidate_answer_and_word_coding(
                quest_id)
            inputs = reader_post_proc_fn(outputs, mc_coding)
            perplexity, state = sess.run(
                [model.likelihood, model.final_decoder_state],
                feed_dict=model.fill_feed_dict(inputs))
            perplexity = perplexity.reshape(inputs[-1].shape)
            loss = perplexity[:, :-1].mean(axis=1)

            # generated = [generated[0]]  # sample 3
            question = to_sentence.index_to_question(quest_ids)
            answer = to_sentence.index_to_answer(ans_ids)
            top1_mc_ans = mc_ans[loss.argmin()]
            result.append({u'answer': top1_mc_ans, u'question_id': quest_id})

            # add hidden state saver
            label = mc_manager.get_binary_label(quest_id)
            state_sv = {'quest_id': quest_id, 'states': state, 'label': label}
            state_rescore_data.append(state_sv)

            if itr % 100 == 0:
                print('============== %d ============' % itr)
                print('image id: %d, question id: %d' % (im_ids, quest_id))
                print('question\t: %s' % question)
                print('answer\t: %s' % answer)
                top_k_ids = loss.argsort()[:3].tolist()
                for i, idx in enumerate(top_k_ids):
                    t_mc_ans = mc_ans[idx]
                    print('VAQ answer <%d>\t: %s (%0.2f)' %
                          (i, t_mc_ans, loss[idx]))

            itr += 1
            # save information for train classifier
            mc_label = np.array([a == answer for a in mc_ans],
                                dtype=np.float32)
            quest_target = inputs[-2]
            datum = {
                'quest_seq': quest_target,
                'perplex': perplexity,
                'label': mc_label,
                'quest_id': quest_id
            }
            rescore_data.append(datum)

        quest_ids = [res[u'question_id'] for res in result]
        # save results
        tf.logging.info('Saving results')
        res_file = FLAGS.result_file % get_model_iteration(checkpoint_path)
        json.dump(result, open(res_file, 'w'))
        tf.logging.info('Saving rescore data...')
        from util import pickle
        # pickle('data/rescore_dev.pkl', rescore_data)
        pickle('data/rescore_state_dev.pkl', state_rescore_data)
        tf.logging.info('Done!')
        return res_file, quest_ids
Exemple #10
0
def var_vqa_decoding_beam_search(checkpoint_path=None, subset='kpval'):
    model_config = ModelConfig()
    res_file = 'result/quest_vaq_greedy_%s.json' % FLAGS.model_type.upper()
    # Get model
    model_fn = get_model_creation_fn(FLAGS.model_type)
    create_fn = create_reader('V7W-VarDS', phase='test')
    writer = ExperimentWriter('latex/v7w_%s' % FLAGS.model_type.lower())

    # Create the vocabulary.
    to_sentence = SentenceGenerator(
        trainset='train',
        ans_vocab_file='data2/v7w_train_answer_word_counts.txt',
        quest_vocab_file='data2/v7w_train_question_word_counts.txt',
        top_ans_file='data2/v7w_train_top2000_answers.txt')

    # get data reader
    subset = 'val'
    reader = create_fn(batch_size=1, subset=subset, version=FLAGS.test_version)

    if checkpoint_path is None:
        ckpt_dir = FLAGS.checkpoint_dir % (FLAGS.trainset, FLAGS.model_type)
        # ckpt_dir = '/import/vision-ephemeral/fl302/models/v2_kpvaq_VAQ-RL/'
        ckpt = tf.train.get_checkpoint_state(ckpt_dir)
        checkpoint_path = ckpt.model_checkpoint_path

    # Build model
    g = tf.Graph()
    with g.as_default():
        # Build the model.
        model = model_fn(model_config, 'sampling')
        model.build()
        # Restore from checkpoint
        restorer = Restorer(g)
        sess = tf.Session()
        restorer.restore(sess, checkpoint_path)

    num_batches = reader.num_batches

    print('Running beam search inference...')
    results = []
    for i in range(num_batches):
        outputs = reader.get_test_batch()
        # pdb.set_trace()

        # inference
        images, quest, quest_len, ans, ans_len, quest_ids, image_ids = outputs
        scores, pathes = model.greedy_inference([images, quest, quest_len],
                                                sess)
        scores, pathes = post_process_prediction(scores, pathes)
        pathes, pathes_len = put_to_array(pathes)
        scores, pathes = find_unique_rows(scores, pathes)
        scores, pathes = post_process_prediction(scores, pathes[:, 1:])
        # question = to_sentence.index_to_question(pathes[0])
        # print('%d/%d: %s' % (i, num_batches, question))

        # show image
        os.system('clear')
        image_id = image_ids[0]
        im_path = _get_vg_image_root(image_id)
        # im = imread(im_path)
        # plt.imshow(im)
        questions = extract_gt(quest, quest_len)
        question = to_sentence.index_to_question(questions[0])
        print('Question: %s' % question)

        answers = extract_gt(ans, ans_len)
        answer = to_sentence.index_to_answer(answers[0])
        # plt.title(answer)

        print('Answer: %s' % answer)
        answers = []
        for path in pathes:
            sentence = to_sentence.index_to_answer(path)
            answers.append(sentence)
            print(sentence)
        # plt.show()
        qa = '%s - %s' % (question, answer)
        writer.add_result(image_ids[0], quest_ids[0], im_path, qa, answers)

        for quest_id, image_id, path in zip(quest_ids, image_ids, pathes):
            sentence = to_sentence.index_to_question(path)
            res_i = {
                'image_id': int(image_id),
                'question_id': int(quest_id),
                'question': sentence
            }
            results.append(res_i)

        if i == 40:
            break

    writer.render()
    return
Exemple #11
0
def test(T=3.0, num_cands=10):
    # Build the inference graph.
    cand_file = 'result/vqa_cands.json'
    config = QuestionGeneratorConfig()
    reader = TFRecordDataFetcher(FLAGS.input_files,
                                 config.image_feature_key)

    # Create model creator
    model_creator = create_model_fn(FLAGS.model_type)

    # create multiple choice question manger
    oe_manager = CandidateAnswerManager(cand_file, max_num_cands=10)

    # Create reader post-processing function
    reader_post_proc_fn = build_mc_reader_proc_fn(model_creator.ans_coding)

    g = tf.Graph()
    ckpt = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir)
    checkpoint_path = ckpt.model_checkpoint_path
    print(checkpoint_path)
    with g.as_default():
        model = model_creator(config, phase='evaluate')
        model.build()

    # Create the vocabulary.
    to_sentence = SentenceGenerator(trainset=FLAGS.model_trainset)

    filenames = []
    for file_pattern in FLAGS.input_files.split(","):
        filenames.extend(tf.gfile.Glob(file_pattern))
    tf.logging.set_verbosity(tf.logging.INFO)
    tf.logging.info("Running caption generation on %d files matching %s",
                    len(filenames), FLAGS.input_files)

    result = []
    with tf.Session(graph=g) as sess:
        # Load the model from checkpoint.
        saver = tf.train.Saver(var_list=tf.all_variables())
        saver.restore(sess, checkpoint_path)

        itr = 0
        while not reader.eof():
            outputs = reader.pop_batch()
            im_ids, quest_id, im_feat, ans_w2v, quest_ids, ans_ids = outputs
            oe_ans, oe_coding, scores = oe_manager.get_answer_sequence(quest_id)
            inputs = reader_post_proc_fn(outputs, oe_coding)
            perplexity, state = sess.run([model.likelihood, model.final_decoder_state],
                                         feed_dict=model.fill_feed_dict(inputs))
            perplexity = perplexity.reshape(inputs[-1].shape)
            loss = perplexity[:, :-1].mean(axis=1)
            weight = np.exp(-loss * T)
            weight = weight / weight.sum()  # l1 normalise
            score = scores * weight
            score = score[:num_cands]

            question = to_sentence.index_to_question(quest_ids)
            answer = to_sentence.index_to_answer(ans_ids)
            top1_ans = oe_ans[score.argmax()]
            result.append({u'answer': top1_ans, u'question_id': quest_id})

            if itr % 100 == 0:
                print('============== %d ============' % itr)
                print('image id: %d, question id: %d' % (im_ids, quest_id))
                print('question\t: %s' % question)
                print('answer\t: %s' % answer)
                top_k_ids = (-score).argsort()[:3].tolist()
                print('VQA answer\t: %s' % oe_ans[0])
                for i, idx in enumerate(top_k_ids):
                    t_mc_ans = oe_ans[idx]
                    print('VAQ answer <%d>\t: %s (%0.2f)' % (i, t_mc_ans, weight[idx]))

            itr += 1

        quest_ids = [res[u'question_id'] for res in result]
        # save results
        tf.logging.info('Saving results')
        res_file = FLAGS.result_file % get_model_iteration(checkpoint_path)
        json.dump(result, open(res_file, 'w'))
        return res_file, quest_ids