def test(checkpoint_path=None):
    subsets = ['kpval', 'kptest', 'kprestval']

    quest_ids = []
    result = []

    config = ModelConfig()
    config.sample_negative = FLAGS.sample_negative
    config.use_fb_bn = FLAGS.use_fb_bn
    # Get model function
    model_fn = get_model_creation_fn(FLAGS.model_type)

    # build and restore model
    model = model_fn(config, phase='test')
    model.build()

    sess = tf.Session(graph=tf.get_default_graph())
    tf.logging.info('Restore from model %s' %
                    os.path.basename(checkpoint_path))
    saver = tf.train.Saver()
    saver.restore(sess, checkpoint_path)

    for subset in subsets:
        _quest_ids, _result = test_worker(model, sess, subset)
        quest_ids += _quest_ids
        result += _result

    quest_ids = np.concatenate(quest_ids)
    # save results
    tf.logging.info('Saving results')
    res_file = FLAGS.result_format % (FLAGS.version, 'val')
    json.dump(result, open(res_file, 'w'))
    tf.logging.info('Done!')
    tf.logging.info('#Num eval samples %d' % len(result))
    return res_file, quest_ids
Esempio n. 2
0
def test(checkpoint_path=None):
    batch_size = 64
    config = ModelConfig()
    config.sample_negative = FLAGS.sample_negative
    config.use_fb_bn = FLAGS.use_fb_bn
    # Get model function
    model_fn = get_model_creation_fn(FLAGS.model_type)

    # build data reader
    reader = TestReader(batch_size=batch_size,
                        subset=TEST_SET,
                        use_fb_data=FLAGS.use_fb_data)
    if checkpoint_path is None:
        ckpt = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir %
                                             (FLAGS.version, FLAGS.model_type))
        checkpoint_path = ckpt.model_checkpoint_path
    print(checkpoint_path)

    # build and restore model
    model = model_fn(config, phase='test')
    model.build()
    prob = model.prob

    sess = tf.Session(graph=tf.get_default_graph())
    tf.logging.info('Restore from model %s' %
                    os.path.basename(checkpoint_path))
    saver = tf.train.Saver()
    saver.restore(sess, checkpoint_path)

    quest_ids = []
    result = []

    print('Running inference on split %s...' % TEST_SET)
    for i in range(reader.num_batches):
        if i % 10 == 0:
            update_progress(i / float(reader.num_batches))
        outputs = reader.get_test_batch()
        mc_scores = sess.run(model._logits,
                             feed_dict=model.fill_feed_dict(outputs[:-3]))
        choice_idx = np.argmax(mc_scores, axis=1)

        cands, _qids, image_ids = outputs[-3:]
        for qid, cid, mcs in zip(_qids, choice_idx, cands):
            answer = mcs['cands'][cid]
            assert (mcs['quest_id'] == qid)
            result.append({u'answer': answer, u'question_id': qid})

        quest_ids.append(_qids)

    quest_ids = np.concatenate(quest_ids)

    # save results
    tf.logging.info('Saving results')
    res_file = FLAGS.result_format % (FLAGS.version, TEST_SET)
    json.dump(result, open(res_file, 'w'))
    tf.logging.info('Done!')
    tf.logging.info('#Num eval samples %d' % len(result))
    return res_file, quest_ids
Esempio n. 3
0
def test(checkpoint_path=None):
    batch_size = 100
    config = ModelConfig()
    # Get model function
    # model_fn = get_model_creation_fn(FLAGS.model_type)

    # build data reader
    reader = AttentionFetcher(batch_size=batch_size,
                              subset=TEST_SET,
                              feat_type=config.feat_type,
                              version=FLAGS.version)
    if checkpoint_path is None:
        ckpt = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir %
                                             (FLAGS.version, FLAGS.model_type))
        checkpoint_path = ckpt.model_checkpoint_path
    print(checkpoint_path)

    # build and restore model
    model = model_fn(config, phase='test')
    # model.set_agent_ids([0])
    model.build()
    prob = model.prob

    sess = tf.Session(graph=tf.get_default_graph())
    tf.logging.info('Restore from model %s' %
                    os.path.basename(checkpoint_path))
    saver = tf.train.Saver()
    saver.restore(sess, checkpoint_path)

    # Create the vocabulary.
    top_ans_file = '../VQA-tensorflow/data/vqa_trainval_top2000_answers.txt'
    to_sentence = SentenceGenerator(trainset='trainval',
                                    top_ans_file=top_ans_file)
    # to_sentence = SentenceGenerator(trainset='trainval')

    results = []

    print('Running inference on split %s...' % TEST_SET)
    for i in range(reader.num_batches):
        if i % 10 == 0:
            update_progress(i / float(reader.num_batches))
        outputs = reader.get_test_batch()
        generated_ans = sess.run(prob,
                                 feed_dict=model.fill_feed_dict(outputs[:-2]))
        generated_ans[:, -1] = 0
        ans_cand_ids = np.argsort(-generated_ans, axis=1)

        quest_ids = outputs[-2]

        for quest_id, ids in zip(quest_ids, ans_cand_ids):
            answers = []
            for k in range(_K):
                aid = ids[k]
                ans = to_sentence.index_to_top_answer(aid)
                answers.append(ans)
            res_i = {'question_id': int(quest_id), 'answers': answers}
            results.append(res_i)

    eval_recall(results)
Esempio n. 4
0
def test(checkpoint_path=None):
    batch_size = 100
    config = ModelConfig()
    # Get model function
    model_fn = get_model_creation_fn(FLAGS.model_type)

    # build data reader
    reader = AttentionFetcher(batch_size=batch_size, subset=TEST_SET,
                              feat_type=config.feat_type, version=FLAGS.version)
    if checkpoint_path is None:
        ckpt = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir % (FLAGS.version,
                                                                     FLAGS.model_type))
        checkpoint_path = ckpt.model_checkpoint_path
    print(checkpoint_path)

    # build and restore model
    model = model_fn(config, phase='test')
    model.build()
    prob = model.prob

    sess = tf.Session(graph=tf.get_default_graph())
    tf.logging.info('Restore from model %s' % os.path.basename(checkpoint_path))
    saver = tf.train.Saver()
    saver.restore(sess, checkpoint_path)

    # Create the vocabulary.
    top_ans_file = '../VQA-tensorflow/data/vqa_trainval_top2000_answers.txt'
    to_sentence = SentenceGenerator(trainset='trainval',
                                    top_ans_file=top_ans_file)

    ans_ids = []
    quest_ids = []

    print('Running inference on split %s...' % TEST_SET)
    for i in range(reader.num_batches):
        if i % 10 == 0:
            update_progress(i / float(reader.num_batches))
        outputs = reader.get_test_batch()
        generated_ans = sess.run(
            prob, feed_dict=model.fill_feed_dict(outputs[:-2]))
        generated_ans[:, -1] = 0
        top_ans = np.argmax(generated_ans, axis=1)

        ans_ids.append(top_ans)
        quest_id = outputs[-2]
        quest_ids.append(quest_id)

    quest_ids = np.concatenate(quest_ids)
    ans_ids = np.concatenate(ans_ids)
    result = [{u'answer': to_sentence.index_to_top_answer(aid),
               u'question_id': qid} for aid, qid in zip(ans_ids, quest_ids)]

    # save results
    tf.logging.info('Saving results')
    res_file = FLAGS.result_format % (FLAGS.version, TEST_SET)
    json.dump(result, open(res_file, 'w'))
    tf.logging.info('Done!')
    tf.logging.info('#Num eval samples %d' % len(result))
    return res_file, quest_ids
Esempio n. 5
0
def test(checkpoint_path=None):
    batch_size = 4
    config = ModelConfig()
    # Get model function
    model_fn = get_model_creation_fn(FLAGS.model_type)

    # build data reader
    reader = AttentionFetcher(batch_size=batch_size,
                              subset=TEST_SET,
                              feat_type=config.feat_type,
                              version=FLAGS.version)
    if checkpoint_path is None:
        ckpt = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir %
                                             (FLAGS.version, FLAGS.model_type))
        checkpoint_path = ckpt.model_checkpoint_path
    print(checkpoint_path)

    # build and restore model
    model = model_fn(config, phase='test')
    model.build()
    prob = model.prob

    sess = tf.Session(graph=tf.get_default_graph())
    tf.logging.info('Restore from model %s' %
                    os.path.basename(checkpoint_path))
    saver = tf.train.Saver()
    saver.restore(sess, checkpoint_path)

    # Create the vocabulary.
    to_sentence = SentenceGenerator(
        trainset='trainval',
        top_ans_file='../VQA-tensorflow/data/vqa_trainval_top2000_answers.txt')

    ans_ids = []
    quest_ids = []

    print('Running inference on split %s...' % TEST_SET)
    for i in range(reader.num_batches):
        if i % 10 == 0:
            update_progress(i / float(reader.num_batches))
        outputs = reader.get_test_batch()
        generated_ans = sess.run(prob,
                                 feed_dict=model.fill_feed_dict(outputs[:-2]))
        generated_ans[:, -1] = 0
        top_ans = np.argmax(generated_ans, axis=1)

        ans_ids.append(top_ans)
        quest_id = outputs[-2]
        quest_ids.append(quest_id)

    quest_ids = np.concatenate(quest_ids)
    ans_ids = np.concatenate(ans_ids)
    gt = reader._answer
    n1, n2 = (gt == ans_ids).sum(), gt.size
    acc = n1 / float(n2)
    print('\nAcc: %0.2f, %d/%d' % (acc * 100., n1, n2))
    return acc
Esempio n. 6
0
def test(checkpoint_path=None):
    batch_size = 100
    config = ModelConfig()
    # Get model function
    # model_fn = get_model_creation_fn(FLAGS.model_type)

    # build data reader
    reader = AttentionFetcher(batch_size=batch_size,
                              subset=TEST_SET,
                              feat_type=config.feat_type,
                              version=FLAGS.version)
    if checkpoint_path is None:
        print(FLAGS.checkpoint_dir % (FLAGS.version, FLAGS.model_type))
        ckpt = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir %
                                             (FLAGS.version, FLAGS.model_type))
        checkpoint_path = ckpt.model_checkpoint_path
    print(checkpoint_path)

    # build and restore model
    model = model_fn(config, phase='test')
    # model.set_agent_ids([0])
    model.build()
    prob = model.qrd_prob

    gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.2)
    sess = tf.Session(graph=tf.get_default_graph(),
                      config=tf.ConfigProto(gpu_options=gpu_options))
    tf.logging.info('Restore from model %s' %
                    os.path.basename(checkpoint_path))
    saver = tf.train.Saver()
    saver.restore(sess, checkpoint_path)

    gts = []
    preds = []

    print('Running inference on split %s...' % TEST_SET)
    for i in range(reader.num_batches):
        if i % 10 == 0:
            update_progress(i / float(reader.num_batches))
        outputs = reader.get_test_batch()
        scores = sess.run(prob, feed_dict=model.fill_feed_dict(outputs))
        preds.append(scores.flatten())
        gts.append(outputs[-1])

    gts = np.concatenate(gts)
    preds = np.concatenate(preds)
    from scipy.io import savemat
    from sklearn.metrics import average_precision_score
    sv_file_name = os.path.basename(checkpoint_path)
    savemat('result/predictions_%s.mat' % sv_file_name, {
        'gt': gts,
        'preds': preds
    })
    ap = average_precision_score(1.0 - gts, 1.0 - preds)

    return float(ap)
Esempio n. 7
0
def create_model():
    from models.vqa_soft_attention import AttentionModel
    from vqa_config import ModelConfig
    model = AttentionModel(ModelConfig(), phase='test_broadcast')
    model.build()

    checkpoint_path = 'model/v1_vqa_VQA/v1_vqa_VQA_best2/model.ckpt-135000'
    sess = tf.Session(graph=tf.get_default_graph())
    tf.logging.info('Restore from model %s' %
                    os.path.basename(checkpoint_path))
    saver = tf.train.Saver()
    saver.restore(sess, checkpoint_path)
    return sess, model
Esempio n. 8
0
def test(checkpoint_path=None):
    batch_size = 100
    config = ModelConfig()
    # Get model function
    # model_fn = get_model_creation_fn(FLAGS.model_type)

    # build data reader
    reader = Reader(batch_size=batch_size,
                    subset=TEST_SET,
                    feat_type=config.feat_type,
                    version=FLAGS.version)
    if checkpoint_path is None:
        ckpt = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir % (FLAGS.version,
                                                                     FLAGS.model_type))
        checkpoint_path = ckpt.model_checkpoint_path
    print(checkpoint_path)

    # build and restore model
    model = model_fn(config, phase='test')
    # model.set_agent_ids([0])
    model.build()
    prob = model.prob

    sess = tf.Session(graph=tf.get_default_graph())
    tf.logging.info('Restore from model %s' % os.path.basename(checkpoint_path))
    saver = tf.train.Saver()
    saver.restore(sess, checkpoint_path)

    # Create the vocabulary.
    quest_ids = []
    ans_preds = []
    gt_labels = []

    print('Running inference on split %s...' % TEST_SET)
    for i in range(reader.num_batches):
        if i % 10 == 0:
            update_progress(i / float(reader.num_batches))
        outputs = reader.get_test_batch()
        generated_ans = sess.run(
            prob, feed_dict=model.fill_feed_dict(outputs[:-2]))
        _gt_labels = outputs[1]
        gt_labels.append(_gt_labels)
        ans_preds.append(generated_ans)

        quest_id = outputs[-2]
        quest_ids.append(quest_id)

    ans_preds = np.concatenate(ans_preds)
    gt_labels = np.concatenate(gt_labels)
    return evaluate_result(ans_preds, gt_labels)
 def __init__(self, ckpt_file='model/v1_vqa_VQA/v1_vqa_VQA_best2/model.ckpt-135000',
              use_dis_reward=False):
     self.g = tf.Graph()
     self.ckpt_file = ckpt_file
     from models.vqa_soft_attention import AttentionModel
     from vqa_config import ModelConfig
     config = ModelConfig()
     self.ans2id = AnswerTokenToTopAnswer()
     self.use_dis_reward = use_dis_reward
     with self.g.as_default():
         self.sess = tf.Session()
         self.model = AttentionModel(config, phase='test_broadcast')
         self.model.build()
         vars = tf.trainable_variables()
         self.saver = tf.train.Saver(var_list=vars)
         self.saver.restore(self.sess, ckpt_file)
Esempio n. 10
0
 def __init__(self, ckpt_file='/usr/data/fl302/code/inverse_vqa/model/mlb_attention_v2/model.ckpt-170000',
              use_dis_reward=False):
     self.g = tf.Graph()
     self.ckpt_file = ckpt_file
     self.v1tov2 = TopAnswerVersionConverter()
     from models.vqa_soft_attention_v2 import AttentionModel
     from vqa_config import ModelConfig
     config = ModelConfig()
     self.ans2id = AnswerTokenToTopAnswer()
     self.use_dis_reward = use_dis_reward
     with self.g.as_default():
         self.sess = tf.Session()
         self.model = AttentionModel(config, phase='test_broadcast')
         self.model.build()
         vars = tf.trainable_variables()
         self.saver = tf.train.Saver(var_list=vars)
         self.saver.restore(self.sess, ckpt_file)
Esempio n. 11
0
    def __init__(
            self,
            ckpt_file='model/v1_vqa_VQA/v1_vqa_VQA_best2/model.ckpt-135000'):
        BaseVQAModel.__init__(self)
        self.g = tf.Graph()
        self.ckpt_file = ckpt_file
        from models.vqa_soft_attention import AttentionModel
        from vqa_config import ModelConfig
        config = ModelConfig()
        self.name = ' ------- MLB-attention ------- '

        with self.g.as_default():
            self.sess = tf.Session()
            self.model = AttentionModel(config, phase='test_broadcast')
            self.model.build()
            vars = tf.trainable_variables()
            self.saver = tf.train.Saver(var_list=vars)
            self.saver.restore(self.sess, ckpt_file)
Esempio n. 12
0
    def __init__(self,
                 ckpt_file='model/kprestval_VQA-BaseNorm/model.ckpt-26000'):
        BaseVQAModel.__init__(self)
        self.top_k = 2
        self.g = tf.Graph()
        self.ckpt_file = ckpt_file
        from models.vqa_base import BaseModel
        from vqa_config import ModelConfig
        config = ModelConfig()
        self._subset = 'test'
        self._year = 2015
        self.name = ' ------- DeeperLSTM ------- '

        with self.g.as_default():
            self.sess = tf.Session()
            self.model = BaseModel(config, phase='test')
            self.model.build()
            vars = tf.trainable_variables()
            self.saver = tf.train.Saver(var_list=vars)
            self.saver.restore(self.sess, ckpt_file)
Esempio n. 13
0
 def __init__(self, ckpt_file='', use_dis_reward=False,
              use_attention_model=False):
     self.g = tf.Graph()
     self.ckpt_file = ckpt_file
     self.use_attention_model = use_attention_model
     from models.vqa_base import BaseModel
     from vqa_config import ModelConfig
     config = ModelConfig()
     self.ans2id = AnswerTokenToTopAnswer()
     self.use_dis_reward = use_dis_reward
     with self.g.as_default():
         self.sess = tf.Session()
         if self.use_attention_model:
             self.model = AttentionModel(config, phase='test')
             self.model.build()
         else:
             self.model = BaseModel(config, phase='test')
             self.model.build()
         vars = tf.trainable_variables()
         self.saver = tf.train.Saver(var_list=vars)
         self.saver.restore(self.sess, ckpt_file)
import tensorflow as tf
# from models.model_creater import get_model_creation_fn
from models.vqa_adversary import BaseModel

# model_fn = get_model_creation_fn('LM')
from vqa_config import ModelConfig

model = BaseModel(ModelConfig(), phase='train')
model.build()

sess = tf.Session()
model.set_session(sess)
model.setup_model()

import numpy as np

batch_size = 8
quest = np.random.randint(low=0, high=1000,
                          size=(batch_size, 20)).astype(np.int32)
quest_len = np.ones(shape=(batch_size, ), dtype=np.int32)
images = np.random.rand(batch_size, 2048)
labels = np.random.randint(low=0,
                           high=2000,
                           size=(batch_size, ),
                           dtype=np.int32)
prob = model.inference([images, quest, quest_len])
print(prob)
loss = model.trainstep([images, quest, quest_len, labels])
print(loss)
Esempio n. 15
0
def test(checkpoint_path=None):
    batch_size = 100
    config = ModelConfig()
    # Get model function
    # model_fn = get_model_creation_fn(FLAGS.model_type)
    _model_suffix = 'var_' if FLAGS.use_var else ''

    # build data reader
    reader = AttentionFetcher(batch_size=batch_size,
                              subset=FLAGS.testset,
                              feat_type=config.feat_type,
                              version=FLAGS.version,
                              var_suffix=_model_suffix)
    if checkpoint_path is None:
        ckpt = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir %
                                             (FLAGS.version, FLAGS.model_type))
        checkpoint_path = ckpt.model_checkpoint_path
    print(checkpoint_path)

    # build and restore model
    model = model_fn(config, phase='test')
    # model.set_agent_ids([0])
    model.build()
    prob = model.prob

    sess = tf.Session(graph=tf.get_default_graph())
    tf.logging.info('Restore from model %s' %
                    os.path.basename(checkpoint_path))
    saver = tf.train.Saver()
    saver.restore(sess, checkpoint_path)

    # Create the vocabulary.
    # top_ans_file = '../VQA-tensorflow/data/vqa_trainval_top2000_answers.txt'
    # to_sentence = SentenceGenerator(trainset='trainval',
    #                                 top_ans_file=top_ans_file)
    # to_sentence = SentenceGenerator(trainset='trainval')

    ans_ids = []
    ans_scores = []
    gt_scores = []
    quest_ids = []

    print('Running inference on split %s...' % FLAGS.testset)
    for i in range(reader.num_batches):
        if i % 10 == 0:
            update_progress(i / float(reader.num_batches))
        outputs = reader.get_test_batch()
        generated_ans = sess.run(prob,
                                 feed_dict=model.fill_feed_dict(outputs[:-2]))
        _gt_labels = outputs[3]
        _this_batch_size = _gt_labels.size
        _gt_scores = generated_ans[np.arange(_this_batch_size, ), _gt_labels]
        gt_scores.append(_gt_scores)
        generated_ans[:, -1] = 0
        top_ans = np.argmax(generated_ans, axis=1)
        top_scores = np.max(generated_ans, axis=1)

        ans_ids.append(top_ans)
        ans_scores.append(top_scores)
        quest_id = outputs[-2]
        quest_ids.append(quest_id)

    quest_ids = np.concatenate(quest_ids)
    ans_ids = np.concatenate(ans_ids)
    ans_scores = np.concatenate(ans_scores)
    gt_scores = np.concatenate(gt_scores)

    # save results
    tf.logging.info('Saving results')
    # res_file = FLAGS.result_format % (FLAGS.version, FLAGS.testset)
    from util import save_hdf5
    save_hdf5(
        'data4/%sv2qa_%s_qa_scores.data' % (_model_suffix, FLAGS.testset), {
            'ext_quest_ids': quest_ids,
            'ext_cand_scores': gt_scores,
            'ext_cand_pred_labels': ans_ids,
            'ext_cand_pred_scores': ans_scores
        })
def test(checkpoint_path=None):
    batch_size = 40
    config = ModelConfig()
    config.convert = True
    config.ivqa_rerank = True  # VQA baseline or re-rank
    config.loss_type = FLAGS.loss_type
    # Get model function
    model_fn = get_model_creation_fn(FLAGS.model_type)
    # ana_ctx = RerankAnalysiser()

    # build data reader
    reader_fn = create_reader(FLAGS.model_type, phase='test')
    reader = reader_fn(batch_size=batch_size, subset='kp%s' % FLAGS.testset,
                       version=FLAGS.version)
    if checkpoint_path is None:
        ckpt = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir % (FLAGS.version,
                                                                     FLAGS.model_type))
        checkpoint_path = ckpt.model_checkpoint_path
    print(checkpoint_path)

    # build and restore model
    model = model_fn(config, phase='evaluate')
    model.build()
    # prob = model.prob

    sess = tf.Session(graph=tf.get_default_graph())
    tf.logging.info('Restore from model %s' % os.path.basename(checkpoint_path))
    if FLAGS.restore:
        saver = tf.train.Saver()
        saver.restore(sess, checkpoint_path)
    else:
        sess.run(tf.initialize_all_variables())
        model.init_fn(sess)

    # Create the vocabulary.
    to_sentence = SentenceGenerator(trainset='trainval')

    ans_ids = []
    quest_ids = []

    b_rerank_scores = []
    b_vqa_scores = []
    b_cand_labels = []
    print('Running inference on split %s...' % FLAGS.testset)
    for i in range(reader.num_batches):
        if i % 10 == 0:
            update_progress(i / float(reader.num_batches))
        outputs = reader.get_test_batch()
        model_preds = model.inference_rerank_vqa(outputs[:4], sess)
        score, top_ans, _, _, _ = model_preds
        ivqa_score, ivqa_top_ans, ivqa_scores, vqa_top_ans, vqa_scores = model_preds
        b_rerank_scores.append(ivqa_scores)
        b_vqa_scores.append(vqa_scores)
        b_cand_labels.append(vqa_top_ans)
        # if i > 100:
        #     break
        # ana_ctx.update(outputs, model_preds)

        ans_ids.append(top_ans)
        quest_id = outputs[-2]
        quest_ids.append(quest_id)
    # save preds
    b_rerank_scores = np.concatenate(b_rerank_scores, axis=0)
    b_vqa_scores = np.concatenate(b_vqa_scores, axis=0)
    b_cand_labels = np.concatenate(b_cand_labels, axis=0)
    quest_ids = np.concatenate(quest_ids)
    from util import save_hdf5
    save_hdf5('data/rerank_kptest.h5', {'ivqa': b_rerank_scores,
                                         'vqa': b_vqa_scores,
                                         'cands': b_cand_labels,
                                         'quest_ids': quest_ids})

    # ana_ctx.compute_accuracy()

    ans_ids = np.concatenate(ans_ids)
    result = [{u'answer': to_sentence.index_to_top_answer(aid),
               u'question_id': qid} for aid, qid in zip(ans_ids, quest_ids)]

    # save results
    tf.logging.info('Saving results')
    res_file = FLAGS.result_format % (FLAGS.version, FLAGS.testset)
    json.dump(result, open(res_file, 'w'))
    tf.logging.info('Done!')
    tf.logging.info('#Num eval samples %d' % len(result))
    # ana_ctx.close()
    return res_file, quest_ids
Esempio n. 17
0
def extract_answer_proposals(checkpoint_path=None, subset='kpval'):
    batch_size = 100
    config = ModelConfig()
    # Get model function
    # model_fn = get_model_creation_fn(FLAGS.model_type)

    if FLAGS.append_gt:
        ann_set = 'train' if 'train' in subset else 'val'
        mc_ctx = MultiChoiceQuestionManger(subset=ann_set,
                                           load_ans=True,
                                           answer_coding='sequence')
    else:
        mc_ctx = None

    # build data reader
    reader = AttentionFetcher(batch_size=batch_size, subset=subset,
                              feat_type=config.feat_type, version=FLAGS.version)
    if checkpoint_path is None:
        ckpt = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir % (FLAGS.version,
                                                                     FLAGS.model_type))
        checkpoint_path = ckpt.model_checkpoint_path
    print(checkpoint_path)

    # build and restore model
    model = model_fn(config, phase='test')
    # model.set_agent_ids([0])
    model.build()
    prob = model.prob

    sess = tf.Session(graph=tf.get_default_graph())
    tf.logging.info('Restore from model %s' % os.path.basename(checkpoint_path))
    saver = tf.train.Saver()
    saver.restore(sess, checkpoint_path)

    # Create the vocabulary.
    top_ans_file = '../VQA-tensorflow/data/vqa_trainval_top2000_answers.txt'
    to_sentence = SentenceGenerator(trainset='trainval',
                                    top_ans_file=top_ans_file)
    w2v_encoder = SentenceEncoder()
    # to_sentence = SentenceGenerator(trainset='trainval')

    cands_meta = []
    cands_scores = []
    cands_coding = []
    quest_ids = []
    is_oov = []
    print('Running inference on split %s...' % subset)
    for i in range(reader.num_batches):
        if i % 10 == 0:
            update_progress(i / float(reader.num_batches))
        outputs = reader.get_test_batch()
        raw_ans = sess.run(
            prob, feed_dict=model.fill_feed_dict(outputs[:-2]))
        generated_ans = raw_ans.copy()
        generated_ans[:, -1] = -1.0  # by default do not predict UNK
        # print('Max: %0.3f, Min: %0.3f' % (raw_ans.max(), raw_ans.min()))

        gt_labels = outputs[-3]
        if FLAGS.append_gt:
            generated_ans[np.arange(gt_labels.size), gt_labels] = 10.0

        ans_cand_ids = np.argsort(-generated_ans, axis=1)

        q_ids = outputs[-2]

        if FLAGS.append_gt:
            assert (np.all(np.equal(ans_cand_ids[:, 0], gt_labels)))

        for quest_id, ids, cand_scs, _gt in zip(q_ids, ans_cand_ids,
                                                raw_ans, gt_labels):
            answers = []
            answer_w2v = []

            # check out of vocabulary
            is_oov.append(_gt == 2000)

            cands_scores.append(cand_scs[ids[:_K]][np.newaxis, :])
            for k in range(_K):
                aid = ids[k]
                if aid == 2000:  # gt is out of vocab
                    ans = mc_ctx.get_gt_answer(quest_id)
                else:
                    ans = to_sentence.index_to_top_answer(aid)
                answer_w2v.append(w2v_encoder.encode(ans))
                answers.append(ans)
            answer_w2v = np.concatenate(answer_w2v, axis=1)
            res_i = {'quest_id': int(quest_id), 'cands': answers}
            cands_meta.append(res_i)
            cands_coding.append(answer_w2v)
            quest_ids.append(quest_id)
    quest_ids = np.array(quest_ids, dtype=np.int32)
    is_oov = np.array(is_oov, dtype=np.bool)
    labels = np.zeros_like(quest_ids, dtype=np.int32)
    cands_scores = np.concatenate(cands_scores, axis=0).astype(np.float32)
    cands_coding = np.concatenate(cands_coding, axis=0).astype(np.float32)
    save_hdf5('data3/vqa_ap_w2v_coding_%s.data' % subset, {'cands_w2v': cands_coding,
                                                           'cands_scs': cands_scores,
                                                           'quest_ids': quest_ids,
                                                           'is_oov': is_oov,
                                                           'labels': labels})
    save_json('data3/vqa_ap_cands_%s.meta' % subset, cands_meta)
    print('\n\nExtraction Done!')
    print('OOV percentage: %0.2f' % (100.*is_oov.sum()/reader._num))
Esempio n. 18
0
 def _build_vqa_agent(self):
     with tf.variable_scope('vqa_agent'):
         self.vqa_agent = VQAAgent(config=ModelConfig(), phase='test')