Пример #1
0
def test(epoch, dataset, config, log_dir, question_type_dict):
    """Test model, output prediction as json file."""
    model_config = config['model']
    sess_config = config['session']

    question_type_correct_count = copy.deepcopy(question_type_dict)
    question_type_all_count = copy.deepcopy(question_type_dict)
    for k in question_type_dict:
        question_type_correct_count[k] = 0
        question_type_all_count[k] = 0

    answerset = pd.read_csv(os.path.join(config['preprocess_dir'],
                                         'answer_set.txt'),
                            header=None)[0]

    with tf.Graph().as_default():
        model = TripleAttentiveModal_DMN(model_config)
        model.build_inference()

        with tf.Session(config=sess_config) as sess:
            ckpt_dir = os.path.join(log_dir, 'checkpoint')
            save_path = tf.train.latest_checkpoint(ckpt_dir)
            saver = tf.train.Saver()
            if save_path:
                print('load checkpoint {}.'.format(save_path))
                saver.restore(sess, save_path)
            else:
                print('no checkpoint.')
                exit()

            stats_dir = os.path.join(log_dir, 'stats')
            stats_path = os.path.join(stats_dir, 'test.json')
            if os.path.exists(stats_path):
                print('load stats file {}.'.format(stats_path))
                stats = pd.read_json(stats_path, 'records')
            else:
                print('no stats file.')
                if not os.path.exists(stats_dir):
                    os.makedirs(stats_dir)
                stats = pd.DataFrame(columns=['epoch', 'acc'])

            # test iterate over examples
            result = DataFrame(columns=['id', 'answer'])
            correct = 0

            groundtruth_answer_list = []
            predict_answer_list = []
            while dataset.has_test_example:
                vgg, c3d, vgg_conv5, vgg_conv5_3, mfcc, question, answer, example_id, question_len = dataset.get_test_example(
                )
                input_len = 20
                feed_dict = {
                    model.c3d_video_feature: [c3d],
                    model.vgg_video_feature: [vgg],
                    model.mfcc_video_feature: [mfcc],
                    model.question_encode: [question],
                    model.question_len_placeholder: [question_len],
                    model.video_len_placeholder: [input_len],
                    model.keep_placeholder: 1.0
                }
                prediction = sess.run(model.prediction, feed_dict=feed_dict)
                prediction = prediction[0]

                result = result.append(
                    {
                        'id': example_id,
                        'answer': answerset[prediction]
                    },
                    ignore_index=True)
                if answerset[prediction] == answer:
                    correct += 1
                    question_type_correct_count[question[0]] += 1
                question_type_all_count[question[0]] += 1

                groundtruth_answer_list.append(answer)
                predict_answer_list.append(answerset[prediction])

            result.to_json(
                os.path.join(log_dir, 'prediction_' + str(epoch) + '.json'),
                'records')
            acc = correct * 1.0 / dataset.test_example_total
            WUPS_0_0 = metrics.compute_wups(groundtruth_answer_list,
                                            predict_answer_list, 0.0)
            WUPS_0_9 = metrics.compute_wups(groundtruth_answer_list,
                                            predict_answer_list, 0.9)
            WUPS_acc = metrics.compute_wups(groundtruth_answer_list,
                                            predict_answer_list, -1)
            print('[TEST] acc {:.5f}.\n'.format(acc))
            print('[TEST], WUPS@acc {:.5f}.\n'.format(WUPS_acc))
            print('[TEST], [email protected] {:.5f}.\n'.format(WUPS_0_0))
            print('[TEST], [email protected] {:.5f}.\n'.format(WUPS_0_9))

            print('######## question type acc list ######### ')
            for k in question_type_dict:
                print(question_type_dict[k] +
                      ' acc {:.5f}.'.format(question_type_correct_count[k] *
                                            1.0 / question_type_all_count[k]))
                print('correct = {:d}, all = {:d}'.format(
                    question_type_correct_count[k],
                    question_type_all_count[k]))

            record = Series([epoch, acc], ['epoch', 'acc'])
            stats = stats.append(record, ignore_index=True)
            stats.to_json(stats_path, 'records')

            dataset.reset_test()
            return acc
Пример #2
0
def val(epoch, dataset, config, log_dir):
    """Validate model."""
    model_config = config['model']
    sess_config = config['session']

    answerset = pd.read_csv(os.path.join(config['preprocess_dir'],
                                         'answer_set.txt'),
                            header=None)[0]

    with tf.Graph().as_default():
        model = TripleAttentiveModal_DMN(model_config)
        model.build_inference()

        with tf.Session(config=sess_config) as sess:
            sum_dir = os.path.join(log_dir, 'summary')
            summary_writer = tf.summary.FileWriter(sum_dir)

            ckpt_dir = os.path.join(log_dir, 'checkpoint')
            save_path = tf.train.latest_checkpoint(ckpt_dir)
            saver = tf.train.Saver()
            if save_path:
                print('load checkpoint {}.'.format(save_path))
                saver.restore(sess, save_path)
            else:
                print('no checkpoint.')
                exit()

            stats_dir = os.path.join(log_dir, 'stats')
            stats_path = os.path.join(stats_dir, 'val.json')
            if os.path.exists(stats_path):
                print('load stats file {}.'.format(stats_path))
                stats = pd.read_json(stats_path, 'records')
            else:
                print('no stats file.')
                if not os.path.exists(stats_dir):
                    os.makedirs(stats_dir)
                stats = pd.DataFrame(columns=['epoch', 'acc'])

            # val iterate over examples
            correct = 0

            groundtruth_answer_list = []
            predict_answer_list = []
            while dataset.has_val_example:
                vgg, c3d, vgg_conv5, vgg_conv5_3, mfcc, question, answer, question_len = dataset.get_val_example(
                )
                input_len = 20
                feed_dict = {
                    model.c3d_video_feature: [c3d],
                    model.vgg_video_feature: [vgg],
                    model.mfcc_video_feature: [mfcc],
                    model.question_encode: [question],
                    model.question_len_placeholder: [question_len],
                    model.video_len_placeholder: [input_len],
                    model.keep_placeholder: 1.0
                }
                prediction = sess.run(model.prediction, feed_dict=feed_dict)
                prediction = prediction[0]
                if answerset[prediction] == answer:
                    correct += 1
                groundtruth_answer_list.append(answer)
                predict_answer_list.append(answerset[prediction])

            acc = correct * 1.0 / dataset.val_example_total
            WUPS_0_0 = metrics.compute_wups(groundtruth_answer_list,
                                            predict_answer_list, 0.0)
            WUPS_0_9 = metrics.compute_wups(groundtruth_answer_list,
                                            predict_answer_list, 0.9)
            WUPS_acc = metrics.compute_wups(groundtruth_answer_list,
                                            predict_answer_list, -1)
            print('[VAL] epoch {}, acc {:.5f}.\n'.format(epoch, acc))
            print('[VAL] epoch {}, WUPS@acc {:.5f}.\n'.format(epoch, WUPS_acc))
            print('[VAL] epoch {}, [email protected] {:.5f}.\n'.format(epoch, WUPS_0_0))
            print('[VAL] epoch {}, [email protected] {:.5f}.\n'.format(epoch, WUPS_0_9))

            summary = tf.Summary()
            summary.value.add(tag='val/acc', simple_value=float(acc))
            summary_writer.add_summary(summary, epoch)

            record = Series([epoch, acc], ['epoch', 'acc'])
            stats = stats.append(record, ignore_index=True)
            stats.to_json(stats_path, 'records')

            dataset.reset_val()
            return acc