Exemple #1
0
def test_replaybuffermanager():
    buf = VectorReplayBuffer(20, 4)
    batch = Batch(obs=[1, 2, 3], act=[1, 2, 3], rew=[1, 2, 3], done=[0, 0, 1])
    ptr, ep_rew, ep_len, ep_idx = buf.add(batch, buffer_ids=[0, 1, 2])
    assert np.all(ep_len == [0, 0, 1]) and np.all(ep_rew == [0, 0, 3])
    assert np.all(ptr == [0, 5, 10]) and np.all(ep_idx == [0, 5, 10])
    with pytest.raises(NotImplementedError):
        # ReplayBufferManager cannot be updated
        buf.update(buf)
    # sample index / prev / next / unfinished_index
    indices = buf.sample_indices(11000)
    assert np.bincount(indices)[[0, 5, 10]].min() >= 3000  # uniform sample
    batch, indices = buf.sample(0)
    assert np.allclose(indices, [0, 5, 10])
    indices_prev = buf.prev(indices)
    assert np.allclose(indices_prev, indices), indices_prev
    indices_next = buf.next(indices)
    assert np.allclose(indices_next, indices), indices_next
    assert np.allclose(buf.unfinished_index(), [0, 5])
    buf.add(Batch(obs=[4], act=[4], rew=[4], done=[1]), buffer_ids=[3])
    assert np.allclose(buf.unfinished_index(), [0, 5])
    batch, indices = buf.sample(10)
    batch, indices = buf.sample(0)
    assert np.allclose(indices, [0, 5, 10, 15])
    indices_prev = buf.prev(indices)
    assert np.allclose(indices_prev, indices), indices_prev
    indices_next = buf.next(indices)
    assert np.allclose(indices_next, indices), indices_next
    data = np.array([0, 0, 0, 0])
    buf.add(Batch(obs=data, act=data, rew=data, done=data),
            buffer_ids=[0, 1, 2, 3])
    buf.add(Batch(obs=data, act=data, rew=data, done=1 - data),
            buffer_ids=[0, 1, 2, 3])
    assert len(buf) == 12
    buf.add(Batch(obs=data, act=data, rew=data, done=data),
            buffer_ids=[0, 1, 2, 3])
    buf.add(Batch(obs=data, act=data, rew=data, done=[0, 1, 0, 1]),
            buffer_ids=[0, 1, 2, 3])
    assert len(buf) == 20
    indices = buf.sample_indices(120000)
    assert np.bincount(indices).min() >= 5000
    batch, indices = buf.sample(10)
    indices = buf.sample_indices(0)
    assert np.allclose(indices, np.arange(len(buf)))
    # check the actual data stored in buf._meta
    assert np.allclose(buf.done, [
        0,
        0,
        1,
        0,
        0,
        0,
        0,
        1,
        0,
        1,
        1,
        0,
        1,
        0,
        0,
        1,
        0,
        1,
        0,
        1,
    ])
    assert np.allclose(buf.prev(indices), [
        0,
        0,
        1,
        3,
        3,
        5,
        5,
        6,
        8,
        8,
        10,
        11,
        11,
        13,
        13,
        15,
        16,
        16,
        18,
        18,
    ])
    assert np.allclose(buf.next(indices), [
        1,
        2,
        2,
        4,
        4,
        6,
        7,
        7,
        9,
        9,
        10,
        12,
        12,
        14,
        14,
        15,
        17,
        17,
        19,
        19,
    ])
    assert np.allclose(buf.unfinished_index(), [4, 14])
    ptr, ep_rew, ep_len, ep_idx = buf.add(Batch(obs=[1],
                                                act=[1],
                                                rew=[1],
                                                done=[1]),
                                          buffer_ids=[2])
    assert np.all(ep_len == [3]) and np.all(ep_rew == [1])
    assert np.all(ptr == [10]) and np.all(ep_idx == [13])
    assert np.allclose(buf.unfinished_index(), [4])
    indices = list(sorted(buf.sample_indices(0)))
    assert np.allclose(indices, np.arange(len(buf)))
    assert np.allclose(buf.prev(indices), [
        0,
        0,
        1,
        3,
        3,
        5,
        5,
        6,
        8,
        8,
        14,
        11,
        11,
        13,
        13,
        15,
        16,
        16,
        18,
        18,
    ])
    assert np.allclose(buf.next(indices), [
        1,
        2,
        2,
        4,
        4,
        6,
        7,
        7,
        9,
        9,
        10,
        12,
        12,
        14,
        10,
        15,
        17,
        17,
        19,
        19,
    ])
    # corner case: list, int and -1
    assert buf.prev(-1) == buf.prev([buf.maxsize - 1])[0]
    assert buf.next(-1) == buf.next([buf.maxsize - 1])[0]
    batch = buf._meta
    batch.info = np.ones(buf.maxsize)
    buf.set_batch(batch)
    assert np.allclose(buf.buffers[-1].info, [1] * 5)
    assert buf.sample_indices(-1).tolist() == []
    assert np.array([ReplayBuffer(0, ignore_obs_next=True)]).dtype == object