Esempio n. 1
0
def test_multibuf_stack():
    size = 5
    bufsize = 9
    stack_num = 4
    cached_num = 3
    env = MyTestEnv(size)
    # test if CachedReplayBuffer can handle stack_num + ignore_obs_next
    buf4 = CachedReplayBuffer(
        ReplayBuffer(bufsize, stack_num=stack_num, ignore_obs_next=True),
        cached_num, size)
    # test if CachedReplayBuffer can handle corner case:
    # buffer + stack_num + ignore_obs_next + sample_avail
    buf5 = CachedReplayBuffer(
        ReplayBuffer(bufsize,
                     stack_num=stack_num,
                     ignore_obs_next=True,
                     sample_avail=True), cached_num, size)
    obs = env.reset(1)
    for i in range(18):
        obs_next, rew, done, info = env.step(1)
        obs_list = np.array([obs + size * i for i in range(cached_num)])
        act_list = [1] * cached_num
        rew_list = [rew] * cached_num
        done_list = [done] * cached_num
        obs_next_list = -obs_list
        info_list = [info] * cached_num
        batch = Batch(obs=obs_list,
                      act=act_list,
                      rew=rew_list,
                      done=done_list,
                      obs_next=obs_next_list,
                      info=info_list)
        buf5.add(batch)
        buf4.add(batch)
        assert np.all(buf4.obs == buf5.obs)
        assert np.all(buf4.done == buf5.done)
        obs = obs_next
        if done:
            obs = env.reset(1)
    # check the `add` order is correct
    assert np.allclose(
        buf4.obs.reshape(-1),
        [
            12,
            13,
            14,
            4,
            6,
            7,
            8,
            9,
            11,  # main_buffer
            1,
            2,
            3,
            4,
            0,  # cached_buffer[0]
            6,
            7,
            8,
            9,
            0,  # cached_buffer[1]
            11,
            12,
            13,
            14,
            0,  # cached_buffer[2]
        ]), buf4.obs
    assert np.allclose(
        buf4.done,
        [
            0,
            0,
            1,
            1,
            0,
            0,
            0,
            1,
            0,  # main_buffer
            0,
            0,
            0,
            1,
            0,  # cached_buffer[0]
            0,
            0,
            0,
            1,
            0,  # cached_buffer[1]
            0,
            0,
            0,
            1,
            0,  # cached_buffer[2]
        ]), buf4.done
    assert np.allclose(buf4.unfinished_index(), [10, 15, 20])
    indices = sorted(buf4.sample_indices(0))
    assert np.allclose(indices, list(range(bufsize)) + [9, 10, 14, 15, 19, 20])
    assert np.allclose(buf4[indices].obs[..., 0], [
        [11, 11, 11, 12],
        [11, 11, 12, 13],
        [11, 12, 13, 14],
        [4, 4, 4, 4],
        [6, 6, 6, 6],
        [6, 6, 6, 7],
        [6, 6, 7, 8],
        [6, 7, 8, 9],
        [11, 11, 11, 11],
        [1, 1, 1, 1],
        [1, 1, 1, 2],
        [6, 6, 6, 6],
        [6, 6, 6, 7],
        [11, 11, 11, 11],
        [11, 11, 11, 12],
    ])
    assert np.allclose(buf4[indices].obs_next[..., 0], [
        [11, 11, 12, 13],
        [11, 12, 13, 14],
        [11, 12, 13, 14],
        [4, 4, 4, 4],
        [6, 6, 6, 7],
        [6, 6, 7, 8],
        [6, 7, 8, 9],
        [6, 7, 8, 9],
        [11, 11, 11, 12],
        [1, 1, 1, 2],
        [1, 1, 1, 2],
        [6, 6, 6, 7],
        [6, 6, 6, 7],
        [11, 11, 11, 12],
        [11, 11, 11, 12],
    ])
    indices = buf5.sample_indices(0)
    assert np.allclose(sorted(indices), [2, 7])
    assert np.all(np.isin(buf5.sample_indices(100), indices))
    # manually change the stack num
    buf5.stack_num = 2
    for buf in buf5.buffers:
        buf.stack_num = 2
    indices = buf5.sample_indices(0)
    assert np.allclose(sorted(indices), [0, 1, 2, 5, 6, 7, 10, 15, 20])
    batch, _ = buf5.sample(0)
    # test Atari with CachedReplayBuffer, save_only_last_obs + ignore_obs_next
    buf6 = CachedReplayBuffer(
        ReplayBuffer(bufsize,
                     stack_num=stack_num,
                     save_only_last_obs=True,
                     ignore_obs_next=True), cached_num, size)
    obs = np.random.rand(size, 4, 84, 84)
    buf6.add(Batch(obs=[obs[2], obs[0]],
                   act=[1, 1],
                   rew=[0, 0],
                   done=[0, 1],
                   obs_next=[obs[3], obs[1]]),
             buffer_ids=[1, 2])
    assert buf6.obs.shape == (buf6.maxsize, 84, 84)
    assert np.allclose(buf6.obs[0], obs[0, -1])
    assert np.allclose(buf6.obs[14], obs[2, -1])
    assert np.allclose(buf6.obs[19], obs[0, -1])
    assert buf6[0].obs.shape == (4, 84, 84)
Esempio n. 2
0
def test_cachedbuffer():
    buf = CachedReplayBuffer(ReplayBuffer(10), 4, 5)
    assert buf.sample_indices(0).tolist() == []
    # check the normal function/usage/storage in CachedReplayBuffer
    ptr, ep_rew, ep_len, ep_idx = buf.add(Batch(obs=[1],
                                                act=[1],
                                                rew=[1],
                                                done=[0]),
                                          buffer_ids=[1])
    obs = np.zeros(buf.maxsize)
    obs[15] = 1
    indices = buf.sample_indices(0)
    assert np.allclose(indices, [15])
    assert np.allclose(buf.prev(indices), [15])
    assert np.allclose(buf.next(indices), [15])
    assert np.allclose(buf.obs, obs)
    assert np.all(ep_len == [0]) and np.all(ep_rew == [0.0])
    assert np.all(ptr == [15]) and np.all(ep_idx == [15])
    ptr, ep_rew, ep_len, ep_idx = buf.add(Batch(obs=[2],
                                                act=[2],
                                                rew=[2],
                                                done=[1]),
                                          buffer_ids=[3])
    obs[[0, 25]] = 2
    indices = buf.sample_indices(0)
    assert np.allclose(indices, [0, 15])
    assert np.allclose(buf.prev(indices), [0, 15])
    assert np.allclose(buf.next(indices), [0, 15])
    assert np.allclose(buf.obs, obs)
    assert np.all(ep_len == [1]) and np.all(ep_rew == [2.0])
    assert np.all(ptr == [0]) and np.all(ep_idx == [0])
    assert np.allclose(buf.unfinished_index(), [15])
    assert np.allclose(buf.sample_indices(0), [0, 15])
    ptr, ep_rew, ep_len, ep_idx = buf.add(Batch(obs=[3, 4],
                                                act=[3, 4],
                                                rew=[3, 4],
                                                done=[0, 1]),
                                          buffer_ids=[3, 1])
    assert np.all(ep_len == [0, 2]) and np.all(ep_rew == [0, 5.0])
    assert np.all(ptr == [25, 2]) and np.all(ep_idx == [25, 1])
    obs[[0, 1, 2, 15, 16, 25]] = [2, 1, 4, 1, 4, 3]
    assert np.allclose(buf.obs, obs)
    assert np.allclose(buf.unfinished_index(), [25])
    indices = buf.sample_indices(0)
    assert np.allclose(indices, [0, 1, 2, 25])
    assert np.allclose(buf.done[indices], [1, 0, 1, 0])
    assert np.allclose(buf.prev(indices), [0, 1, 1, 25])
    assert np.allclose(buf.next(indices), [0, 2, 2, 25])
    indices = buf.sample_indices(10000)
    assert np.bincount(indices)[[0, 1, 2, 25]].min() > 2000  # uniform sample
    # cached buffer with main_buffer size == 0 (no update)
    # used in test_collector
    buf = CachedReplayBuffer(ReplayBuffer(0, sample_avail=True), 4, 5)
    data = np.zeros(4)
    rew = np.ones([4, 4])
    buf.add(Batch(obs=data, act=data, rew=rew, done=[0, 0, 1, 1]))
    buf.add(Batch(obs=data, act=data, rew=rew, done=[0, 0, 0, 0]))
    buf.add(Batch(obs=data, act=data, rew=rew, done=[1, 1, 1, 1]))
    buf.add(Batch(obs=data, act=data, rew=rew, done=[0, 0, 0, 0]))
    ptr, ep_rew, ep_len, ep_idx = buf.add(
        Batch(obs=data, act=data, rew=rew, done=[0, 1, 0, 1]))
    assert np.all(ptr == [1, -1, 11, -1]) and np.all(ep_idx == [0, -1, 10, -1])
    assert np.all(ep_len == [0, 2, 0, 2])
    assert np.all(ep_rew == [data, data + 2, data, data + 2])
    assert np.allclose(buf.done, [
        0,
        0,
        1,
        0,
        0,
        0,
        1,
        1,
        0,
        0,
        0,
        0,
        0,
        0,
        0,
        0,
        1,
        0,
        0,
        0,
    ])
    indices = buf.sample_indices(0)
    assert np.allclose(indices, [0, 1, 10, 11])
    assert np.allclose(buf.prev(indices), [0, 0, 10, 10])
    assert np.allclose(buf.next(indices), [1, 1, 11, 11])