def test_kl_upper_bound(): assert kl_upper_bound(0.5 * 1, 1, 10, c=1, eps=1e-3) == pytest.approx(0.997, abs=1e-3) assert kl_upper_bound(0.5 * 10, 10, 20, c=1, eps=1e-3) == pytest.approx(0.835, abs=1e-3) assert kl_upper_bound(0.5 * 20, 20, 40, c=1, eps=1e-3) == pytest.approx(0.777, abs=1e-3) rands = np.random.randint(1, 500, 2) rands.sort() mu, count, time = np.random.random(), rands[0], rands[1] ucb = kl_upper_bound(mu * count, count, time, c=1, eps=1e-3) assert not np.isnan(ucb) d_max = 1 * np.log(time) / count assert bernoulli_kullback_leibler(mu, ucb) == pytest.approx(d_max, abs=1e-2)
def compute_ucb(self): if self.planner.config["upper_bound"]["time"] == "local": time = self.planner.episode + 1 elif self.planner.config["upper_bound"]["time"] == "global": time = self.planner.config["episodes"] else: logger.error("Unknown upper-bound time reference") if self.planner.config["upper_bound"]["type"] == "hoeffding": self.mu_ucb = hoeffding_upper_bound( self.cumulative_reward, self.count, time, c=self.planner.config["upper_bound"]["c"]) elif self.planner.config["upper_bound"]["type"] == "laplace": self.mu_ucb = laplace_upper_bound( self.cumulative_reward, self.count, time, c=self.planner.config["upper_bound"]["c"]) elif self.planner.config["upper_bound"]["type"] == "kullback-leibler": self.mu_ucb = kl_upper_bound( self.cumulative_reward, self.count, time, c=self.planner.config["upper_bound"]["c"]) else: logger.error("Unknown upper-bound type")
def compute_reward_ucb(self): if self.planner.config["upper_bound"]["type"] == "kullback-leibler": # Variables available for threshold evaluation horizon = self.planner.config["horizon"] actions = self.planner.env.action_space.n confidence = self.planner.config["confidence"] count = self.count time = self.planner.config["episodes"] threshold = eval(self.planner.config["upper_bound"]["threshold"]) self.mu_ucb = kl_upper_bound(self.cumulative_reward, self.count, threshold) self.mu_lcb = kl_upper_bound(self.cumulative_reward, self.count, threshold, lower=True) else: logger.error("Unknown upper-bound type")