Beispiel #1
0
 def test_ModelHandler_evaluation(self):
     """It has an evaluation method that can return the output of the RNN for a given name tensor."""
     handler = app.ModelHandler(TestModelHandler.dummy_data,
                                utils.TextFileLoader.all_letters, 100)
     name = utils.word_to_tensor('Steve')
     result = handler._evaluate(name)
     assert tuple(result.size()) == (1, TestModelHandler.n_categories)
Beispiel #2
0
 def test_ModelHandler_train_iteration(self):
     """It has a training iteration method that can take an output and input (category and name) tensor,
     then use them to perform backprop, updating the RNN parameters with gradient descent, and then returning
     the output tensor and loss.
     """
     handler = app.ModelHandler(TestModelHandler.dummy_data,
                                utils.TextFileLoader.all_letters, 100)
     result = handler._train_iteration(learning_rate=0.1)
     assert tuple(result[0].size()) == (1, TestModelHandler.n_categories)
Beispiel #3
0
    def test_ModelHandler_can_determine_most_likely_category(self):
        """It has a helper function that can return the most likely category of an output, and its index."""
        tensor = torch.tensor([[1.1, 2.7]])

        handler = app.ModelHandler(TestModelHandler.dummy_data,
                                   utils.TextFileLoader.all_letters, 100)
        results = handler._most_likely_category(tensor)
        assert results[0] == 'English'
        assert results[1] == 1
Beispiel #4
0
 def test_ModelHandler_training(self):
     """It has a train method that can continually perform backprop and update the parameters of the model."""
     iterations = 4
     handler = app.ModelHandler(TestModelHandler.dummy_data,
                                utils.TextFileLoader.all_letters, 100)
     result = handler.train(n_iter=iterations,
                            learning_rate=0.1,
                            output_losses=True)
     assert len(result) == iterations
Beispiel #5
0
 def test_ModelHandler_can_pick_random_training_sample(self):
     """It has a helper function that can return a random training sample. Return should be that training
     sample's category, name, input tensor representation and output representation.
     """
     handler = app.ModelHandler(TestModelHandler.dummy_data,
                                utils.TextFileLoader.all_letters, 100)
     result = handler._random_training_sample()
     assert result[0] in TestModelHandler.categories
     assert result[1] in TestModelHandler.names
     assert tuple(result[2].size()) == (1, )
     assert tuple(result[3].size())[1] == 1 and tuple(
         result[3].size())[2] == len(handler.letters)
Beispiel #6
0
 def test_ModelHandler_prediction(self):
     """It has a prediction method that can evaluate the output of the RNN for a given name, then prints
     the specified top prediction values and categories, and returns the top predictions as a list of
     (value, category) tuples.
     """
     predictions = 2
     handler = app.ModelHandler(TestModelHandler.dummy_data,
                                utils.TextFileLoader.all_letters, 100)
     result = handler.predict('Steve', top_predictions=predictions)
     assert len(result) == predictions
     for r in result:
         assert type(r[0]) == float
         assert r[1] in handler.categories
Beispiel #7
0
    def test_ModelHandler_can_load_model_and_data(self):
        """It takes data dict stores an RNN and language->names data dict."""

        n_hidden = 100
        handler = app.ModelHandler(TestModelHandler.dummy_data,
                                   utils.TextFileLoader.all_letters, n_hidden)

        assert isinstance(handler.data, dict)
        assert handler.categories == tuple(TestModelHandler.dummy_data.keys())
        assert isinstance(handler.rnn, app.RNN)
        assert isinstance(handler.letters, str)

        assert handler.rnn.input_size == len(utils.TextFileLoader.all_letters)
        assert handler.rnn.hidden_size == n_hidden
        assert handler.rnn.output_size == len(handler.data.keys())