def test_cachedbuffer(): buf = CachedReplayBuffer(ReplayBuffer(10), 4, 5) assert buf.sample_index(0).tolist() == [] # check the normal function/usage/storage in CachedReplayBuffer ep_len, ep_rew = buf.add(obs=[1], act=[1], rew=[1], done=[0], cached_buffer_ids=[1]) obs = np.zeros(buf.maxsize) obs[15] = 1 indice = buf.sample_index(0) assert np.allclose(indice, [15]) assert np.allclose(buf.prev(indice), [15]) assert np.allclose(buf.next(indice), [15]) assert np.allclose(buf.obs, obs) assert np.allclose(ep_len, [0]) and np.allclose(ep_rew, [0.0]) ep_len, ep_rew = buf.add(obs=[2], act=[2], rew=[2], done=[1], cached_buffer_ids=[3]) obs[[0, 25]] = 2 indice = buf.sample_index(0) assert np.allclose(indice, [0, 15]) assert np.allclose(buf.prev(indice), [0, 15]) assert np.allclose(buf.next(indice), [0, 15]) assert np.allclose(buf.obs, obs) assert np.allclose(ep_len, [1]) and np.allclose(ep_rew, [2.0]) assert np.allclose(buf.unfinished_index(), [15]) assert np.allclose(buf.sample_index(0), [0, 15]) ep_len, ep_rew = buf.add(obs=[3, 4], act=[3, 4], rew=[3, 4], done=[0, 1], cached_buffer_ids=[3, 1]) assert np.allclose(ep_len, [0, 2]) and np.allclose(ep_rew, [0, 5.0]) obs[[0, 1, 2, 15, 16, 25]] = [2, 1, 4, 1, 4, 3] assert np.allclose(buf.obs, obs) assert np.allclose(buf.unfinished_index(), [25]) indice = buf.sample_index(0) assert np.allclose(indice, [0, 1, 2, 25]) assert np.allclose(buf.done[indice], [1, 0, 1, 0]) assert np.allclose(buf.prev(indice), [0, 1, 1, 25]) assert np.allclose(buf.next(indice), [0, 2, 2, 25]) indice = buf.sample_index(10000) assert np.bincount(indice)[[0, 1, 2, 25]].min() > 2000 # uniform sample # cached buffer with main_buffer size == 0 (no update) # used in test_collector buf = CachedReplayBuffer(ReplayBuffer(0, sample_avail=True), 4, 5) data = np.zeros(4) rew = np.ones([4, 4]) buf.add(obs=data, act=data, rew=rew, done=[0, 0, 1, 1], obs_next=data) buf.add(obs=data, act=data, rew=rew, done=[0, 0, 0, 0], obs_next=data) buf.add(obs=data, act=data, rew=rew, done=[1, 1, 1, 1], obs_next=data) buf.add(obs=data, act=data, rew=rew, done=[0, 0, 0, 0], obs_next=data) buf.add(obs=data, act=data, rew=rew, done=[0, 1, 0, 1], obs_next=data) assert np.allclose(buf.done, [ 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, ]) indice = buf.sample_index(0) assert np.allclose(indice, [0, 1, 10, 11]) assert np.allclose(buf.prev(indice), [0, 0, 10, 10]) assert np.allclose(buf.next(indice), [1, 1, 11, 11])
def test_multibuf_stack(): size = 5 bufsize = 9 stack_num = 4 cached_num = 3 env = MyTestEnv(size) # test if CachedReplayBuffer can handle stack_num + ignore_obs_next buf4 = CachedReplayBuffer( ReplayBuffer(bufsize, stack_num=stack_num, ignore_obs_next=True), cached_num, size) # test if CachedReplayBuffer can handle super corner case: # prio-buffer + stack_num + ignore_obs_next + sample_avail buf5 = CachedReplayBuffer( PrioritizedReplayBuffer(bufsize, 0.6, 0.4, stack_num=stack_num, ignore_obs_next=True, sample_avail=True), cached_num, size) obs = env.reset(1) for i in range(18): obs_next, rew, done, info = env.step(1) obs_list = np.array([obs + size * i for i in range(cached_num)]) act_list = [1] * cached_num rew_list = [rew] * cached_num done_list = [done] * cached_num obs_next_list = -obs_list info_list = [info] * cached_num buf4.add(obs_list, act_list, rew_list, done_list, obs_next_list, info_list) buf5.add(obs_list, act_list, rew_list, done_list, obs_next_list, info_list) obs = obs_next if done: obs = env.reset(1) # check the `add` order is correct assert np.allclose( buf4.obs.reshape(-1), [ 12, 13, 14, 4, 6, 7, 8, 9, 11, # main_buffer 1, 2, 3, 4, 0, # cached_buffer[0] 6, 7, 8, 9, 0, # cached_buffer[1] 11, 12, 13, 14, 0, # cached_buffer[2] ]), buf4.obs assert np.allclose( buf4.done, [ 0, 0, 1, 1, 0, 0, 0, 1, 0, # main_buffer 0, 0, 0, 1, 0, # cached_buffer[0] 0, 0, 0, 1, 0, # cached_buffer[1] 0, 0, 0, 1, 0, # cached_buffer[2] ]), buf4.done assert np.allclose(buf4.unfinished_index(), [10, 15, 20]) indice = sorted(buf4.sample_index(0)) assert np.allclose(indice, list(range(bufsize)) + [9, 10, 14, 15, 19, 20]) assert np.allclose(buf4[indice].obs[..., 0], [ [11, 11, 11, 12], [11, 11, 12, 13], [11, 12, 13, 14], [4, 4, 4, 4], [6, 6, 6, 6], [6, 6, 6, 7], [6, 6, 7, 8], [6, 7, 8, 9], [11, 11, 11, 11], [1, 1, 1, 1], [1, 1, 1, 2], [6, 6, 6, 6], [6, 6, 6, 7], [11, 11, 11, 11], [11, 11, 11, 12], ]) assert np.allclose(buf4[indice].obs_next[..., 0], [ [11, 11, 12, 13], [11, 12, 13, 14], [11, 12, 13, 14], [4, 4, 4, 4], [6, 6, 6, 7], [6, 6, 7, 8], [6, 7, 8, 9], [6, 7, 8, 9], [11, 11, 11, 12], [1, 1, 1, 2], [1, 1, 1, 2], [6, 6, 6, 7], [6, 6, 6, 7], [11, 11, 11, 12], [11, 11, 11, 12], ]) assert np.all(buf4.done == buf5.done) indice = buf5.sample_index(0) assert np.allclose(sorted(indice), [2, 7]) assert np.all(np.isin(buf5.sample_index(100), indice)) # manually change the stack num buf5.stack_num = 2 for buf in buf5.buffers: buf.stack_num = 2 indice = buf5.sample_index(0) assert np.allclose(sorted(indice), [0, 1, 2, 5, 6, 7, 10, 15, 20]) batch, _ = buf5.sample(0) assert np.allclose(buf5[np.arange(buf5.maxsize)].weight, 1) buf5.update_weight(indice, batch.weight * 0) weight = buf5[np.arange(buf5.maxsize)].weight modified_weight = weight[[0, 1, 2, 5, 6, 7]] assert modified_weight.min() == modified_weight.max() assert modified_weight.max() < 1 unmodified_weight = weight[[3, 4, 8]] assert unmodified_weight.min() == unmodified_weight.max() assert unmodified_weight.max() < 1 cached_weight = weight[9:] assert cached_weight.min() == cached_weight.max() == 1 # test Atari with CachedReplayBuffer, save_only_last_obs + ignore_obs_next buf6 = CachedReplayBuffer( ReplayBuffer(bufsize, stack_num=stack_num, save_only_last_obs=True, ignore_obs_next=True), cached_num, size) obs = np.random.rand(size, 4, 84, 84) buf6.add(obs=[obs[2], obs[0]], act=[1, 1], rew=[0, 0], done=[0, 1], obs_next=[obs[3], obs[1]], cached_buffer_ids=[1, 2]) assert buf6.obs.shape == (buf6.maxsize, 84, 84) assert np.allclose(buf6.obs[0], obs[0, -1]) assert np.allclose(buf6.obs[14], obs[2, -1]) assert np.allclose(buf6.obs[19], obs[0, -1]) assert buf6[0].obs.shape == (4, 84, 84)