class TestRNNLM(unittest.TestCase): def setUp(self): text = 'You said good-bye and I said hello.' cbm = CountBasedMethod() word_list = cbm.text_to_word_list(text) word_to_id, *_ = cbm.preprocess(word_list) vocab_size = len(word_to_id) wordvec_size = 100 hidden_size = 100 self.rnnlm = RNNLM(vocab_size, wordvec_size, hidden_size) self.xs = np.array([ [0, 4, 4, 1], [4, 0, 2, 1] ]) self.ts = np.array([ [0, 1, 0, 0], [0, 0, 0, 1] ]) def test_predict(self): score = self.rnnlm._predict(self.xs) self.assertEqual((2, 4, 7), score.shape) def test_forward(self): loss = self.rnnlm.forward(self.xs, self.ts) self.assertEqual(1.94, round(loss, 2)) def test_backward(self): self.rnnlm.forward(self.xs, self.ts) dout = self.rnnlm.backward() self.assertEqual(None, dout) def test_reset_state(self): self.rnnlm.forward(self.xs, self.ts) self.rnnlm.backward() self.assertEqual((2, 100), self.rnnlm.lstm_layer.h.shape) self.rnnlm.reset_state() self.assertEqual(None, self.rnnlm.lstm_layer.h) def test_save_params(self): self.rnnlm.forward(self.xs, self.ts) self.rnnlm.backward() self.rnnlm.save_params() self.assertEqual(True, path.exists('../pkl/rnnlm.pkl')) def test_load_params(self): self.rnnlm.load_params() a, b, c, d, e, f = self.rnnlm.params self.assertEqual((7, 100), a.shape) self.assertEqual((100, 400), b.shape) self.assertEqual((100, 400), c.shape) self.assertEqual((400,), d.shape) self.assertEqual((100, 7), e.shape) self.assertEqual((7,), f.shape)
corpus_test, *_ = load_data('test') vocab_size = len(word_to_id) xs = corpus[:-1] ts = corpus[1:] # Generate a model, optimiser and trainer model = RNNLM(vocab_size, wordvec_size, hidden_size) optimiser = SGD(learning_rate) trainer = RNNLMTrainer(model, optimiser) # 1. Train applying gradients clipping training_process = trainer.fit(xs, ts, max_epoch, batch_size, time_size, max_grad, eval_interval=20) for iter in training_process: print(iter) file_path = '../img/train_rnnlm.png' tainer.save_plot_image(file_path, ylim=(0, 500)) # 2. Evaluate by test data model.reset_state() ppl_test = eval_perplexity(model, corpus_test) print('Test perplexity: ', ppl_test) # 3. Save parameters model.save_params()