示例#1
0
    def test_reproduce(self):
        # disable all logging
        logging.disable(logging.CRITICAL)
        set_reproducibility(0)

        trainer = Trainer(self.args, tf_log=False)
        train_set, dev_set = trainer.load_dataset()
        trainer._init_model()

        expected_losses = [
            483.9374983910236, 396.7885422831847, 313.4529746738469,
            412.6149972063738, 294.88370214788677, 276.9205821059003,
            242.83269801561872, 225.63258253138852, 223.06413635684396,
            209.42066629479157, 182.57185706984563, 199.2475491835321,
            203.06883826238152, 166.84528868322133, 177.92200707610107,
            163.95598015036012, 143.41510641020656, 164.14947102979863,
            157.85085816150064, 145.01048472122704
        ]
        losses = self._test_train(trainer, train_set)
        # print(losses)
        self.assertEqual(losses, expected_losses)

        expected_valid_loss = -19.518319110144216
        valid_loss = self._test_valid(trainer, dev_set)
        trainer.save_model()
        # print(valid_loss)
        self.assertEqual(valid_loss, expected_valid_loss)

        expected_hypos = [
            [
                Hypothesis(value=[
                    'Examines', 'the', 'undistorted', 'gray', 'input', 'image',
                    'for', 'squares', '<con>', '<con>', '.', '.', '.'
                ],
                           score=-9.169617064004049),
                Hypothesis(value=[
                    'Examines', 'the', 'undistorted', 'gray', 'input', 'image',
                    'for', 'squares', '<con>', '<con>', '.', '.'
                ],
                           score=-8.4981576175081)
            ],
            [
                Hypothesis(value=[
                    'Save', 'basic', 'clusters', '<con>', '.', '.', '.', '.',
                    '.'
                ],
                           score=-4.295478831250667),
                Hypothesis(value=[
                    'Save', 'basic', 'clusters', '<con>', '.', '.', '.', '.',
                    '.', '.'
                ],
                           score=-4.809735736699733)
            ],
            [
                Hypothesis(value=[
                    'Configure', 'a', 'ssl', '<con>', 'Config', 'for', 'the',
                    'server', 'using', 'the', 'legacy', 'legacy',
                    'configuration', 'configuration', 'configuration',
                    'configuration', 'configuration', 'configuration',
                    'configuration', 'configuration', 'configuration',
                    'configuration', 'configuration', 'configuration',
                    'configuration', 'configuration', 'configuration',
                    'configuration', 'configuration', 'configuration',
                    'configuration', 'configuration', 'configuration',
                    'configuration', 'configuration', 'configuration',
                    'configuration', 'configuration', 'configuration',
                    'configuration', 'configuration', 'configuration',
                    'configuration', 'configuration', 'configuration',
                    'configuration', 'configuration', 'configuration',
                    'configuration', 'configuration'
                ],
                           score=-35.71404207185598)
            ]
        ]
        hypos = self._test_infer()
        self.assertEqual(hypos, expected_hypos)