示例#1
0
    def test_set_model_file_without_dict_file(self):
        """Check that moving a model without moving the dictionary raises the
        appropriate error.
        """
        # Download model, move to a new location
        datapath = ParlaiParser().parse_args(print_args=False)['datapath']
        try:
            # remove unittest models if there before
            shutil.rmtree(os.path.join(datapath, 'models/unittest'))
        except FileNotFoundError:
            pass
        testing_utils.download_unittest_models()

        zoo_path = 'models:unittest/seq2seq/model'
        model_path = modelzoo_path(datapath, zoo_path)
        os.remove(model_path + '.dict')
        # Test that eval model fails
        with self.assertRaises(RuntimeError):
            testing_utils.eval_model(
                dict(task='babi:task1k:1', model_file=model_path))
        try:
            # remove unittest models if there after
            shutil.rmtree(os.path.join(datapath, 'models/unittest'))
        except FileNotFoundError:
            pass
示例#2
0
    def test_backwards_compatibility(self):
        testing_utils.download_unittest_models()

        stdout, valid, test = testing_utils.eval_model(
            dict(
                task='integration_tests:multipass',
                model='seq2seq',
                model_file='zoo:unittest/seq2seq/model',
                dict_file='zoo:unittest/seq2seq/model.dict',
                no_cuda=True,
            ))

        self.assertLessEqual(
            valid['ppl'], 1.01,
            'valid ppl = {}\nLOG:\n{}'.format(valid['ppl'], stdout))
        self.assertGreaterEqual(
            valid['accuracy'],
            0.999,
            'valid accuracy = {}\nLOG:\n{}'.format(valid['accuracy'], stdout),
        )
        self.assertGreaterEqual(
            valid['f1'], 0.999,
            'valid f1 = {}\nLOG:\n{}'.format(valid['f1'], stdout))
        self.assertLessEqual(
            test['ppl'], 1.01,
            'test ppl = {}\nLOG:\n{}'.format(test['ppl'], stdout))
        self.assertGreaterEqual(
            test['accuracy'],
            0.999,
            'test accuracy = {}\nLOG:\n{}'.format(test['accuracy'], stdout),
        )
        self.assertGreaterEqual(
            test['f1'], 0.999,
            'test f1 = {}\nLOG:\n{}'.format(test['f1'], stdout))
    def test_generator_backcomp(self):
        """
        Tests that the generator model files work over time.
        """
        testing_utils.download_unittest_models()

        stdout, valid, test = testing_utils.eval_model(
            dict(
                task='integration_tests:multipass',
                model='transformer/generator',
                model_file='models:unittest/transformer_generator2/model',
                dict_file='models:unittest/transformer_generator2/model.dict',
                rank_candidates=True,
                batch_size=64,
            ))

        self.assertGreaterEqual(
            valid['hits@1'],
            0.95,
            'valid hits@1 = {}\nLOG:\n{}'.format(valid['hits@1'], stdout),
        )
        self.assertLessEqual(
            valid['ppl'],
            1.01,
            'valid ppl = {}\nLOG:\n{}'.format(valid['ppl'], stdout),
        )
        self.assertGreaterEqual(
            valid['accuracy'],
            .99,
            'valid accuracy = {}\nLOG:\n{}'.format(valid['accuracy'], stdout),
        )
        self.assertGreaterEqual(
            valid['f1'], .99,
            'valid f1 = {}\nLOG:\n{}'.format(valid['f1'], stdout))
        self.assertGreaterEqual(
            test['hits@1'],
            0.95,
            'test hits@1 = {}\nLOG:\n{}'.format(test['hits@1'], stdout),
        )
        self.assertLessEqual(
            test['ppl'],
            1.01,
            'test ppl = {}\nLOG:\n{}'.format(test['ppl'], stdout),
        )
        self.assertGreaterEqual(
            test['accuracy'],
            .99,
            'test accuracy = {}\nLOG:\n{}'.format(test['accuracy'], stdout),
        )
        self.assertGreaterEqual(
            test['f1'], .99,
            'test f1 = {}\nLOG:\n{}'.format(test['f1'], stdout))
示例#4
0
    def test_backcomp(self):
        """
        Tests that the transformer ranker model files continue to work over time.
        """
        testing_utils.download_unittest_models()

        stdout, valid, test = testing_utils.eval_model(
            dict(
                task='integration_tests:multipass',
                model='transformer/ranker',
                model_file='zoo:unittest/transformer_ranker/model',
                dict_file='zoo:unittest/transformer_ranker/model.dict',
                batch_size=64,
            )
        )

        self.assertGreaterEqual(
            valid['hits@1'],
            0.99,
            'valid hits@1 = {}\nLOG:\n{}'.format(valid['hits@1'], stdout),
        )
        self.assertGreaterEqual(
            valid['accuracy'],
            0.99,
            'valid accuracy = {}\nLOG:\n{}'.format(valid['accuracy'], stdout),
        )
        self.assertGreaterEqual(
            valid['f1'], 0.99, 'valid f1 = {}\nLOG:\n{}'.format(valid['f1'], stdout)
        )
        self.assertGreaterEqual(
            test['hits@1'],
            0.99,
            'test hits@1 = {}\nLOG:\n{}'.format(test['hits@1'], stdout),
        )
        self.assertGreaterEqual(
            test['accuracy'],
            0.99,
            'test accuracy = {}\nLOG:\n{}'.format(test['accuracy'], stdout),
        )
        self.assertGreaterEqual(
            test['f1'], 0.99, 'test f1 = {}\nLOG:\n{}'.format(test['f1'], stdout)
        )