Esempio n. 1
0
def test_priortized_replaybuffer(size=32, bufsize=15):
    env = MyTestEnv(size)
    buf = PrioritizedReplayBuffer(bufsize, 0.5, 0.5)
    buf2 = PrioritizedVectorReplayBuffer(bufsize,
                                         buffer_num=3,
                                         alpha=0.5,
                                         beta=0.5)
    obs = env.reset()
    action_list = [1] * 5 + [0] * 10 + [1] * 10
    for i, a in enumerate(action_list):
        obs_next, rew, done, info = env.step(a)
        batch = Batch(obs=obs,
                      act=a,
                      rew=rew,
                      done=done,
                      obs_next=obs_next,
                      info=info,
                      policy=np.random.randn() - 0.5)
        batch_stack = Batch.stack([batch, batch, batch])
        buf.add(Batch.stack([batch]), buffer_ids=[0])
        buf2.add(batch_stack, buffer_ids=[0, 1, 2])
        obs = obs_next
        data, indices = buf.sample(len(buf) // 2)
        if len(buf) // 2 == 0:
            assert len(data) == len(buf)
        else:
            assert len(data) == len(buf) // 2
        assert len(buf) == min(bufsize, i + 1)
        assert len(buf2) == min(bufsize, 3 * (i + 1))
    # check single buffer's data
    assert buf.info.key.shape == (buf.maxsize, )
    assert buf.rew.dtype == float
    assert buf.done.dtype == bool
    data, indices = buf.sample(len(buf) // 2)
    buf.update_weight(indices, -data.weight / 2)
    assert np.allclose(buf.weight[indices],
                       np.abs(-data.weight / 2)**buf._alpha)
    # check multi buffer's data
    assert np.allclose(buf2[np.arange(buf2.maxsize)].weight, 1)
    batch, indices = buf2.sample(10)
    buf2.update_weight(indices, batch.weight * 0)
    weight = buf2[np.arange(buf2.maxsize)].weight
    mask = np.isin(np.arange(buf2.maxsize), indices)
    assert np.all(weight[mask] == weight[mask][0])
    assert np.all(weight[~mask] == weight[~mask][0])
    assert weight[~mask][0] < weight[mask][0] and weight[mask][0] <= 1
Esempio n. 2
0
def test_priortized_replaybuffer(size=32, bufsize=15):
    env = MyTestEnv(size)
    buf = PrioritizedReplayBuffer(bufsize, 0.5, 0.5)
    obs = env.reset()
    action_list = [1] * 5 + [0] * 10 + [1] * 10
    for i, a in enumerate(action_list):
        obs_next, rew, done, info = env.step(a)
        buf.add(obs, a, rew, done, obs_next, info, np.random.randn() - 0.5)
        obs = obs_next
        data, indice = buf.sample(len(buf) // 2)
        if len(buf) // 2 == 0:
            assert len(data) == len(buf)
        else:
            assert len(data) == len(buf) // 2
        assert len(buf) == min(bufsize, i + 1)
    data, indice = buf.sample(len(buf) // 2)
    buf.update_weight(indice, -data.weight / 2)
    assert np.allclose(
        buf.weight[indice], np.abs(-data.weight / 2) ** buf._alpha)
Esempio n. 3
0
def test_priortized_replaybuffer(size=32, bufsize=15):
    env = MyTestEnv(size)
    buf = PrioritizedReplayBuffer(bufsize, 0.5, 0.5)
    obs = env.reset()
    action_list = [1] * 5 + [0] * 10 + [1] * 10
    for i, a in enumerate(action_list):
        obs_next, rew, done, info = env.step(a)
        buf.add(obs, a, rew, done, obs_next, info, np.random.randn() - 0.5)
        obs = obs_next
        assert np.isclose(np.sum((buf.weight / buf._weight_sum)[:buf._size]),
                          1,
                          rtol=1e-12)
        data, indice = buf.sample(len(buf) // 2)
        if len(buf) // 2 == 0:
            assert len(data) == len(buf)
        else:
            assert len(data) == len(buf) // 2
        assert len(buf) == min(bufsize, i + 1), print(len(buf), i)
        assert np.isclose(buf._weight_sum, (buf.weight).sum())
    data, indice = buf.sample(len(buf) // 2)
    buf.update_weight(indice, -data.weight / 2)
    assert np.isclose(buf.weight[indice],
                      np.power(np.abs(-data.weight / 2), buf._alpha)).all()
    assert np.isclose(buf._weight_sum, (buf.weight).sum())