コード例 #1
0
    def test_update_dictionary_correct_shorten_dictionary(self):
        model = Multinomial()
        model.train(self.sports_labels, self.sports_data)
        model.update_dictionary(self.shortened_dictionary)
        prediction, score = model.predict(self.a_very_close_game)

        # Correct Model Parameter Updates
        self.assertEqual(model.priors, self.correct_priors)
        self.assertEqual(model.label_counts, self.correct_label_count)
        self.assertEqual(model.empty_likelihoods,
                         self.correct_shortened_empty_likelihoods)
        self.assertDictEqual(model.likelihoods,
                             self.correct_shortened_likelihoods)

        # Correct Scores
        self.assertIsNotNone(score)
        self.assertAlmostEqual(
            score[0]["sport"],
            self.correct_shortened_a_very_close_game_score["sport"])
        self.assertAlmostEqual(
            score[0]["not sport"],
            self.correct_shortened_a_very_close_game_score["not sport"])

        # Correct Prediction
        self.assertEqual(prediction[0], "sport")
コード例 #2
0
    def test_train_model_params(self):
        model = Multinomial()
        model.train(self.sports_labels, self.sports_data)

        # Correct Model
        self.assertEqual(model.priors, self.correct_priors)
        self.assertEqual(model.label_counts, self.correct_label_count)
        self.assertEqual(model.empty_likelihoods,
                         self.correct_empty_likelihoods)
        self.assertDictEqual(model.likelihoods, self.correct_likelihoods)
コード例 #3
0
    def test_predict_prediction_and_score(self):
        model = Multinomial()
        model.train(self.sports_labels, self.sports_data)
        prediction, score = model.predict(self.a_very_close_game)

        # Correct Prediction Output
        self.assertEqual(prediction[0], "sport")

        # Correct Score Output
        self.assertIsNotNone(score)
        self.assertAlmostEqual(score[0]["sport"],
                               self.correct_a_very_close_game_score["sport"])
        self.assertAlmostEqual(
            score[0]["not sport"],
            self.correct_a_very_close_game_score["not sport"])
コード例 #4
0
    def test_update_add_more_training_data(self):
        model = Multinomial()
        model.train(self.sports_labels[0:4], self.sports_data[0:4])
        model.update([self.sports_labels[4]], [self.sports_data[4]])
        prediction, score = model.predict(self.a_very_close_game)

        # Correct Model
        self.assertEqual(model.priors, self.correct_priors)
        self.assertEqual(model.label_counts, self.correct_label_count)
        self.assertEqual(model.empty_likelihoods,
                         self.correct_empty_likelihoods)
        self.assertDictEqual(model.likelihoods, self.correct_likelihoods)

        # Correct Score Output
        self.assertIsNotNone(score)
        self.assertAlmostEqual(score[0]["sport"],
                               self.correct_a_very_close_game_score["sport"])
        self.assertAlmostEqual(
            score[0]["not sport"],
            self.correct_a_very_close_game_score["not sport"])

        # Correct Prediction Output
        self.assertEqual(prediction[0], "sport")
コード例 #5
0
 def test_update_dictionary_new_dictionary_is_empty(self):
     model = Multinomial()
     model.train(self.sports_labels, self.sports_data)
     self.assertRaises(TypeError, model.update_dictionary, {})
コード例 #6
0
 def test_update_dictionary_new_dictionary_does_not_contain_strings(self):
     model = Multinomial()
     model.train(self.sports_labels, self.sports_data)
     self.assertRaises(TypeError, model.update_dictionary, {0})
コード例 #7
0
 def test_update_dictionary_new_dictionary_is_not_set(self):
     model = Multinomial()
     model.train(self.sports_labels, self.sports_data)
     self.assertRaises(TypeError, model.update_dictionary,
                       list(self.extended_dictionary))
コード例 #8
0
 def test_update_number_of_labels_and_docs_differ(self):
     model = Multinomial()
     model.train(self.sports_labels, self.sports_data)
     self.assertRaises(ValueError, model.update, self.sports_labels[0:4],
                       self.sports_data)
コード例 #9
0
 def test_update_training_data_does_not_contains_lists_of_strs(self):
     model = Multinomial()
     model.train(self.sports_labels, self.sports_data)
     self.assertRaises(TypeError, model.update, self.sports_labels,
                       map(lambda x: [0], self.sports_data))
コード例 #10
0
 def test_update_training_data_is_not_in_a_list(self):
     model = Multinomial()
     model.train(self.sports_labels, self.sports_data)
     self.assertRaises(TypeError, model.update, self.sports_labels, set())
コード例 #11
0
 def test_update_labels_are_not_str_or_int(self):
     model = Multinomial()
     model.train(self.sports_labels, self.sports_data)
     self.assertRaises(TypeError, model.update,
                       list(map(lambda x: None, self.sports_labels)),
                       self.sports_data)
コード例 #12
0
 def test_predict_test_data_does_not_contains_lists(self):
     model = Multinomial()
     model.train(self.sports_labels, self.sports_data)
     self.assertRaises(TypeError, model.predict,
                       map(lambda x: set(), self.a_very_close_game))
コード例 #13
0
 def test_predict_test_data_is_not_in_a_list(self):
     model = Multinomial()
     model.train(self.sports_labels, self.sports_data)
     self.assertRaises(TypeError, model.predict, set())
コード例 #14
0
with open('./sample_data/nltk/wine/wine_data.pkl', 'rb') as f:
    wine_data = pickle.load(f)

    ####################
    # Dictionary Model #
    ####################

    wine_bow_data = wine_data["bagofwords"]

    raw_data_top_removed = wine_bow_data["raw_data_top_removed"]
    raw_data = wine_bow_data["raw_data"]
    labels = wine_bow_data["labels"]

    model_dict = DictMultinomial()
    model_dict.train(labels, raw_data_top_removed)
    dict_predictions, dict_scores = model_dict.predict(raw_data)

    matches = 0
    for i in range(0, len(labels)):
        if dict_predictions[i] == labels[i]:
            matches += 1

    print("- Dictionary")
    print("Accuracy: " + str(matches / len(labels)))

    ################
    # Vector Model #
    ################

    wine_vector_data = wine_data["vectors"]