def test_stack(size=5, bufsize=9, stack_num=4): env = MyTestEnv(size) buf = ReplayBuffer(bufsize, stack_num=stack_num) buf2 = ReplayBuffer(bufsize, stack_num=stack_num, sample_avail=True) buf3 = ReplayBuffer(bufsize, stack_num=stack_num, save_only_last_obs=True) obs = env.reset(1) for i in range(16): obs_next, rew, done, info = env.step(1) buf.add(obs, 1, rew, done, None, info) buf2.add(obs, 1, rew, done, None, info) buf3.add([None, None, obs], 1, rew, done, [None, obs], info) obs = obs_next if done: obs = env.reset(1) indice = np.arange(len(buf)) assert np.allclose(buf.get(indice, 'obs')[..., 0], [ [1, 1, 1, 2], [1, 1, 2, 3], [1, 2, 3, 4], [1, 1, 1, 1], [1, 1, 1, 2], [1, 1, 2, 3], [1, 2, 3, 4], [4, 4, 4, 4], [1, 1, 1, 1]]) assert np.allclose(buf.get(indice, 'obs'), buf3.get(indice, 'obs')) assert np.allclose(buf.get(indice, 'obs'), buf3.get(indice, 'obs_next')) _, indice = buf2.sample(0) assert indice.tolist() == [2, 6] _, indice = buf2.sample(1) assert indice in [2, 6] with pytest.raises(IndexError): buf[bufsize * 2]
def test_stack(size=5, bufsize=9, stack_num=4, cached_num=3): env = MyTestEnv(size) buf = ReplayBuffer(bufsize, stack_num=stack_num) buf2 = ReplayBuffer(bufsize, stack_num=stack_num, sample_avail=True) buf3 = ReplayBuffer(bufsize, stack_num=stack_num, save_only_last_obs=True) obs = env.reset(1) for _ in range(16): obs_next, rew, done, info = env.step(1) buf.add(Batch(obs=obs, act=1, rew=rew, done=done, info=info)) buf2.add(Batch(obs=obs, act=1, rew=rew, done=done, info=info)) buf3.add( Batch(obs=[obs, obs, obs], act=1, rew=rew, done=done, obs_next=[obs, obs], info=info)) obs = obs_next if done: obs = env.reset(1) indices = np.arange(len(buf)) assert np.allclose( buf.get(indices, 'obs')[..., 0], [[1, 1, 1, 2], [1, 1, 2, 3], [1, 2, 3, 4], [1, 1, 1, 1], [1, 1, 1, 2], [1, 1, 2, 3], [1, 2, 3, 4], [4, 4, 4, 4], [1, 1, 1, 1]]) assert np.allclose(buf.get(indices, 'obs'), buf3.get(indices, 'obs')) assert np.allclose(buf.get(indices, 'obs'), buf3.get(indices, 'obs_next')) _, indices = buf2.sample(0) assert indices.tolist() == [2, 6] _, indices = buf2.sample(1) assert indices[0] in [2, 6] batch, indices = buf2.sample(-1) # neg bsz -> no data assert indices.tolist() == [] and len(batch) == 0 with pytest.raises(IndexError): buf[bufsize * 2]
def test_ReplayBuffer(): """ tianshou.data.ReplayBuffer buf.add() buf.get() buf.update() buf.sample() buf.reset() len(buf) :return: """ buf1 = ReplayBuffer(size=15) for i in range(3): buf1.add(obs=i, act=i, rew=i, done=i, obs_next=i + 1, info={}, weight=None) print(len(buf1)) print(buf1.obs) buf2 = ReplayBuffer(size=10) for i in range(15): buf2.add(obs=i, act=i, rew=i, done=i, obs_next=i + 1, info={}, weight=None) print(buf2.obs) buf1.update(buf2) print(buf1.obs) index = [1, 3, 5] # key is an obligatory args print(buf2.get(index, key='obs')) print('--------------------') sample_data, indice = buf2.sample(batch_size=4) print(sample_data, indice) print(sample_data.obs == buf2[indice].obs) print('--------------------') # buf.reset() only resets the index, not the content. print(len(buf2)) buf2.reset() print(len(buf2)) print(buf2) print('--------------------')
def test_stack(size=5, bufsize=9, stack_num=4): env = MyTestEnv(size) buf = ReplayBuffer(bufsize, stack_num) obs = env.reset(1) for i in range(15): obs_next, rew, done, info = env.step(1) buf.add(obs, 1, rew, done, None, info) obs = obs_next if done: obs = env.reset(1) indice = np.arange(len(buf)) assert abs( buf.get(indice, 'obs') - np.array([[1, 1, 1, 2], [1, 1, 2, 3], [1, 2, 3, 4], [1, 1, 1, 1], [1, 1, 1, 2], [1, 1, 2, 3], [3, 3, 3, 3], [3, 3, 3, 4], [1, 1, 1, 1]])).sum() < 1e-6 print(buf)
def test_stack(size=5, bufsize=9, stack_num=4): env = MyTestEnv(size) buf = ReplayBuffer(bufsize, stack_num=stack_num) buf2 = ReplayBuffer(bufsize, stack_num=stack_num, sample_avail=True) obs = env.reset(1) for i in range(15): obs_next, rew, done, info = env.step(1) buf.add(obs, 1, rew, done, None, info) buf2.add(obs, 1, rew, done, None, info) obs = obs_next if done: obs = env.reset(1) indice = np.arange(len(buf)) assert np.allclose(buf.get(indice, 'obs'), np.array([ [1, 1, 1, 2], [1, 1, 2, 3], [1, 2, 3, 4], [1, 1, 1, 1], [1, 1, 1, 2], [1, 1, 2, 3], [3, 3, 3, 3], [3, 3, 3, 4], [1, 1, 1, 1]])) print(buf) _, indice = buf2.sample(0) assert indice == [2] _, indice = buf2.sample(1) assert indice.sum() == 2
def test_stack(size=5, bufsize=9, stack_num=4): env = MyTestEnv(size) buf = ReplayBuffer(bufsize, stack_num=stack_num) buf2 = ReplayBuffer(bufsize, stack_num=stack_num, sample_avail=True) obs = env.reset(1) for i in range(16): obs_next, rew, done, info = env.step(1) buf.add(obs, 1, rew, done, None, info) buf2.add(obs, 1, rew, done, None, info) obs = obs_next if done: obs = env.reset(1) indice = np.arange(len(buf)) assert np.allclose( buf.get(indice, 'obs'), np.expand_dims([[1, 1, 1, 2], [1, 1, 2, 3], [1, 2, 3, 4], [1, 1, 1, 1], [1, 1, 1, 2], [1, 1, 2, 3], [1, 2, 3, 4], [4, 4, 4, 4], [1, 1, 1, 1]], axis=-1)) _, indice = buf2.sample(0) assert indice.tolist() == [2, 6] _, indice = buf2.sample(1) assert indice in [2, 6]