class TestEpsilonGreedyPolicy: def setup_method(self): self.env = DummyDiscreteEnv() self.policy = SimplePolicy(env_spec=self.env) self.epsilon_greedy_policy = EpsilonGreedyPolicy(env_spec=self.env, policy=self.policy, total_timesteps=100, max_epsilon=1.0, min_epsilon=0.02, decay_ratio=0.1) self.env.reset() def test_epsilon_greedy_policy(self): obs, _, _, _ = self.env.step(1) action, _ = self.epsilon_greedy_policy.get_action(obs) assert self.env.action_space.contains(action) # epsilon decay by 1 step, new epsilon = 1 - 0.098 = 0.902 random_rate = np.random.random( 100000) < self.epsilon_greedy_policy._epsilon() assert np.isclose([0.902], [sum(random_rate) / 100000], atol=0.01) actions, _ = self.epsilon_greedy_policy.get_actions([obs] * 5) # epsilon decay by 6 steps in total # new epsilon = 1 - 6 * 0.098 = 0.412 random_rate = np.random.random( 100000) < self.epsilon_greedy_policy._epsilon() assert np.isclose([0.412], [sum(random_rate) / 100000], atol=0.01) for action in actions: assert self.env.action_space.contains(action) def test_set_param(self): params = self.epsilon_greedy_policy.get_param_values() params['total_env_steps'] = 6 self.epsilon_greedy_policy.set_param_values(params) assert np.isclose(self.epsilon_greedy_policy._epsilon(), 0.412) def test_update(self): DummyBatch = collections.namedtuple('EpisodeBatch', ['lengths']) batch = DummyBatch(np.array([1, 2, 3])) self.epsilon_greedy_policy.update(batch) assert np.isclose(self.epsilon_greedy_policy._epsilon(), 0.412) def test_epsilon_greedy_policy_is_pickleable(self): obs, _, _, _ = self.env.step(1) for _ in range(5): self.epsilon_greedy_policy.get_action(obs) h_data = pickle.dumps(self.epsilon_greedy_policy) policy = pickle.loads(h_data) assert policy._epsilon() == self.epsilon_greedy_policy._epsilon()
class TestEpsilonGreedyPolicy: def setup_method(self): self.env = DummyDiscreteEnv() self.policy = SimplePolicy(env_spec=self.env) self.epsilon_greedy_policy = EpsilonGreedyPolicy(env_spec=self.env, policy=self.policy, total_timesteps=100, max_epsilon=1.0, min_epsilon=0.02, decay_ratio=0.1) self.env.reset() def test_epsilon_greedy_policy(self): obs, _, _, _ = self.env.step(1) action, _ = self.epsilon_greedy_policy.get_action(obs) assert self.env.action_space.contains(action) # epsilon decay by 1 step, new epsilon = 1 - 0.98 = 0.902 random_rate = np.random.random( 100000) < self.epsilon_greedy_policy._epsilon assert np.isclose([0.902], [sum(random_rate) / 100000], atol=0.01) actions, _ = self.epsilon_greedy_policy.get_actions([obs] * 5) # epsilon decay by 6 steps in total, new epsilon = 1 - 6 * 0.98 = 0.412 random_rate = np.random.random( 100000) < self.epsilon_greedy_policy._epsilon assert np.isclose([0.412], [sum(random_rate) / 100000], atol=0.01) for action in actions: assert self.env.action_space.contains(action) def test_epsilon_greedy_policy_is_pickleable(self): obs, _, _, _ = self.env.step(1) for _ in range(5): self.epsilon_greedy_policy.get_action(obs) h_data = pickle.dumps(self.epsilon_greedy_policy) policy = pickle.loads(h_data) assert policy._epsilon == self.epsilon_greedy_policy._epsilon