def test_predict_epsilon_not_adf_args_error_2(self): learner = VowpalLearner("--cb_explore --epsilon 0.75 --random_seed 20") self.assertEqual( [0.25 + 0.25 * 0.75, 0.25 * 0.75, 0.25 * 0.75, 0.25 * 0.75], learner.predict(1, None, [1, 2, 3, 4])) with self.assertRaises(Exception) as e: self.assertEqual( [0.25 + 0.25 * 0.75, 0.25 * 0.75, 0.25 * 0.75, 0.25 * 0.75], learner.predict(1, None, [1, 2, 3])) self.assertTrue("--cb_explore_adf" in str(e.exception))
def test_predict_epsilon_dict_context_adf(self): learner = VowpalLearner(epsilon=0.05, adf=True, seed=20) self.assertEqual([0.25, 0.25, 0.25, 0.25], learner.predict(1, { 1: 10.2, 2: 3.5 }, [1, 2, 3, 4])) #type: ignore
def test_predict_epsilon_not_adf(self): learner = VowpalLearner(epsilon=0.75, is_adf=False, seed=30) self.assertEqual([0.25+0.25*0.75,0.25*0.75,0.25*0.75,0.25*0.75],learner.predict(1, None, [1,2,3,4]))
def test_predict_epsilon_adf(self): learner = VowpalLearner(epsilon=0.05, is_adf=True, seed=20) self.assertEqual([0.25,0.25,0.25,0.25],learner.predict(1, None, [1,2,3,4]))
def test_predict_cover_not_adf(self): learner = VowpalLearner(cover=5, seed=30) self.assertEqual([1,0,0,0], learner.predict(1, None, [1,2,3,4]))
def test_predict_bag_not_adf(self): learner = VowpalLearner(bag=5, is_adf=False, seed=30) self.assertEqual([1,0,0,0], learner.predict(1, None, [1,2,3,4]))
def test_predict_bag_adf(self): learner = VowpalLearner(bag=5, is_adf=True, seed=30) self.assertEqual([0.25,0.25,0.25,0.25],learner.predict(1, None, [1,2,3,4]))
def test_predict_bag_adf(self): learner = VowpalLearner(bag=5, adf=True, seed=30) self.assertEqual([0.25, 0.25, 0.25, 0.25], learner.predict(1, None, ['1', '2', '3', '4']))
def test_predict_epsilon_not_adf_args(self): learner = VowpalLearner( "--cb_explore 20 --epsilon 0.75 --random_seed 20") self.assertEqual( [0.25 + 0.25 * 0.75, 0.25 * 0.75, 0.25 * 0.75, 0.25 * 0.75], learner.predict(1, None, [1, 2, 3, 4]))
def test_predict_epsilon_tuple_context_adf(self): learner = VowpalLearner(epsilon=0.05, adf=True, seed=20) self.assertEqual([0.25, 0.25, 0.25, 0.25], learner.predict(1, ((1, 2), (10.2, 3.5)), [1, 2, 3, 4])) #type: ignore