def test(checkpoint_path=None): subsets = ['kpval', 'kptest', 'kprestval'] quest_ids = [] result = [] config = ModelConfig() config.sample_negative = FLAGS.sample_negative config.use_fb_bn = FLAGS.use_fb_bn # Get model function model_fn = get_model_creation_fn(FLAGS.model_type) # build and restore model model = model_fn(config, phase='test') model.build() sess = tf.Session(graph=tf.get_default_graph()) tf.logging.info('Restore from model %s' % os.path.basename(checkpoint_path)) saver = tf.train.Saver() saver.restore(sess, checkpoint_path) for subset in subsets: _quest_ids, _result = test_worker(model, sess, subset) quest_ids += _quest_ids result += _result quest_ids = np.concatenate(quest_ids) # save results tf.logging.info('Saving results') res_file = FLAGS.result_format % (FLAGS.version, 'val') json.dump(result, open(res_file, 'w')) tf.logging.info('Done!') tf.logging.info('#Num eval samples %d' % len(result)) return res_file, quest_ids
def test(checkpoint_path=None): batch_size = 64 config = ModelConfig() config.sample_negative = FLAGS.sample_negative config.use_fb_bn = FLAGS.use_fb_bn # Get model function model_fn = get_model_creation_fn(FLAGS.model_type) # build data reader reader = TestReader(batch_size=batch_size, subset=TEST_SET, use_fb_data=FLAGS.use_fb_data) if checkpoint_path is None: ckpt = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir % (FLAGS.version, FLAGS.model_type)) checkpoint_path = ckpt.model_checkpoint_path print(checkpoint_path) # build and restore model model = model_fn(config, phase='test') model.build() prob = model.prob sess = tf.Session(graph=tf.get_default_graph()) tf.logging.info('Restore from model %s' % os.path.basename(checkpoint_path)) saver = tf.train.Saver() saver.restore(sess, checkpoint_path) quest_ids = [] result = [] print('Running inference on split %s...' % TEST_SET) for i in range(reader.num_batches): if i % 10 == 0: update_progress(i / float(reader.num_batches)) outputs = reader.get_test_batch() mc_scores = sess.run(model._logits, feed_dict=model.fill_feed_dict(outputs[:-3])) choice_idx = np.argmax(mc_scores, axis=1) cands, _qids, image_ids = outputs[-3:] for qid, cid, mcs in zip(_qids, choice_idx, cands): answer = mcs['cands'][cid] assert (mcs['quest_id'] == qid) result.append({u'answer': answer, u'question_id': qid}) quest_ids.append(_qids) quest_ids = np.concatenate(quest_ids) # save results tf.logging.info('Saving results') res_file = FLAGS.result_format % (FLAGS.version, TEST_SET) json.dump(result, open(res_file, 'w')) tf.logging.info('Done!') tf.logging.info('#Num eval samples %d' % len(result)) return res_file, quest_ids