def test_update(): buf1 = ReplayBuffer(4, stack_num=2) buf2 = ReplayBuffer(4, stack_num=2) for i in range(5): buf1.add( Batch(obs=np.array([i]), act=float(i), rew=i * i, done=i % 2 == 0, info={'incident': 'found'})) assert len(buf1) > len(buf2) buf2.update(buf1) assert len(buf1) == len(buf2) assert (buf2[0].obs == buf1[1].obs).all() assert (buf2[-1].obs == buf1[0].obs).all() b = CachedReplayBuffer(ReplayBuffer(10), 4, 5) with pytest.raises(NotImplementedError): b.update(b)
def test_collector_with_atari_setting(): reference_obs = np.zeros([6, 4, 84, 84]) for i in range(6): reference_obs[i, 3, np.arange(84), np.arange(84)] = i reference_obs[i, 2, np.arange(84)] = i reference_obs[i, 1, :, np.arange(84)] = i reference_obs[i, 0] = i # atari single buffer env = MyTestEnv(size=5, sleep=0, array_state=True) policy = MyPolicy() c0 = Collector(policy, env, ReplayBuffer(size=100)) c0.collect(n_step=6) c0.collect(n_episode=2) assert c0.buffer.obs.shape == (100, 4, 84, 84) assert c0.buffer.obs_next.shape == (100, 4, 84, 84) assert len(c0.buffer) == 15 obs = np.zeros_like(c0.buffer.obs) obs[np.arange(15)] = reference_obs[np.arange(15) % 5] assert np.all(obs == c0.buffer.obs) c1 = Collector(policy, env, ReplayBuffer(size=100, ignore_obs_next=True)) c1.collect(n_episode=3) assert np.allclose(c0.buffer.obs, c1.buffer.obs) with pytest.raises(AttributeError): c1.buffer.obs_next assert np.all(reference_obs[[1, 2, 3, 4, 4] * 3] == c1.buffer[:].obs_next) c2 = Collector( policy, env, ReplayBuffer(size=100, ignore_obs_next=True, save_only_last_obs=True)) c2.collect(n_step=8) assert c2.buffer.obs.shape == (100, 84, 84) obs = np.zeros_like(c2.buffer.obs) obs[np.arange(8)] = reference_obs[[0, 1, 2, 3, 4, 0, 1, 2], -1] assert np.all(c2.buffer.obs == obs) assert np.allclose(c2.buffer[:].obs_next, reference_obs[[1, 2, 3, 4, 4, 1, 2, 2], -1]) # atari multi buffer env_fns = [ lambda x=i: MyTestEnv(size=x, sleep=0, array_state=True) for i in [2, 3, 4, 5] ] envs = DummyVectorEnv(env_fns) c3 = Collector(policy, envs, VectorReplayBuffer(total_size=100, buffer_num=4)) c3.collect(n_step=12) result = c3.collect(n_episode=9) assert result["n/ep"] == 9 and result["n/st"] == 23 assert c3.buffer.obs.shape == (100, 4, 84, 84) obs = np.zeros_like(c3.buffer.obs) obs[np.arange(8)] = reference_obs[[0, 1, 0, 1, 0, 1, 0, 1]] obs[np.arange(25, 34)] = reference_obs[[0, 1, 2, 0, 1, 2, 0, 1, 2]] obs[np.arange(50, 58)] = reference_obs[[0, 1, 2, 3, 0, 1, 2, 3]] obs[np.arange(75, 85)] = reference_obs[[0, 1, 2, 3, 4, 0, 1, 2, 3, 4]] assert np.all(obs == c3.buffer.obs) obs_next = np.zeros_like(c3.buffer.obs_next) obs_next[np.arange(8)] = reference_obs[[1, 2, 1, 2, 1, 2, 1, 2]] obs_next[np.arange(25, 34)] = reference_obs[[1, 2, 3, 1, 2, 3, 1, 2, 3]] obs_next[np.arange(50, 58)] = reference_obs[[1, 2, 3, 4, 1, 2, 3, 4]] obs_next[np.arange(75, 85)] = reference_obs[[1, 2, 3, 4, 5, 1, 2, 3, 4, 5]] assert np.all(obs_next == c3.buffer.obs_next) c4 = Collector( policy, envs, VectorReplayBuffer(total_size=100, buffer_num=4, stack_num=4, ignore_obs_next=True, save_only_last_obs=True)) c4.collect(n_step=12) result = c4.collect(n_episode=9) assert result["n/ep"] == 9 and result["n/st"] == 23 assert c4.buffer.obs.shape == (100, 84, 84) obs = np.zeros_like(c4.buffer.obs) slice_obs = reference_obs[:, -1] obs[np.arange(8)] = slice_obs[[0, 1, 0, 1, 0, 1, 0, 1]] obs[np.arange(25, 34)] = slice_obs[[0, 1, 2, 0, 1, 2, 0, 1, 2]] obs[np.arange(50, 58)] = slice_obs[[0, 1, 2, 3, 0, 1, 2, 3]] obs[np.arange(75, 85)] = slice_obs[[0, 1, 2, 3, 4, 0, 1, 2, 3, 4]] assert np.all(c4.buffer.obs == obs) obs_next = np.zeros([len(c4.buffer), 4, 84, 84]) ref_index = np.array([ 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 1, 2, 2, 1, 2, 2, 1, 2, 3, 3, 1, 2, 3, 3, 1, 2, 3, 4, 4, 1, 2, 3, 4, 4, ]) obs_next[:, -1] = slice_obs[ref_index] ref_index -= 1 ref_index[ref_index < 0] = 0 obs_next[:, -2] = slice_obs[ref_index] ref_index -= 1 ref_index[ref_index < 0] = 0 obs_next[:, -3] = slice_obs[ref_index] ref_index -= 1 ref_index[ref_index < 0] = 0 obs_next[:, -4] = slice_obs[ref_index] assert np.all(obs_next == c4.buffer[:].obs_next) buf = ReplayBuffer(100, stack_num=4, ignore_obs_next=True, save_only_last_obs=True) c5 = Collector(policy, envs, CachedReplayBuffer(buf, 4, 10)) result_ = c5.collect(n_step=12) assert len(buf) == 5 and len(c5.buffer) == 12 result = c5.collect(n_episode=9) assert result["n/ep"] == 9 and result["n/st"] == 23 assert len(buf) == 35 assert np.all(buf.obs[:len(buf)] == slice_obs[[ 0, 1, 0, 1, 2, 0, 1, 0, 1, 2, 3, 0, 1, 2, 3, 4, 0, 1, 0, 1, 2, 0, 1, 0, 1, 2, 3, 0, 1, 2, 0, 1, 2, 3, 4 ]]) assert np.all(buf[:].obs_next[:, -1] == slice_obs[[ 1, 1, 1, 2, 2, 1, 1, 1, 2, 3, 3, 1, 2, 3, 4, 4, 1, 1, 1, 2, 2, 1, 1, 1, 2, 3, 3, 1, 2, 2, 1, 2, 3, 4, 4 ]]) assert len(buf) == len(c5.buffer) # test buffer=None c6 = Collector(policy, envs) result1 = c6.collect(n_step=12) for key in ["n/ep", "n/st", "rews", "lens"]: assert np.allclose(result1[key], result_[key]) result2 = c6.collect(n_episode=9) for key in ["n/ep", "n/st", "rews", "lens"]: assert np.allclose(result2[key], result[key])
def test_multibuf_hdf5(): size = 100 buffers = { "vector": ReplayBufferManager([ReplayBuffer(size) for i in range(4)]), "cached": CachedReplayBuffer(ReplayBuffer(size), 4, size) } buffer_types = {k: b.__class__ for k, b in buffers.items()} device = 'cuda' if torch.cuda.is_available() else 'cpu' info_t = torch.tensor([1.]).to(device) for i in range(4): kwargs = { 'obs': Batch(index=np.array([i])), 'act': i, 'rew': np.array([1, 2]), 'done': i % 3 == 2, 'info': { "number": { "n": i, "t": info_t }, 'extra': None }, } buffers["vector"].add(**Batch.stack([kwargs, kwargs, kwargs]), buffer_ids=[0, 1, 2]) buffers["cached"].add(**Batch.stack([kwargs, kwargs, kwargs]), cached_buffer_ids=[0, 1, 2]) # save paths = {} for k, buf in buffers.items(): f, path = tempfile.mkstemp(suffix='.hdf5') os.close(f) buf.save_hdf5(path) paths[k] = path # load replay buffer _buffers = {k: buffer_types[k].load_hdf5(paths[k]) for k in paths.keys()} # compare for k in buffers.keys(): assert len(_buffers[k]) == len(buffers[k]) assert np.allclose(_buffers[k].act, buffers[k].act) assert _buffers[k].stack_num == buffers[k].stack_num assert _buffers[k].maxsize == buffers[k].maxsize assert np.all(_buffers[k]._indices == buffers[k]._indices) # check shallow copy in ReplayBufferManager for k in ["vector", "cached"]: buffers[k].info.number.n[0] = -100 assert buffers[k].buffers[0].info.number.n[0] == -100 # check if still behave normally for k in ["vector", "cached"]: kwargs = { 'obs': Batch(index=np.array([5])), 'act': 5, 'rew': np.array([2, 1]), 'done': False, 'info': { "number": { "n": i }, 'Timelimit.truncate': True }, } buffers[k].add(**Batch.stack([kwargs, kwargs, kwargs, kwargs])) act = np.zeros(buffers[k].maxsize) if k == "vector": act[np.arange(5)] = np.array([0, 1, 2, 3, 5]) act[np.arange(5) + size] = np.array([0, 1, 2, 3, 5]) act[np.arange(5) + size * 2] = np.array([0, 1, 2, 3, 5]) act[size * 3] = 5 elif k == "cached": act[np.arange(9)] = np.array([0, 1, 2, 0, 1, 2, 0, 1, 2]) act[np.arange(3) + size] = np.array([3, 5, 2]) act[np.arange(3) + size * 2] = np.array([3, 5, 2]) act[np.arange(3) + size * 3] = np.array([3, 5, 2]) act[size * 4] = 5 assert np.allclose(buffers[k].act, act) for path in paths.values(): os.remove(path)
def test_multibuf_stack(): size = 5 bufsize = 9 stack_num = 4 cached_num = 3 env = MyTestEnv(size) # test if CachedReplayBuffer can handle stack_num + ignore_obs_next buf4 = CachedReplayBuffer( ReplayBuffer(bufsize, stack_num=stack_num, ignore_obs_next=True), cached_num, size) # test if CachedReplayBuffer can handle super corner case: # prio-buffer + stack_num + ignore_obs_next + sample_avail buf5 = CachedReplayBuffer( PrioritizedReplayBuffer(bufsize, 0.6, 0.4, stack_num=stack_num, ignore_obs_next=True, sample_avail=True), cached_num, size) obs = env.reset(1) for i in range(18): obs_next, rew, done, info = env.step(1) obs_list = np.array([obs + size * i for i in range(cached_num)]) act_list = [1] * cached_num rew_list = [rew] * cached_num done_list = [done] * cached_num obs_next_list = -obs_list info_list = [info] * cached_num buf4.add(obs_list, act_list, rew_list, done_list, obs_next_list, info_list) buf5.add(obs_list, act_list, rew_list, done_list, obs_next_list, info_list) obs = obs_next if done: obs = env.reset(1) # check the `add` order is correct assert np.allclose( buf4.obs.reshape(-1), [ 12, 13, 14, 4, 6, 7, 8, 9, 11, # main_buffer 1, 2, 3, 4, 0, # cached_buffer[0] 6, 7, 8, 9, 0, # cached_buffer[1] 11, 12, 13, 14, 0, # cached_buffer[2] ]), buf4.obs assert np.allclose( buf4.done, [ 0, 0, 1, 1, 0, 0, 0, 1, 0, # main_buffer 0, 0, 0, 1, 0, # cached_buffer[0] 0, 0, 0, 1, 0, # cached_buffer[1] 0, 0, 0, 1, 0, # cached_buffer[2] ]), buf4.done assert np.allclose(buf4.unfinished_index(), [10, 15, 20]) indice = sorted(buf4.sample_index(0)) assert np.allclose(indice, list(range(bufsize)) + [9, 10, 14, 15, 19, 20]) assert np.allclose(buf4[indice].obs[..., 0], [ [11, 11, 11, 12], [11, 11, 12, 13], [11, 12, 13, 14], [4, 4, 4, 4], [6, 6, 6, 6], [6, 6, 6, 7], [6, 6, 7, 8], [6, 7, 8, 9], [11, 11, 11, 11], [1, 1, 1, 1], [1, 1, 1, 2], [6, 6, 6, 6], [6, 6, 6, 7], [11, 11, 11, 11], [11, 11, 11, 12], ]) assert np.allclose(buf4[indice].obs_next[..., 0], [ [11, 11, 12, 13], [11, 12, 13, 14], [11, 12, 13, 14], [4, 4, 4, 4], [6, 6, 6, 7], [6, 6, 7, 8], [6, 7, 8, 9], [6, 7, 8, 9], [11, 11, 11, 12], [1, 1, 1, 2], [1, 1, 1, 2], [6, 6, 6, 7], [6, 6, 6, 7], [11, 11, 11, 12], [11, 11, 11, 12], ]) assert np.all(buf4.done == buf5.done) indice = buf5.sample_index(0) assert np.allclose(sorted(indice), [2, 7]) assert np.all(np.isin(buf5.sample_index(100), indice)) # manually change the stack num buf5.stack_num = 2 for buf in buf5.buffers: buf.stack_num = 2 indice = buf5.sample_index(0) assert np.allclose(sorted(indice), [0, 1, 2, 5, 6, 7, 10, 15, 20]) batch, _ = buf5.sample(0) assert np.allclose(buf5[np.arange(buf5.maxsize)].weight, 1) buf5.update_weight(indice, batch.weight * 0) weight = buf5[np.arange(buf5.maxsize)].weight modified_weight = weight[[0, 1, 2, 5, 6, 7]] assert modified_weight.min() == modified_weight.max() assert modified_weight.max() < 1 unmodified_weight = weight[[3, 4, 8]] assert unmodified_weight.min() == unmodified_weight.max() assert unmodified_weight.max() < 1 cached_weight = weight[9:] assert cached_weight.min() == cached_weight.max() == 1 # test Atari with CachedReplayBuffer, save_only_last_obs + ignore_obs_next buf6 = CachedReplayBuffer( ReplayBuffer(bufsize, stack_num=stack_num, save_only_last_obs=True, ignore_obs_next=True), cached_num, size) obs = np.random.rand(size, 4, 84, 84) buf6.add(obs=[obs[2], obs[0]], act=[1, 1], rew=[0, 0], done=[0, 1], obs_next=[obs[3], obs[1]], cached_buffer_ids=[1, 2]) assert buf6.obs.shape == (buf6.maxsize, 84, 84) assert np.allclose(buf6.obs[0], obs[0, -1]) assert np.allclose(buf6.obs[14], obs[2, -1]) assert np.allclose(buf6.obs[19], obs[0, -1]) assert buf6[0].obs.shape == (4, 84, 84)
def test_cachedbuffer(): buf = CachedReplayBuffer(ReplayBuffer(10), 4, 5) assert buf.sample_index(0).tolist() == [] # check the normal function/usage/storage in CachedReplayBuffer ep_len, ep_rew = buf.add(obs=[1], act=[1], rew=[1], done=[0], cached_buffer_ids=[1]) obs = np.zeros(buf.maxsize) obs[15] = 1 indice = buf.sample_index(0) assert np.allclose(indice, [15]) assert np.allclose(buf.prev(indice), [15]) assert np.allclose(buf.next(indice), [15]) assert np.allclose(buf.obs, obs) assert np.allclose(ep_len, [0]) and np.allclose(ep_rew, [0.0]) ep_len, ep_rew = buf.add(obs=[2], act=[2], rew=[2], done=[1], cached_buffer_ids=[3]) obs[[0, 25]] = 2 indice = buf.sample_index(0) assert np.allclose(indice, [0, 15]) assert np.allclose(buf.prev(indice), [0, 15]) assert np.allclose(buf.next(indice), [0, 15]) assert np.allclose(buf.obs, obs) assert np.allclose(ep_len, [1]) and np.allclose(ep_rew, [2.0]) assert np.allclose(buf.unfinished_index(), [15]) assert np.allclose(buf.sample_index(0), [0, 15]) ep_len, ep_rew = buf.add(obs=[3, 4], act=[3, 4], rew=[3, 4], done=[0, 1], cached_buffer_ids=[3, 1]) assert np.allclose(ep_len, [0, 2]) and np.allclose(ep_rew, [0, 5.0]) obs[[0, 1, 2, 15, 16, 25]] = [2, 1, 4, 1, 4, 3] assert np.allclose(buf.obs, obs) assert np.allclose(buf.unfinished_index(), [25]) indice = buf.sample_index(0) assert np.allclose(indice, [0, 1, 2, 25]) assert np.allclose(buf.done[indice], [1, 0, 1, 0]) assert np.allclose(buf.prev(indice), [0, 1, 1, 25]) assert np.allclose(buf.next(indice), [0, 2, 2, 25]) indice = buf.sample_index(10000) assert np.bincount(indice)[[0, 1, 2, 25]].min() > 2000 # uniform sample # cached buffer with main_buffer size == 0 (no update) # used in test_collector buf = CachedReplayBuffer(ReplayBuffer(0, sample_avail=True), 4, 5) data = np.zeros(4) rew = np.ones([4, 4]) buf.add(obs=data, act=data, rew=rew, done=[0, 0, 1, 1], obs_next=data) buf.add(obs=data, act=data, rew=rew, done=[0, 0, 0, 0], obs_next=data) buf.add(obs=data, act=data, rew=rew, done=[1, 1, 1, 1], obs_next=data) buf.add(obs=data, act=data, rew=rew, done=[0, 0, 0, 0], obs_next=data) buf.add(obs=data, act=data, rew=rew, done=[0, 1, 0, 1], obs_next=data) assert np.allclose(buf.done, [ 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, ]) indice = buf.sample_index(0) assert np.allclose(indice, [0, 1, 10, 11]) assert np.allclose(buf.prev(indice), [0, 0, 10, 10]) assert np.allclose(buf.next(indice), [1, 1, 11, 11])