Beispiel #1
0
    def test_find_strength_diff(self):
        bt = BradleyTerry()
        bt.is_fitted = True
        bt.target_col_name = 'result'
        bt.lkp = {0: 'A', 1: 'B', 2: 'C', 3: 'D'}
        bt.rplc_lkp = {'A': 0, 'B': 1, 'C': 2, 'D': 3}
        bt._params = np.array([0.3, 0.2, -0.2, -0.4])
        bt.params_ = np.array([0.3, 0.2, -0.2, -0.4])

        calc_strength = bt.find_strength_diff(INDEXED_DATA_NORESCOL)
        correct_strength = np.array([-0.4, -0.1, 0.2, -0.2])
        np.testing.assert_array_almost_equal(calc_strength,
                                             correct_strength,
                                             decimal=10)
Beispiel #2
0
    def test_check_for_no_new_entities(self):
        bt = BradleyTerry()
        bt.is_fitted = True
        bt.target_col_name = 'result'
        bt.lkp = {1: 'A', 2: 'B', 3: 'C', 4: 'D'}
        wrong_data = pd.DataFrame({
            'ent1': ['K', 'A'],
            'ent2': ['A,', 'B'],
            'result': [1, 0]
        })
        wrong_data = wrong_data.set_index(['ent1', 'ent2'])
        with self.assertRaises(Exception):
            bt.check_for_no_new_entities(wrong_data)

        try:
            bt.check_for_no_new_entities(TRANSITIVE_DATA_INDEXED)
        except:
            self.fail("check_for_no_new_entities failed unexpectedly")
Beispiel #3
0
    def test_predict_proba(self):
        bt = BradleyTerry()
        bt.is_fitted = True
        bt.target_col_name = 'result'
        bt.lkp = {0: 'A', 1: 'B', 2: 'C', 3: 'D'}
        bt.rplc_lkp = {'A': 0, 'B': 1, 'C': 2, 'D': 3}
        bt._params = np.array([0.3, 0.2, -0.2, -0.4])
        bt.params_ = np.array([0.3, 0.2, -0.2, -0.4])
        bt.pylogit_fit = False

        def exp_func(x):
            return 1 / (1 + np.exp(-x))

        pred_probs = bt.predict_proba(INDEXED_DATA_NORESCOL)
        corect_probs = np.array(
            [exp_func(-0.4),
             exp_func(-0.1),
             exp_func(0.2),
             exp_func(-0.2)])

        np.testing.assert_array_equal(pred_probs, corect_probs)