コード例 #1
0
ファイル: test_trainer.py プロジェクト: vyaslkv/headliner
    def test_train(self) -> None:

        class LogCallback(Callback):

            def __init__(self):
                super().__init__()

            def on_epoch_end(self, epoch, logs=None):
                self.logs = logs

        data = [('a b', 'a'), ('a b c', 'b')]

        summarizer = AttentionSummarizer(lstm_size=16,
                                         embedding_size=10)
        log_callback = LogCallback()
        trainer = Trainer(batch_size=2,
                          steps_per_epoch=10,
                          max_vocab_size_encoder=10,
                          max_vocab_size_decoder=10,
                          model_save_path=None,
                          max_output_len=3)

        trainer.train(summarizer,
                      data,
                      num_epochs=2,
                      callbacks=[log_callback])

        logs = log_callback.logs
        self.assertAlmostEqual(1.7135955810546875, logs['loss'], 6)
コード例 #2
0
 def test_init_model(self) -> None:
     logging.basicConfig(level=logging.INFO)
     data = [('a b', 'a'), ('a b c', 'b')]
     summarizer = SummarizerAttention(lstm_size=16, embedding_size=10)
     trainer = Trainer(batch_size=2,
                       steps_per_epoch=10,
                       max_vocab_size_encoder=10,
                       max_vocab_size_decoder=10,
                       max_output_len=3)
     trainer.train(summarizer, data, num_epochs=1)
     # encoding dim and decoding dim are num unique tokens + 4 (pad, start, end, oov)
     self.assertIsNotNone(summarizer.vectorizer)
     self.assertEqual(7, summarizer.vectorizer.encoding_dim)
     self.assertEqual(6, summarizer.vectorizer.decoding_dim)
     self.assertEqual(3, summarizer.vectorizer.max_output_len)