コード例 #1
0
 def test_get_latest_model(self):
   """Test latest model is returned as a dictionary."""
   # Model_2 is newer than model_1, so we will get model_2.
   model_paths = train_rnn_generator_task.get_last_saved_model(MODEL_DIR)
   expected = {
       'data': self.model_2_data_path,
       'index': self.model_2_index_path
   }
   self.assertDictEqual(model_paths, expected)
コード例 #2
0
  def test_train_rnn(self):
    """Test train RNN model on a simple corpus."""
    # No model exists in model directory.
    self.assertFalse(
        train_rnn_generator_task.get_last_saved_model(self.model_directory))

    # The training should be fast (a few seconds) since sample corpus is
    # extremely small.
    result = train_rnn_generator_task.train_rnn(
        self.input_directory, self.model_directory, self.log_directory,
        self.batch_size, self.hidden_state_size, self.hidden_layer_size)

    self.assertEqual(result.return_code, constants.ExitCode.SUCCESS)
    self.assertFalse(result.timed_out)

    # At least one model exists.
    self.assertTrue(
        train_rnn_generator_task.get_last_saved_model(self.model_directory))
コード例 #3
0
  def test_no_valid_model(self):
    """Test no model is returned if all models are invalid."""
    # Remove one file from model_1 and one from model_2.
    os.remove(self.model_1_data_path)
    os.remove(self.model_2_index_path)
    self.assertFalse(os.path.exists(self.model_1_data_path))
    self.assertFalse(os.path.exists(self.model_2_index_path))

    # Now we should get empty dictionary since both models are invalid.
    model_paths = train_rnn_generator_task.get_last_saved_model(MODEL_DIR)
    expected = {}
    self.assertDictEqual(model_paths, expected)
コード例 #4
0
  def test_get_valid_model(self):
    """Test lastest model is not returned if it is invalid."""
    # Remove one file from model_2, so model_2 is not valid.
    os.remove(self.model_2_index_path)
    self.assertFalse(os.path.exists(self.model_2_index_path))

    # Now we should get model_1.
    model_paths = train_rnn_generator_task.get_last_saved_model(MODEL_DIR)
    expected = {
        'data': self.model_1_data_path,
        'index': self.model_1_index_path
    }
    self.assertDictEqual(model_paths, expected)
コード例 #5
0
  def test_small_corpus(self):
    """Test small corpus situation."""
    # Increase batch size so the sample corpus appears small in this case.
    self.batch_size = 100

    result = train_rnn_generator_task.train_rnn(
        self.input_directory, self.model_directory, self.log_directory,
        self.batch_size, self.hidden_state_size, self.hidden_layer_size)

    self.assertEqual(result.return_code, constants.ExitCode.CORPUS_TOO_SMALL)
    self.assertFalse(result.timed_out)

    # No model exsits after execution.
    self.assertFalse(
        train_rnn_generator_task.get_last_saved_model(self.model_directory))