Exemple #1
0
def test_vectorbuffer(task="Pendulum-v0"):
    total_count = 5
    for _ in tqdm.trange(total_count, desc="VectorReplayBuffer"):
        env = gym.make(task)
        buf = VectorReplayBuffer(total_size=10000, buffer_num=1)
        obs = env.reset()
        for _ in range(100000):
            act = env.action_space.sample()
            obs_next, rew, done, info = env.step(act)
            batch = Batch(
                obs=np.array([obs]),
                act=np.array([act]),
                rew=np.array([rew]),
                done=np.array([done]),
                obs_next=np.array([obs_next]),
                info=np.array([info]),
            )
            buf.add(batch)
            obs = obs_next
            if done:
                obs = env.reset()
Exemple #2
0
def test_replaybuffermanager():
    buf = VectorReplayBuffer(20, 4)
    batch = Batch(obs=[1, 2, 3], act=[1, 2, 3], rew=[1, 2, 3], done=[0, 0, 1])
    ptr, ep_rew, ep_len, ep_idx = buf.add(batch, buffer_ids=[0, 1, 2])
    assert np.all(ep_len == [0, 0, 1]) and np.all(ep_rew == [0, 0, 3])
    assert np.all(ptr == [0, 5, 10]) and np.all(ep_idx == [0, 5, 10])
    with pytest.raises(NotImplementedError):
        # ReplayBufferManager cannot be updated
        buf.update(buf)
    # sample index / prev / next / unfinished_index
    indices = buf.sample_indices(11000)
    assert np.bincount(indices)[[0, 5, 10]].min() >= 3000  # uniform sample
    batch, indices = buf.sample(0)
    assert np.allclose(indices, [0, 5, 10])
    indices_prev = buf.prev(indices)
    assert np.allclose(indices_prev, indices), indices_prev
    indices_next = buf.next(indices)
    assert np.allclose(indices_next, indices), indices_next
    assert np.allclose(buf.unfinished_index(), [0, 5])
    buf.add(Batch(obs=[4], act=[4], rew=[4], done=[1]), buffer_ids=[3])
    assert np.allclose(buf.unfinished_index(), [0, 5])
    batch, indices = buf.sample(10)
    batch, indices = buf.sample(0)
    assert np.allclose(indices, [0, 5, 10, 15])
    indices_prev = buf.prev(indices)
    assert np.allclose(indices_prev, indices), indices_prev
    indices_next = buf.next(indices)
    assert np.allclose(indices_next, indices), indices_next
    data = np.array([0, 0, 0, 0])
    buf.add(Batch(obs=data, act=data, rew=data, done=data),
            buffer_ids=[0, 1, 2, 3])
    buf.add(Batch(obs=data, act=data, rew=data, done=1 - data),
            buffer_ids=[0, 1, 2, 3])
    assert len(buf) == 12
    buf.add(Batch(obs=data, act=data, rew=data, done=data),
            buffer_ids=[0, 1, 2, 3])
    buf.add(Batch(obs=data, act=data, rew=data, done=[0, 1, 0, 1]),
            buffer_ids=[0, 1, 2, 3])
    assert len(buf) == 20
    indices = buf.sample_indices(120000)
    assert np.bincount(indices).min() >= 5000
    batch, indices = buf.sample(10)
    indices = buf.sample_indices(0)
    assert np.allclose(indices, np.arange(len(buf)))
    # check the actual data stored in buf._meta
    assert np.allclose(buf.done, [
        0,
        0,
        1,
        0,
        0,
        0,
        0,
        1,
        0,
        1,
        1,
        0,
        1,
        0,
        0,
        1,
        0,
        1,
        0,
        1,
    ])
    assert np.allclose(buf.prev(indices), [
        0,
        0,
        1,
        3,
        3,
        5,
        5,
        6,
        8,
        8,
        10,
        11,
        11,
        13,
        13,
        15,
        16,
        16,
        18,
        18,
    ])
    assert np.allclose(buf.next(indices), [
        1,
        2,
        2,
        4,
        4,
        6,
        7,
        7,
        9,
        9,
        10,
        12,
        12,
        14,
        14,
        15,
        17,
        17,
        19,
        19,
    ])
    assert np.allclose(buf.unfinished_index(), [4, 14])
    ptr, ep_rew, ep_len, ep_idx = buf.add(Batch(obs=[1],
                                                act=[1],
                                                rew=[1],
                                                done=[1]),
                                          buffer_ids=[2])
    assert np.all(ep_len == [3]) and np.all(ep_rew == [1])
    assert np.all(ptr == [10]) and np.all(ep_idx == [13])
    assert np.allclose(buf.unfinished_index(), [4])
    indices = list(sorted(buf.sample_indices(0)))
    assert np.allclose(indices, np.arange(len(buf)))
    assert np.allclose(buf.prev(indices), [
        0,
        0,
        1,
        3,
        3,
        5,
        5,
        6,
        8,
        8,
        14,
        11,
        11,
        13,
        13,
        15,
        16,
        16,
        18,
        18,
    ])
    assert np.allclose(buf.next(indices), [
        1,
        2,
        2,
        4,
        4,
        6,
        7,
        7,
        9,
        9,
        10,
        12,
        12,
        14,
        10,
        15,
        17,
        17,
        19,
        19,
    ])
    # corner case: list, int and -1
    assert buf.prev(-1) == buf.prev([buf.maxsize - 1])[0]
    assert buf.next(-1) == buf.next([buf.maxsize - 1])[0]
    batch = buf._meta
    batch.info = np.ones(buf.maxsize)
    buf.set_batch(batch)
    assert np.allclose(buf.buffers[-1].info, [1] * 5)
    assert buf.sample_indices(-1).tolist() == []
    assert np.array([ReplayBuffer(0, ignore_obs_next=True)]).dtype == object