Exemplo n.º 1
0
 def test_get_latest_model(self):
     """Test latest model is returned as a dictionary."""
     # Model_2 is newer than model_1, so we will get model_2.
     model_paths = ml_train_task.get_last_saved_model(MODEL_DIR)
     expected = {
         'meta': self.model_2_meta_path,
         'data': self.model_2_data_path,
         'index': self.model_2_index_path
     }
     self.assertDictEqual(model_paths, expected)
Exemplo n.º 2
0
    def test_train_rnn(self):
        """Test train RNN model on a simple corpus."""
        # No model exists in model directory.
        self.assertFalse(
            ml_train_task.get_last_saved_model(self.model_directory))

        # The training should be fast (a few seconds) since sample corpus is
        # extremely small.
        result = ml_train_task.train_rnn(self.input_directory,
                                         self.model_directory,
                                         self.log_directory, self.batch_size,
                                         self.hidden_state_size,
                                         self.hidden_layer_size)

        self.assertEqual(result.return_code, constants.ExitCode.SUCCESS)
        self.assertFalse(result.timed_out)

        # At least one model exists.
        self.assertTrue(
            ml_train_task.get_last_saved_model(self.model_directory))
Exemplo n.º 3
0
    def test_no_valid_model(self):
        """Test no model is returned if all models are invalid."""
        # Remove one file from model_1 and one from model_2.
        os.remove(self.model_1_data_path)
        os.remove(self.model_2_index_path)
        self.assertFalse(os.path.exists(self.model_1_data_path))
        self.assertFalse(os.path.exists(self.model_2_index_path))

        # Now we should get empty dictionary since both models are invalid.
        model_paths = ml_train_task.get_last_saved_model(MODEL_DIR)
        expected = {}
        self.assertDictEqual(model_paths, expected)
Exemplo n.º 4
0
    def test_get_valid_model(self):
        """Test lastest model is not returned if it is invalid."""
        # Remove one file from model_2, so model_2 is not valid.
        os.remove(self.model_2_index_path)
        self.assertFalse(os.path.exists(self.model_2_index_path))

        # Now we should get model_1.
        model_paths = ml_train_task.get_last_saved_model(MODEL_DIR)
        expected = {
            'meta': self.model_1_meta_path,
            'data': self.model_1_data_path,
            'index': self.model_1_index_path
        }
        self.assertDictEqual(model_paths, expected)
Exemplo n.º 5
0
    def test_small_corpus(self):
        """Test small corpus situation."""
        # Increase batch size so the sample corpus appears small in this case.
        self.batch_size = 100

        result = ml_train_task.train_rnn(self.input_directory,
                                         self.model_directory,
                                         self.log_directory, self.batch_size,
                                         self.hidden_state_size,
                                         self.hidden_layer_size)

        self.assertEqual(result.return_code,
                         constants.ExitCode.CORPUS_TOO_SMALL)
        self.assertFalse(result.timed_out)

        # No model exsits after execution.
        self.assertFalse(
            ml_train_task.get_last_saved_model(self.model_directory))