def test_small_corpus(self): """Test small corpus situation.""" # Increase batch size so the sample corpus appears small in this case. self.batch_size = 100 result = train_rnn_generator_task.train_rnn( self.input_directory, self.model_directory, self.log_directory, self.batch_size, self.hidden_state_size, self.hidden_layer_size) self.assertEqual(result.return_code, constants.ExitCode.CORPUS_TOO_SMALL) self.assertFalse(result.timed_out) # No model exsits after execution. self.assertFalse( train_rnn_generator_task.get_last_saved_model(self.model_directory))
def test_train_rnn(self): """Test train RNN model on a simple corpus.""" # No model exists in model directory. self.assertFalse( train_rnn_generator_task.get_last_saved_model(self.model_directory)) # The training should be fast (a few seconds) since sample corpus is # extremely small. result = train_rnn_generator_task.train_rnn( self.input_directory, self.model_directory, self.log_directory, self.batch_size, self.hidden_state_size, self.hidden_layer_size) self.assertEqual(result.return_code, constants.ExitCode.SUCCESS) self.assertFalse(result.timed_out) # At least one model exists. self.assertTrue( train_rnn_generator_task.get_last_saved_model(self.model_directory))