def test_episodic_returns(size=2560): fn = BasePolicy.compute_episodic_return buf = ReplayBuffer(20) batch = Batch( done=np.array([1, 0, 0, 1, 0, 1, 0, 1.]), rew=np.array([0, 1, 2, 3, 4, 5, 6, 7.]), info=Batch({ 'TimeLimit.truncated': np.array([False, False, False, False, False, True, False, False]) })) for b in batch: b.obs = b.act = 1 buf.add(b) returns, _ = fn(batch, buf, buf.sample_indices(0), gamma=.1, gae_lambda=1) ans = np.array([0, 1.23, 2.3, 3, 4.5, 5, 6.7, 7]) assert np.allclose(returns, ans) buf.reset() batch = Batch( done=np.array([0, 1, 0, 1, 0, 1, 0.]), rew=np.array([7, 6, 1, 2, 3, 4, 5.]), ) for b in batch: b.obs = b.act = 1 buf.add(b) returns, _ = fn(batch, buf, buf.sample_indices(0), gamma=.1, gae_lambda=1) ans = np.array([7.6, 6, 1.2, 2, 3.4, 4, 5]) assert np.allclose(returns, ans) buf.reset() batch = Batch( done=np.array([0, 1, 0, 1, 0, 0, 1.]), rew=np.array([7, 6, 1, 2, 3, 4, 5.]), ) for b in batch: b.obs = b.act = 1 buf.add(b) returns, _ = fn(batch, buf, buf.sample_indices(0), gamma=.1, gae_lambda=1) ans = np.array([7.6, 6, 1.2, 2, 3.45, 4.5, 5]) assert np.allclose(returns, ans) buf.reset() batch = Batch( done=np.array([0, 0, 0, 1., 0, 0, 0, 1, 0, 0, 0, 1]), rew=np.array( [101, 102, 103., 200, 104, 105, 106, 201, 107, 108, 109, 202]), ) for b in batch: b.obs = b.act = 1 buf.add(b) v = np.array([2., 3., 4, -1, 5., 6., 7, -2, 8., 9., 10, -3]) returns, _ = fn(batch, buf, buf.sample_indices(0), v, gamma=0.99, gae_lambda=0.95) ground_truth = np.array([ 454.8344, 376.1143, 291.298, 200., 464.5610, 383.1085, 295.387, 201., 474.2876, 390.1027, 299.476, 202. ]) assert np.allclose(returns, ground_truth) buf.reset() batch = Batch(done=np.array([0, 0, 0, 1., 0, 0, 0, 1, 0, 0, 0, 1]), rew=np.array([ 101, 102, 103., 200, 104, 105, 106, 201, 107, 108, 109, 202 ]), info=Batch({ 'TimeLimit.truncated': np.array([ False, False, False, True, False, False, False, True, False, False, False, False ]) })) for b in batch: b.obs = b.act = 1 buf.add(b) v = np.array([2., 3., 4, -1, 5., 6., 7, -2, 8., 9., 10, -3]) returns, _ = fn(batch, buf, buf.sample_indices(0), v, gamma=0.99, gae_lambda=0.95) ground_truth = np.array([ 454.0109, 375.2386, 290.3669, 199.01, 462.9138, 381.3571, 293.5248, 199.02, 474.2876, 390.1027, 299.476, 202. ]) assert np.allclose(returns, ground_truth) if __name__ == '__main__': buf = ReplayBuffer(size) batch = Batch( done=np.random.randint(100, size=size) == 0, rew=np.random.random(size), ) for b in batch: b.obs = b.act = 1 buf.add(b) indices = buf.sample_indices(0) def vanilla(): return compute_episodic_return_base(batch, gamma=.1) def optimized(): return fn(batch, buf, indices, gamma=.1, gae_lambda=1.0) cnt = 3000 print('GAE vanilla', timeit(vanilla, setup=vanilla, number=cnt)) print('GAE optim ', timeit(optimized, setup=optimized, number=cnt))
def test_replaybuffer(size=10, bufsize=20): env = MyTestEnv(size) buf = ReplayBuffer(bufsize) buf.update(buf) assert str(buf) == buf.__class__.__name__ + '()' obs = env.reset() action_list = [1] * 5 + [0] * 10 + [1] * 10 for i, a in enumerate(action_list): obs_next, rew, done, info = env.step(a) buf.add( Batch(obs=obs, act=[a], rew=rew, done=done, obs_next=obs_next, info=info)) obs = obs_next assert len(buf) == min(bufsize, i + 1) assert buf.act.dtype == int assert buf.act.shape == (bufsize, 1) data, indices = buf.sample(bufsize * 2) assert (indices < len(buf)).all() assert (data.obs < size).all() assert (0 <= data.done).all() and (data.done <= 1).all() b = ReplayBuffer(size=10) # neg bsz should return empty index assert b.sample_indices(-1).tolist() == [] ptr, ep_rew, ep_len, ep_idx = b.add( Batch(obs=1, act=1, rew=1, done=1, obs_next='str', info={ 'a': 3, 'b': { 'c': 5.0 } })) assert b.obs[0] == 1 assert b.done[0] assert b.obs_next[0] == 'str' assert np.all(b.obs[1:] == 0) assert np.all(b.obs_next[1:] == np.array(None)) assert b.info.a[0] == 3 and b.info.a.dtype == int assert np.all(b.info.a[1:] == 0) assert b.info.b.c[0] == 5.0 and b.info.b.c.dtype == float assert np.all(b.info.b.c[1:] == 0.0) assert ptr.shape == (1, ) and ptr[0] == 0 assert ep_rew.shape == (1, ) and ep_rew[0] == 1 assert ep_len.shape == (1, ) and ep_len[0] == 1 assert ep_idx.shape == (1, ) and ep_idx[0] == 0 # test extra keys pop up, the buffer should handle it dynamically batch = Batch(obs=2, act=2, rew=2, done=0, obs_next="str2", info={ "a": 4, "d": { "e": -np.inf } }) b.add(batch) info_keys = ["a", "b", "d"] assert set(b.info.keys()) == set(info_keys) assert b.info.a[1] == 4 and b.info.b.c[1] == 0 assert b.info.d.e[1] == -np.inf # test batch-style adding method, where len(batch) == 1 batch.done = [1] batch.info.e = np.zeros([1, 4]) batch = Batch.stack([batch]) ptr, ep_rew, ep_len, ep_idx = b.add(batch, buffer_ids=[0]) assert ptr.shape == (1, ) and ptr[0] == 2 assert ep_rew.shape == (1, ) and ep_rew[0] == 4 assert ep_len.shape == (1, ) and ep_len[0] == 2 assert ep_idx.shape == (1, ) and ep_idx[0] == 1 assert set(b.info.keys()) == set(info_keys + ["e"]) assert b.info.e.shape == (b.maxsize, 1, 4) with pytest.raises(IndexError): b[22] # test prev / next assert np.all(b.prev(np.array([0, 1, 2])) == [0, 1, 1]) assert np.all(b.next(np.array([0, 1, 2])) == [0, 2, 2]) batch.done = [0] b.add(batch, buffer_ids=[0]) assert np.all(b.prev(np.array([0, 1, 2, 3])) == [0, 1, 1, 3]) assert np.all(b.next(np.array([0, 1, 2, 3])) == [0, 2, 2, 3])