def test_multibuf_stack(): size = 5 bufsize = 9 stack_num = 4 cached_num = 3 env = MyTestEnv(size) # test if CachedReplayBuffer can handle stack_num + ignore_obs_next buf4 = CachedReplayBuffer( ReplayBuffer(bufsize, stack_num=stack_num, ignore_obs_next=True), cached_num, size) # test if CachedReplayBuffer can handle corner case: # buffer + stack_num + ignore_obs_next + sample_avail buf5 = CachedReplayBuffer( ReplayBuffer(bufsize, stack_num=stack_num, ignore_obs_next=True, sample_avail=True), cached_num, size) obs = env.reset(1) for i in range(18): obs_next, rew, done, info = env.step(1) obs_list = np.array([obs + size * i for i in range(cached_num)]) act_list = [1] * cached_num rew_list = [rew] * cached_num done_list = [done] * cached_num obs_next_list = -obs_list info_list = [info] * cached_num batch = Batch(obs=obs_list, act=act_list, rew=rew_list, done=done_list, obs_next=obs_next_list, info=info_list) buf5.add(batch) buf4.add(batch) assert np.all(buf4.obs == buf5.obs) assert np.all(buf4.done == buf5.done) obs = obs_next if done: obs = env.reset(1) # check the `add` order is correct assert np.allclose( buf4.obs.reshape(-1), [ 12, 13, 14, 4, 6, 7, 8, 9, 11, # main_buffer 1, 2, 3, 4, 0, # cached_buffer[0] 6, 7, 8, 9, 0, # cached_buffer[1] 11, 12, 13, 14, 0, # cached_buffer[2] ]), buf4.obs assert np.allclose( buf4.done, [ 0, 0, 1, 1, 0, 0, 0, 1, 0, # main_buffer 0, 0, 0, 1, 0, # cached_buffer[0] 0, 0, 0, 1, 0, # cached_buffer[1] 0, 0, 0, 1, 0, # cached_buffer[2] ]), buf4.done assert np.allclose(buf4.unfinished_index(), [10, 15, 20]) indices = sorted(buf4.sample_indices(0)) assert np.allclose(indices, list(range(bufsize)) + [9, 10, 14, 15, 19, 20]) assert np.allclose(buf4[indices].obs[..., 0], [ [11, 11, 11, 12], [11, 11, 12, 13], [11, 12, 13, 14], [4, 4, 4, 4], [6, 6, 6, 6], [6, 6, 6, 7], [6, 6, 7, 8], [6, 7, 8, 9], [11, 11, 11, 11], [1, 1, 1, 1], [1, 1, 1, 2], [6, 6, 6, 6], [6, 6, 6, 7], [11, 11, 11, 11], [11, 11, 11, 12], ]) assert np.allclose(buf4[indices].obs_next[..., 0], [ [11, 11, 12, 13], [11, 12, 13, 14], [11, 12, 13, 14], [4, 4, 4, 4], [6, 6, 6, 7], [6, 6, 7, 8], [6, 7, 8, 9], [6, 7, 8, 9], [11, 11, 11, 12], [1, 1, 1, 2], [1, 1, 1, 2], [6, 6, 6, 7], [6, 6, 6, 7], [11, 11, 11, 12], [11, 11, 11, 12], ]) indices = buf5.sample_indices(0) assert np.allclose(sorted(indices), [2, 7]) assert np.all(np.isin(buf5.sample_indices(100), indices)) # manually change the stack num buf5.stack_num = 2 for buf in buf5.buffers: buf.stack_num = 2 indices = buf5.sample_indices(0) assert np.allclose(sorted(indices), [0, 1, 2, 5, 6, 7, 10, 15, 20]) batch, _ = buf5.sample(0) # test Atari with CachedReplayBuffer, save_only_last_obs + ignore_obs_next buf6 = CachedReplayBuffer( ReplayBuffer(bufsize, stack_num=stack_num, save_only_last_obs=True, ignore_obs_next=True), cached_num, size) obs = np.random.rand(size, 4, 84, 84) buf6.add(Batch(obs=[obs[2], obs[0]], act=[1, 1], rew=[0, 0], done=[0, 1], obs_next=[obs[3], obs[1]]), buffer_ids=[1, 2]) assert buf6.obs.shape == (buf6.maxsize, 84, 84) assert np.allclose(buf6.obs[0], obs[0, -1]) assert np.allclose(buf6.obs[14], obs[2, -1]) assert np.allclose(buf6.obs[19], obs[0, -1]) assert buf6[0].obs.shape == (4, 84, 84)
def test_cachedbuffer(): buf = CachedReplayBuffer(ReplayBuffer(10), 4, 5) assert buf.sample_indices(0).tolist() == [] # check the normal function/usage/storage in CachedReplayBuffer ptr, ep_rew, ep_len, ep_idx = buf.add(Batch(obs=[1], act=[1], rew=[1], done=[0]), buffer_ids=[1]) obs = np.zeros(buf.maxsize) obs[15] = 1 indices = buf.sample_indices(0) assert np.allclose(indices, [15]) assert np.allclose(buf.prev(indices), [15]) assert np.allclose(buf.next(indices), [15]) assert np.allclose(buf.obs, obs) assert np.all(ep_len == [0]) and np.all(ep_rew == [0.0]) assert np.all(ptr == [15]) and np.all(ep_idx == [15]) ptr, ep_rew, ep_len, ep_idx = buf.add(Batch(obs=[2], act=[2], rew=[2], done=[1]), buffer_ids=[3]) obs[[0, 25]] = 2 indices = buf.sample_indices(0) assert np.allclose(indices, [0, 15]) assert np.allclose(buf.prev(indices), [0, 15]) assert np.allclose(buf.next(indices), [0, 15]) assert np.allclose(buf.obs, obs) assert np.all(ep_len == [1]) and np.all(ep_rew == [2.0]) assert np.all(ptr == [0]) and np.all(ep_idx == [0]) assert np.allclose(buf.unfinished_index(), [15]) assert np.allclose(buf.sample_indices(0), [0, 15]) ptr, ep_rew, ep_len, ep_idx = buf.add(Batch(obs=[3, 4], act=[3, 4], rew=[3, 4], done=[0, 1]), buffer_ids=[3, 1]) assert np.all(ep_len == [0, 2]) and np.all(ep_rew == [0, 5.0]) assert np.all(ptr == [25, 2]) and np.all(ep_idx == [25, 1]) obs[[0, 1, 2, 15, 16, 25]] = [2, 1, 4, 1, 4, 3] assert np.allclose(buf.obs, obs) assert np.allclose(buf.unfinished_index(), [25]) indices = buf.sample_indices(0) assert np.allclose(indices, [0, 1, 2, 25]) assert np.allclose(buf.done[indices], [1, 0, 1, 0]) assert np.allclose(buf.prev(indices), [0, 1, 1, 25]) assert np.allclose(buf.next(indices), [0, 2, 2, 25]) indices = buf.sample_indices(10000) assert np.bincount(indices)[[0, 1, 2, 25]].min() > 2000 # uniform sample # cached buffer with main_buffer size == 0 (no update) # used in test_collector buf = CachedReplayBuffer(ReplayBuffer(0, sample_avail=True), 4, 5) data = np.zeros(4) rew = np.ones([4, 4]) buf.add(Batch(obs=data, act=data, rew=rew, done=[0, 0, 1, 1])) buf.add(Batch(obs=data, act=data, rew=rew, done=[0, 0, 0, 0])) buf.add(Batch(obs=data, act=data, rew=rew, done=[1, 1, 1, 1])) buf.add(Batch(obs=data, act=data, rew=rew, done=[0, 0, 0, 0])) ptr, ep_rew, ep_len, ep_idx = buf.add( Batch(obs=data, act=data, rew=rew, done=[0, 1, 0, 1])) assert np.all(ptr == [1, -1, 11, -1]) and np.all(ep_idx == [0, -1, 10, -1]) assert np.all(ep_len == [0, 2, 0, 2]) assert np.all(ep_rew == [data, data + 2, data, data + 2]) assert np.allclose(buf.done, [ 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, ]) indices = buf.sample_indices(0) assert np.allclose(indices, [0, 1, 10, 11]) assert np.allclose(buf.prev(indices), [0, 0, 10, 10]) assert np.allclose(buf.next(indices), [1, 1, 11, 11])