def test_collect_Q(): np.random.seed(88) mdp = GridWorld(3, 3, (2, 2)) eps = Parameter(0.1) pi = EpsGreedy(eps) alpha = Parameter(0.1) agent = SARSA(mdp.info, pi, alpha) callback_q = CollectQ(agent.Q) callback_max_q = CollectMaxQ(agent.Q, np.array([2])) core = Core(agent, mdp, callbacks=[callback_q, callback_max_q]) core.learn(n_steps=1000, n_steps_per_fit=1, quiet=True) V_test = np.array([2.4477574, 0.02246188, 1.6210059, 6.01867052]) V = callback_q.get()[-1] assert np.allclose(V[0, :], V_test) V_max = np.array([np.max(x[2, :], axis=-1) for x in callback_q.get()]) max_q = np.array(callback_max_q.get()) assert np.allclose(V_max, max_q)
def test_dataset_utils(): np.random.seed(88) mdp = GridWorld(3, 3, (2, 2)) epsilon = Parameter(value=0.) alpha = Parameter(value=0.) pi = EpsGreedy(epsilon=epsilon) agent = SARSA(mdp.info, pi, alpha) core = Core(agent, mdp) dataset = core.evaluate(n_episodes=10) J = compute_J(dataset, mdp.info.gamma) J_test = np.array([ 1.16106307e-03, 2.78128389e-01, 1.66771817e+00, 3.09031544e-01, 1.19725152e-01, 9.84770902e-01, 1.06111661e-02, 2.05891132e+00, 2.28767925e+00, 4.23911583e-01 ]) assert np.allclose(J, J_test) L = episodes_length(dataset) L_test = np.array([87, 35, 18, 34, 43, 23, 66, 16, 15, 31]) assert np.array_equal(L, L_test) dataset_ep = select_first_episodes(dataset, 3) J = compute_J(dataset_ep, mdp.info.gamma) assert np.allclose(J, J_test[:3]) L = episodes_length(dataset_ep) assert np.allclose(L, L_test[:3]) samples = select_random_samples(dataset, 2) s, a, r, ss, ab, last = parse_dataset(samples) s_test = np.array([[6.], [1.]]) a_test = np.array([[0.], [1.]]) r_test = np.zeros(2) ss_test = np.array([[3], [4]]) ab_test = np.zeros(2) last_test = np.zeros(2) assert np.array_equal(s, s_test) assert np.array_equal(a, a_test) assert np.array_equal(r, r_test) assert np.array_equal(ss, ss_test) assert np.array_equal(ab, ab_test) assert np.array_equal(last, last_test) index = np.sum(L_test[:2]) + L_test[2] // 2 min_J, max_J, mean_J, n_episodes = compute_metrics(dataset[:index], mdp.info.gamma) assert min_J == 0.0011610630703530948 assert max_J == 0.2781283894436937 assert mean_J == 0.1396447262570234 assert n_episodes == 2
def test_collect_dataset(): np.random.seed(88) callback = CollectDataset() mdp = GridWorld(4, 4, (2, 2)) eps = Parameter(0.1) pi = EpsGreedy(eps) alpha = Parameter(0.2) agent = SARSA(mdp.info, pi, alpha) core = Core(agent, mdp, callbacks=[callback]) core.learn(n_steps=10, n_steps_per_fit=1, quiet=True) dataset = callback.get() assert len(dataset) == 10 core.learn(n_steps=5, n_steps_per_fit=1, quiet=True) assert len(dataset) == 15 callback.clean() dataset = callback.get() assert len(dataset) == 0
def test_collect_parameter(): np.random.seed(88) mdp = GridWorld(3, 3, (2, 2)) eps = ExponentialParameter(value=1, exp=.5, size=mdp.info.observation_space.size) pi = EpsGreedy(eps) alpha = Parameter(0.1) agent = SARSA(mdp.info, pi, alpha) callback_eps = CollectParameters(eps, 1) core = Core(agent, mdp, callbacks=[callback_eps]) core.learn(n_steps=10, n_steps_per_fit=1, quiet=True) eps_test = np.array([ 1., 0.70710678, 0.70710678, 0.57735027, 0.57735027, 0.57735027, 0.57735027, 0.57735027, 0.57735027, 0.57735027 ]) eps = callback_eps.get() assert np.allclose(eps, eps_test)