def test_discrete_env(rate_power): env = GridWorld() counter = DiscreteCounter(env.observation_space, env.action_space, rate_power=rate_power) for N in range(10, 20): assert counter.get_n_visited_states() == 0 assert counter.get_entropy() == 0.0 for ss in range(env.observation_space.n): for aa in range(env.action_space.n): for _ in range(N): ns, rr, _, _ = env.sample(ss, aa) counter.update(ss, aa, ns, rr) assert counter.N_sa[ss, aa] == N assert counter.count(ss, aa) == N if rate_power == pytest.approx(1): assert np.allclose(counter.measure(ss, aa), 1.0 / N) elif rate_power == pytest.approx(0.5): assert np.allclose(counter.measure(ss, aa), np.sqrt(1.0 / N)) assert counter.get_n_visited_states() == env.observation_space.n assert np.allclose(counter.get_entropy(), np.log2(env.observation_space.n)) counter.reset()
def test_continuous_state_env(rate_power): env = MountainCar() counter = DiscreteCounter(env.observation_space, env.action_space, rate_power=rate_power) for N in [10, 20]: for _ in range(50): ss = env.observation_space.sample() aa = env.action_space.sample() for _ in range(N): ns, rr, _, _ = env.sample(ss, aa) counter.update(ss, aa, ns, rr) dss = counter.state_discretizer.discretize(ss) assert counter.N_sa[dss, aa] == N assert counter.count(ss, aa) == N if rate_power == pytest.approx(1): assert np.allclose(counter.measure(ss, aa), 1.0 / N) elif rate_power == pytest.approx(0.5): assert np.allclose(counter.measure(ss, aa), np.sqrt(1.0 / N)) counter.reset()