예제 #1
0
def test(checkpoint_path=None):
    batch_size = 100
    config = ModelConfig()
    # Get model function
    # model_fn = get_model_creation_fn(FLAGS.model_type)

    # build data reader
    reader = AttentionFetcher(batch_size=batch_size,
                              subset=TEST_SET,
                              feat_type=config.feat_type,
                              version=FLAGS.version)
    if checkpoint_path is None:
        ckpt = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir %
                                             (FLAGS.version, FLAGS.model_type))
        checkpoint_path = ckpt.model_checkpoint_path
    print(checkpoint_path)

    # build and restore model
    model = model_fn(config, phase='test')
    # model.set_agent_ids([0])
    model.build()
    prob = model.prob

    sess = tf.Session(graph=tf.get_default_graph())
    tf.logging.info('Restore from model %s' %
                    os.path.basename(checkpoint_path))
    saver = tf.train.Saver()
    saver.restore(sess, checkpoint_path)

    # Create the vocabulary.
    top_ans_file = '../VQA-tensorflow/data/vqa_trainval_top2000_answers.txt'
    to_sentence = SentenceGenerator(trainset='trainval',
                                    top_ans_file=top_ans_file)
    # to_sentence = SentenceGenerator(trainset='trainval')

    results = []

    print('Running inference on split %s...' % TEST_SET)
    for i in range(reader.num_batches):
        if i % 10 == 0:
            update_progress(i / float(reader.num_batches))
        outputs = reader.get_test_batch()
        generated_ans = sess.run(prob,
                                 feed_dict=model.fill_feed_dict(outputs[:-2]))
        generated_ans[:, -1] = 0
        ans_cand_ids = np.argsort(-generated_ans, axis=1)

        quest_ids = outputs[-2]

        for quest_id, ids in zip(quest_ids, ans_cand_ids):
            answers = []
            for k in range(_K):
                aid = ids[k]
                ans = to_sentence.index_to_top_answer(aid)
                answers.append(ans)
            res_i = {'question_id': int(quest_id), 'answers': answers}
            results.append(res_i)

    eval_recall(results)
예제 #2
0
def test(checkpoint_path=None):
    batch_size = 100
    config = ModelConfig()
    # Get model function
    # model_fn = get_model_creation_fn(FLAGS.model_type)
    _model_suffix = 'var_' if FLAGS.use_var else ''

    # build data reader
    reader = AttentionFetcher(batch_size=batch_size,
                              subset=FLAGS.testset,
                              feat_type=config.feat_type,
                              version=FLAGS.version,
                              var_suffix=_model_suffix)
    if checkpoint_path is None:
        ckpt = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir %
                                             (FLAGS.version, FLAGS.model_type))
        checkpoint_path = ckpt.model_checkpoint_path
    print(checkpoint_path)

    # build and restore model
    model = model_fn(config, phase='test')
    # model.set_agent_ids([0])
    model.build()
    prob = model.prob

    sess = tf.Session(graph=tf.get_default_graph())
    tf.logging.info('Restore from model %s' %
                    os.path.basename(checkpoint_path))
    saver = tf.train.Saver()
    saver.restore(sess, checkpoint_path)

    # Create the vocabulary.
    # top_ans_file = '../VQA-tensorflow/data/vqa_trainval_top2000_answers.txt'
    # to_sentence = SentenceGenerator(trainset='trainval',
    #                                 top_ans_file=top_ans_file)
    # to_sentence = SentenceGenerator(trainset='trainval')

    ans_ids = []
    ans_scores = []
    gt_scores = []
    quest_ids = []

    print('Running inference on split %s...' % FLAGS.testset)
    for i in range(reader.num_batches):
        if i % 10 == 0:
            update_progress(i / float(reader.num_batches))
        outputs = reader.get_test_batch()
        generated_ans = sess.run(prob,
                                 feed_dict=model.fill_feed_dict(outputs[:-2]))
        _gt_labels = outputs[3]
        _this_batch_size = _gt_labels.size
        _gt_scores = generated_ans[np.arange(_this_batch_size, ), _gt_labels]
        gt_scores.append(_gt_scores)
        generated_ans[:, -1] = 0
        top_ans = np.argmax(generated_ans, axis=1)
        top_scores = np.max(generated_ans, axis=1)

        ans_ids.append(top_ans)
        ans_scores.append(top_scores)
        quest_id = outputs[-2]
        quest_ids.append(quest_id)

    quest_ids = np.concatenate(quest_ids)
    ans_ids = np.concatenate(ans_ids)
    ans_scores = np.concatenate(ans_scores)
    gt_scores = np.concatenate(gt_scores)

    # save results
    tf.logging.info('Saving results')
    # res_file = FLAGS.result_format % (FLAGS.version, FLAGS.testset)
    from util import save_hdf5
    save_hdf5(
        'data4/%sv2qa_%s_qa_scores.data' % (_model_suffix, FLAGS.testset), {
            'ext_quest_ids': quest_ids,
            'ext_cand_scores': gt_scores,
            'ext_cand_pred_labels': ans_ids,
            'ext_cand_pred_scores': ans_scores
        })
def test(checkpoint_path=None):
    batch_size = 100
    config = ModelConfig()
    # Get model function
    # model_fn = get_model_creation_fn(FLAGS.model_type)

    # build data reader
    reader = AttentionFetcher(batch_size=batch_size, subset=TEST_SET,
                              feat_type=config.feat_type, version=FLAGS.version)
    if checkpoint_path is None:
        ckpt = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir % (FLAGS.version,
                                                                     FLAGS.model_type))
        checkpoint_path = ckpt.model_checkpoint_path
    print(checkpoint_path)

    # build and restore model
    model = model_fn(config, phase='test')
    # model.set_agent_ids([0])
    model.build()
    prob = model.prob

    sess = tf.Session(graph=tf.get_default_graph())
    tf.logging.info('Restore from model %s' % os.path.basename(checkpoint_path))
    saver = tf.train.Saver()
    saver.restore(sess, checkpoint_path)

    # Create the vocabulary.
    top_ans_file = '../VQA-tensorflow/data/vqa_trainval_top2000_answers.txt'
    to_sentence = SentenceGenerator(trainset='trainval',
                                    top_ans_file=top_ans_file)
    # to_sentence = SentenceGenerator(trainset='trainval')

    ans_ids = []
    quest_ids = []
    ans_preds = []

    print('Running inference on split %s...' % TEST_SET)
    for i in range(reader.num_batches):
        if i % 10 == 0:
            update_progress(i / float(reader.num_batches))
        outputs = reader.get_test_batch()
        generated_ans = sess.run(
            prob, feed_dict=model.fill_feed_dict(outputs[:-2]))
        ans_preds.append(generated_ans)
        generated_ans[:, -1] = 0
        top_ans = np.argmax(generated_ans, axis=1)

        ans_ids.append(top_ans)
        quest_id = outputs[-2]
        quest_ids.append(quest_id)

    quest_ids = np.concatenate(quest_ids)
    ans_ids = np.concatenate(ans_ids)
    ans_preds = np.concatenate(ans_preds)
    result = [{u'answer': to_sentence.index_to_top_answer(aid),
               u'question_id': qid} for aid, qid in zip(ans_ids, quest_ids)]

    # save results
    tf.logging.info('Saving results')
    res_file = FLAGS.result_format % (FLAGS.version, TEST_SET)
    data_file = 'data5/%s_%s_scores_flt.data' % (TEST_SET, FLAGS.model_type)
    from util import save_hdf5
    save_hdf5(data_file, {'quest_ids': quest_ids, 'ans_preds': ans_preds})
    json.dump(result, open(res_file, 'w'))
    tf.logging.info('Done!')
    tf.logging.info('#Num eval samples %d' % len(result))
    return res_file, quest_ids
def train():
    _model_suffix = 'var_' if FLAGS.use_var else ''
    model_config = ModelConfig()
    training_config = TrainConfig()

    # Get model
    # model_fn = get_model_creation_fn(FLAGS.model_type)

    # Create training directory.
    train_dir = FLAGS.train_dir % (FLAGS.model_trainset, FLAGS.model_type)
    do_counter_sampling = FLAGS.version == 'v2'
    if not tf.gfile.IsDirectory(train_dir):
        tf.logging.info("Creating training directory: %s", train_dir)
        tf.gfile.MakeDirs(train_dir)

    g = tf.Graph()
    with g.as_default():
        # Build the model.
        model = model_fn(model_config, phase='train')
        model.build()

        # Set up the learning rate
        learning_rate = tf.constant(training_config.initial_learning_rate)

        def _learning_rate_decay_fn(learn_rate, global_step):
            return tf.train.exponential_decay(
                learn_rate,
                global_step,
                decay_steps=training_config.decay_step,
                decay_rate=training_config.decay_factor,
                staircase=False)

        learning_rate_decay_fn = _learning_rate_decay_fn

        train_op = tf.contrib.layers.optimize_loss(
            loss=model.loss,
            global_step=model.global_step,
            learning_rate=learning_rate,
            optimizer=training_config.optimizer,
            clip_gradients=training_config.clip_gradients,
            learning_rate_decay_fn=learning_rate_decay_fn)

        # Set up the Saver for saving and restoring model checkpoints.
        saver = tf.train.Saver(
            max_to_keep=training_config.max_checkpoints_to_keep)

        # setup summaries
        summary_op = tf.summary.merge_all()

    # create reader
    model_name = os.path.split(train_dir)[1]
    reader = Reader(batch_size=64, un_ratio=0)
    # reader = Reader(batch_size=64,
    #                 known_set='kprestval',
    #                 unknown_set='kptrain',  # 'kptrain'
    #                 un_ratio=1,
    #                 hide_label=False)

    # Run training.
    training_util.train(train_op,
                        train_dir,
                        log_every_n_steps=FLAGS.log_every_n_steps,
                        graph=g,
                        global_step=model.global_step,
                        number_of_steps=FLAGS.number_of_steps,
                        init_fn=model.init_fn,
                        saver=saver,
                        reader=reader,
                        feed_fn=model.fill_feed_dict)
예제 #5
0
def extract_answer_proposals(checkpoint_path=None, subset='kpval'):
    batch_size = 100
    config = ModelConfig()
    # Get model function
    # model_fn = get_model_creation_fn(FLAGS.model_type)

    if FLAGS.append_gt:
        ann_set = 'train' if 'train' in subset else 'val'
        mc_ctx = MultiChoiceQuestionManger(subset=ann_set,
                                           load_ans=True,
                                           answer_coding='sequence')
    else:
        mc_ctx = None

    # build data reader
    reader = AttentionFetcher(batch_size=batch_size, subset=subset,
                              feat_type=config.feat_type, version=FLAGS.version)
    if checkpoint_path is None:
        ckpt = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir % (FLAGS.version,
                                                                     FLAGS.model_type))
        checkpoint_path = ckpt.model_checkpoint_path
    print(checkpoint_path)

    # build and restore model
    model = model_fn(config, phase='test')
    # model.set_agent_ids([0])
    model.build()
    prob = model.prob

    sess = tf.Session(graph=tf.get_default_graph())
    tf.logging.info('Restore from model %s' % os.path.basename(checkpoint_path))
    saver = tf.train.Saver()
    saver.restore(sess, checkpoint_path)

    # Create the vocabulary.
    top_ans_file = '../VQA-tensorflow/data/vqa_trainval_top2000_answers.txt'
    to_sentence = SentenceGenerator(trainset='trainval',
                                    top_ans_file=top_ans_file)
    w2v_encoder = SentenceEncoder()
    # to_sentence = SentenceGenerator(trainset='trainval')

    cands_meta = []
    cands_scores = []
    cands_coding = []
    quest_ids = []
    is_oov = []
    print('Running inference on split %s...' % subset)
    for i in range(reader.num_batches):
        if i % 10 == 0:
            update_progress(i / float(reader.num_batches))
        outputs = reader.get_test_batch()
        raw_ans = sess.run(
            prob, feed_dict=model.fill_feed_dict(outputs[:-2]))
        generated_ans = raw_ans.copy()
        generated_ans[:, -1] = -1.0  # by default do not predict UNK
        # print('Max: %0.3f, Min: %0.3f' % (raw_ans.max(), raw_ans.min()))

        gt_labels = outputs[-3]
        if FLAGS.append_gt:
            generated_ans[np.arange(gt_labels.size), gt_labels] = 10.0

        ans_cand_ids = np.argsort(-generated_ans, axis=1)

        q_ids = outputs[-2]

        if FLAGS.append_gt:
            assert (np.all(np.equal(ans_cand_ids[:, 0], gt_labels)))

        for quest_id, ids, cand_scs, _gt in zip(q_ids, ans_cand_ids,
                                                raw_ans, gt_labels):
            answers = []
            answer_w2v = []

            # check out of vocabulary
            is_oov.append(_gt == 2000)

            cands_scores.append(cand_scs[ids[:_K]][np.newaxis, :])
            for k in range(_K):
                aid = ids[k]
                if aid == 2000:  # gt is out of vocab
                    ans = mc_ctx.get_gt_answer(quest_id)
                else:
                    ans = to_sentence.index_to_top_answer(aid)
                answer_w2v.append(w2v_encoder.encode(ans))
                answers.append(ans)
            answer_w2v = np.concatenate(answer_w2v, axis=1)
            res_i = {'quest_id': int(quest_id), 'cands': answers}
            cands_meta.append(res_i)
            cands_coding.append(answer_w2v)
            quest_ids.append(quest_id)
    quest_ids = np.array(quest_ids, dtype=np.int32)
    is_oov = np.array(is_oov, dtype=np.bool)
    labels = np.zeros_like(quest_ids, dtype=np.int32)
    cands_scores = np.concatenate(cands_scores, axis=0).astype(np.float32)
    cands_coding = np.concatenate(cands_coding, axis=0).astype(np.float32)
    save_hdf5('data3/vqa_ap_w2v_coding_%s.data' % subset, {'cands_w2v': cands_coding,
                                                           'cands_scs': cands_scores,
                                                           'quest_ids': quest_ids,
                                                           'is_oov': is_oov,
                                                           'labels': labels})
    save_json('data3/vqa_ap_cands_%s.meta' % subset, cands_meta)
    print('\n\nExtraction Done!')
    print('OOV percentage: %0.2f' % (100.*is_oov.sum()/reader._num))