Beispiel #1
0
    def test_metrics_select(self):
        """
        Test output of running eval_model.
        """
        parser = setup_args()
        parser.set_defaults(
            task='integration_tests',
            model='repeat_label',
            datatype='valid',
            num_examples=5,
            display_examples=False,
            metrics='accuracy,rouge',
        )

        opt = parser.parse_args([], print_args=False)
        valid, test = testing_utils.eval_model(opt)

        self.assertEqual(valid['accuracy'], 1)
        self.assertEqual(valid['rouge-L'], 1)
        self.assertEqual(valid['rouge-1'], 1)
        self.assertEqual(valid['rouge-2'], 1)
        self.assertEqual(test['accuracy'], 1)
        self.assertEqual(test['rouge-L'], 1)
        self.assertEqual(test['rouge-1'], 1)
        self.assertEqual(test['rouge-2'], 1)

        self.assertNotIn('bleu-4', valid)
        self.assertNotIn('bleu-4', test)
Beispiel #2
0
    def test_metrics_select(self):
        """Test output of running eval_model"""
        parser = setup_args()
        parser.set_defaults(
            task='integration_tests',
            model='repeat_label',
            datatype='valid',
            num_examples=5,
            display_examples=False,
            metrics='accuracy,rouge',
        )

        opt = parser.parse_args(print_args=False)
        str_output, valid, test = testing_utils.eval_model(opt)
        self.assertGreater(len(str_output), 0, "Output is empty")

        # decode the output
        scores = str_output.split("\n---\n")

        for i in range(1, len(scores)):
            score = ast.literal_eval(scores[i])
            # check totals
            self.assertEqual(score['exs'], i, "Total is incorrect")
            # accuracy should be one
            self.assertEqual(
                'accuracy' in score, True, "Accuracy is missing from selection"
            )
            self.assertEqual(score['accuracy'], 1, "Accuracy != 1")
            self.assertEqual(
                'rouge-1' in score, True, "Rouge is missing from selection"
            )
            self.assertEqual(score['rouge-1'], 1, 'rouge1 != 1')
            self.assertEqual(score['rouge-2'], 1, 'rouge-2 != 1')
            self.assertEqual(score['rouge-L'], 1, 'rouge-L != 1')
Beispiel #3
0
    def test_output(self):
        """Test output of running eval_model"""
        parser = setup_args()
        parser.set_defaults(
            task='tasks.repeat:RepeatTeacher:10',
            model='repeat_label',
            datatype='valid',
            num_examples=5,
            display_examples=False,
        )

        opt = parser.parse_args(print_args=False)
        str_output, valid, test = testing_utils.eval_model(opt)
        self.assertGreater(len(str_output), 0, "Output is empty")

        # decode the output
        scores = str_output.split("\n---\n")
        for i in range(1, len(scores)):
            score = ast.literal_eval(scores[i])
            # check totals
            self.assertTrue(score['exs'] == i,
                            "Total is incorrect")
            # accuracy should be one
            self.assertTrue(score['accuracy'] == 1,
                            "accuracy != 1")
    def test_output(self):
        """Test output of running eval_model"""
        class display_output(object):
            def __init__(self):
                self.data = []

            def write(self, s):
                self.data.append(s)

            def __str__(self):
                return "".join(self.data)

        parser = setup_args()
        parser.set_defaults(
            task='tasks.repeat:RepeatTeacher:10',
            model='repeat_label',
            datatype='valid',
            num_examples=5,
            display_examples=False,
        )

        old_out = sys.stdout
        output = display_output()
        try:
            sys.stdout = output
            opt = parser.parse_args(print_args=False)
            eval_model(opt, print_parser=parser)
        finally:
            # restore sys.stdout
            sys.stdout = old_out

        str_output = str(output)
        self.assertTrue(len(str_output) > 0, "Output is empty")

        # decode the output
        scores = str_output.split("\n---\n")
        for i in range(1, len(scores)):
            score = ast.literal_eval(scores[i])
            # check totals
            self.assertTrue(score['total'] == i, "Total is incorrect")
            # accuracy should be one
            self.assertTrue(score['accuracy'] == 1, "accuracy != 1")
Beispiel #5
0
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. An additional grant
# of patent rights can be found in the PATENTS file in the same directory.
"""Evaluate pre-trained model trained for ppl metric.
This seq2seq model was trained on convai2:self.
"""

from projects.convai2.baselines.download_models import download
from parlai.core.params import ParlaiParser
from examples.eval_model import setup_args, eval_model
from parlai.agents.seq2seq.seq2seq import Seq2seqAgent

if __name__ == '__main__':
    parser = setup_args()
    parser.set_defaults(
        task='convai2:self',
        model='seq2seq',
        model_file=
        'models:convai2/seq2seq/convai2_self_seq2seq_model/convai2_self_seq2seq_model',
        dict_file='models:convai2/seq2seq/dict_convai2_self/dict_convai2_self',
        datatype='valid',
        batchsize=128,
    )
    opt = parser.parse_args()
    download(opt, 'convai2/seq2seq', 'convai2_self_seq2seq_model.tgz')
    download(opt, 'convai2/seq2seq', 'dict_convai2_self')
    eval_model(parser, printargs=False)