def test(epoch, dataset, config, log_dir, question_type_dict): """Test model, output prediction as json file.""" model_config = config['model'] sess_config = config['session'] question_type_correct_count = copy.deepcopy(question_type_dict) question_type_all_count = copy.deepcopy(question_type_dict) for k in question_type_dict: question_type_correct_count[k] = 0 question_type_all_count[k] = 0 answerset = pd.read_csv(os.path.join(config['preprocess_dir'], 'answer_set.txt'), header=None)[0] with tf.Graph().as_default(): model = TripleAttentiveModal_DMN(model_config) model.build_inference() with tf.Session(config=sess_config) as sess: ckpt_dir = os.path.join(log_dir, 'checkpoint') save_path = tf.train.latest_checkpoint(ckpt_dir) saver = tf.train.Saver() if save_path: print('load checkpoint {}.'.format(save_path)) saver.restore(sess, save_path) else: print('no checkpoint.') exit() stats_dir = os.path.join(log_dir, 'stats') stats_path = os.path.join(stats_dir, 'test.json') if os.path.exists(stats_path): print('load stats file {}.'.format(stats_path)) stats = pd.read_json(stats_path, 'records') else: print('no stats file.') if not os.path.exists(stats_dir): os.makedirs(stats_dir) stats = pd.DataFrame(columns=['epoch', 'acc']) # test iterate over examples result = DataFrame(columns=['id', 'answer']) correct = 0 groundtruth_answer_list = [] predict_answer_list = [] while dataset.has_test_example: vgg, c3d, vgg_conv5, vgg_conv5_3, mfcc, question, answer, example_id, question_len = dataset.get_test_example( ) input_len = 20 feed_dict = { model.c3d_video_feature: [c3d], model.vgg_video_feature: [vgg], model.mfcc_video_feature: [mfcc], model.question_encode: [question], model.question_len_placeholder: [question_len], model.video_len_placeholder: [input_len], model.keep_placeholder: 1.0 } prediction = sess.run(model.prediction, feed_dict=feed_dict) prediction = prediction[0] result = result.append( { 'id': example_id, 'answer': answerset[prediction] }, ignore_index=True) if answerset[prediction] == answer: correct += 1 question_type_correct_count[question[0]] += 1 question_type_all_count[question[0]] += 1 groundtruth_answer_list.append(answer) predict_answer_list.append(answerset[prediction]) result.to_json( os.path.join(log_dir, 'prediction_' + str(epoch) + '.json'), 'records') acc = correct * 1.0 / dataset.test_example_total WUPS_0_0 = metrics.compute_wups(groundtruth_answer_list, predict_answer_list, 0.0) WUPS_0_9 = metrics.compute_wups(groundtruth_answer_list, predict_answer_list, 0.9) WUPS_acc = metrics.compute_wups(groundtruth_answer_list, predict_answer_list, -1) print('[TEST] acc {:.5f}.\n'.format(acc)) print('[TEST], WUPS@acc {:.5f}.\n'.format(WUPS_acc)) print('[TEST], [email protected] {:.5f}.\n'.format(WUPS_0_0)) print('[TEST], [email protected] {:.5f}.\n'.format(WUPS_0_9)) print('######## question type acc list ######### ') for k in question_type_dict: print(question_type_dict[k] + ' acc {:.5f}.'.format(question_type_correct_count[k] * 1.0 / question_type_all_count[k])) print('correct = {:d}, all = {:d}'.format( question_type_correct_count[k], question_type_all_count[k])) record = Series([epoch, acc], ['epoch', 'acc']) stats = stats.append(record, ignore_index=True) stats.to_json(stats_path, 'records') dataset.reset_test() return acc
def val(epoch, dataset, config, log_dir): """Validate model.""" model_config = config['model'] sess_config = config['session'] answerset = pd.read_csv(os.path.join(config['preprocess_dir'], 'answer_set.txt'), header=None)[0] with tf.Graph().as_default(): model = TripleAttentiveModal_DMN(model_config) model.build_inference() with tf.Session(config=sess_config) as sess: sum_dir = os.path.join(log_dir, 'summary') summary_writer = tf.summary.FileWriter(sum_dir) ckpt_dir = os.path.join(log_dir, 'checkpoint') save_path = tf.train.latest_checkpoint(ckpt_dir) saver = tf.train.Saver() if save_path: print('load checkpoint {}.'.format(save_path)) saver.restore(sess, save_path) else: print('no checkpoint.') exit() stats_dir = os.path.join(log_dir, 'stats') stats_path = os.path.join(stats_dir, 'val.json') if os.path.exists(stats_path): print('load stats file {}.'.format(stats_path)) stats = pd.read_json(stats_path, 'records') else: print('no stats file.') if not os.path.exists(stats_dir): os.makedirs(stats_dir) stats = pd.DataFrame(columns=['epoch', 'acc']) # val iterate over examples correct = 0 groundtruth_answer_list = [] predict_answer_list = [] while dataset.has_val_example: vgg, c3d, vgg_conv5, vgg_conv5_3, mfcc, question, answer, question_len = dataset.get_val_example( ) input_len = 20 feed_dict = { model.c3d_video_feature: [c3d], model.vgg_video_feature: [vgg], model.mfcc_video_feature: [mfcc], model.question_encode: [question], model.question_len_placeholder: [question_len], model.video_len_placeholder: [input_len], model.keep_placeholder: 1.0 } prediction = sess.run(model.prediction, feed_dict=feed_dict) prediction = prediction[0] if answerset[prediction] == answer: correct += 1 groundtruth_answer_list.append(answer) predict_answer_list.append(answerset[prediction]) acc = correct * 1.0 / dataset.val_example_total WUPS_0_0 = metrics.compute_wups(groundtruth_answer_list, predict_answer_list, 0.0) WUPS_0_9 = metrics.compute_wups(groundtruth_answer_list, predict_answer_list, 0.9) WUPS_acc = metrics.compute_wups(groundtruth_answer_list, predict_answer_list, -1) print('[VAL] epoch {}, acc {:.5f}.\n'.format(epoch, acc)) print('[VAL] epoch {}, WUPS@acc {:.5f}.\n'.format(epoch, WUPS_acc)) print('[VAL] epoch {}, [email protected] {:.5f}.\n'.format(epoch, WUPS_0_0)) print('[VAL] epoch {}, [email protected] {:.5f}.\n'.format(epoch, WUPS_0_9)) summary = tf.Summary() summary.value.add(tag='val/acc', simple_value=float(acc)) summary_writer.add_summary(summary, epoch) record = Series([epoch, acc], ['epoch', 'acc']) stats = stats.append(record, ignore_index=True) stats.to_json(stats_path, 'records') dataset.reset_val() return acc