def test_train(self) -> None: class LogCallback(Callback): def __init__(self): super().__init__() def on_epoch_end(self, epoch, logs=None): self.logs = logs data = [('a b', 'a'), ('a b c', 'b')] summarizer = AttentionSummarizer(lstm_size=16, embedding_size=10) log_callback = LogCallback() trainer = Trainer(batch_size=2, steps_per_epoch=10, max_vocab_size_encoder=10, max_vocab_size_decoder=10, model_save_path=None, max_output_len=3) trainer.train(summarizer, data, num_epochs=2, callbacks=[log_callback]) logs = log_callback.logs self.assertAlmostEqual(1.7135955810546875, logs['loss'], 6)
def test_init_model(self) -> None: logging.basicConfig(level=logging.INFO) data = [('a b', 'a'), ('a b c', 'b')] summarizer = SummarizerAttention(lstm_size=16, embedding_size=10) trainer = Trainer(batch_size=2, steps_per_epoch=10, max_vocab_size_encoder=10, max_vocab_size_decoder=10, max_output_len=3) trainer.train(summarizer, data, num_epochs=1) # encoding dim and decoding dim are num unique tokens + 4 (pad, start, end, oov) self.assertIsNotNone(summarizer.vectorizer) self.assertEqual(7, summarizer.vectorizer.encoding_dim) self.assertEqual(6, summarizer.vectorizer.decoding_dim) self.assertEqual(3, summarizer.vectorizer.max_output_len)