def test_simple_run(self): means = [0.3, 0.5, 0.7] arms = [BernoulliArm(mean) for mean in means] thres_bandit = ThresholdingBandit(arms=arms, theta=0.5, eps=0) thres_bandit.reset() assert thres_bandit.regret(MaxCorrectAnswers(answers=[0, 1, 1])) == 0 assert thres_bandit.regret(AllCorrect(answers=[0, 1, 0])) == 1
def test_simple_run(self): means = [0, 0.5, 0.7, 1] arms = [BernoulliArm(mean) for mean in means] learner = EpsGreedy(arm_num=len(arms)) learner.reset() # Pull each arm once during the initial steps for time in range(1, len(arms) + 1): assert learner.actions( Context()).SerializeToString() == text_format.Parse( """ arm_pulls < arm < id: {arm_id} > times: 1 > """.format(arm_id=time - 1), Actions()).SerializeToString() learner.update( text_format.Parse( """ arm_feedbacks < arm < id: {arm_id} > rewards: 0 > """.format(arm_id=time - 1), Feedback()))
def test_simple_run(self): means = [0, 1] arms = [BernoulliArm(mean) for mean in means] ordinary_bandit = OrdinaryBandit(arms) ordinary_bandit.reset() # pull arm 0 for 100 times ordinary_bandit.feed([(0, 100)]) assert ordinary_bandit.regret(MaxReward()) == 100 assert ordinary_bandit.regret(BestArmId(best_arm=1)) == 0
def test_simple_run(self): means = [0, 0.5, 0.7, 1] arms = [BernoulliArm(mean) for mean in means] learner = EpsGreedy(arm_num=len(arms), horizon=10) learner.reset() for arm_id in range(len(arms)): assert learner.actions() == [(arm_id, 1)] learner.update(([np.array([0])], ))
def test_simple_run(self): means = [0.3, 0.5, 0.7] arms = [BernoulliArm(mean) for mean in means] ordinary_bandit = OrdinaryBandit(arms) eps_greedy_learner = EpsGreedy(arm_num=3, horizon=10) single_player = SinglePlayerProtocol(bandit=ordinary_bandit, learners=[eps_greedy_learner]) temp_file = tempfile.NamedTemporaryFile() single_player.play(trials=3, output_filename=temp_file.name) with open(temp_file.name, 'r') as f: # check number of records is 3 lines = f.readlines() assert len(lines) == 3
def test_simple_run(self): means = [0.3, 0.5, 0.7] arms = [BernoulliArm(mean) for mean in means] ordinary_bandit = MultiArmedBandit(arms) eps_greedy_learner = EpsGreedy(arm_num=3) single_player = SinglePlayerProtocol(bandit=ordinary_bandit, learners=[eps_greedy_learner]) temp_file = tempfile.NamedTemporaryFile() single_player.play(3, temp_file.name, horizon=10) with open(temp_file.name, 'rb') as f: # check number of records is 3 trials = parse_trials_from_bytes(f.read()) assert len(trials) == 3
def test_simple_run(self): means = [0, 1] arms = [BernoulliArm(mean) for mean in means] ordinary_bandit = MultiArmedBandit(arms) ordinary_bandit.reset() # Pull arm 0 for 100 times actions = text_format.Parse( """ arm_pulls { arm { id: 0 } times: 100 } """, Actions()) ordinary_bandit.feed(actions) assert ordinary_bandit.regret(MaximizeTotalRewards()) == 100 arm = Arm() arm.id = 1 assert ordinary_bandit.regret(IdentifyBestArm(best_arm=arm)) == 0