示例#1
0
def test_cachedbuffer():
    buf = CachedReplayBuffer(ReplayBuffer(10), 4, 5)
    assert buf.sample_index(0).tolist() == []
    # check the normal function/usage/storage in CachedReplayBuffer
    ep_len, ep_rew = buf.add(obs=[1], act=[1], rew=[1], done=[0],
                             cached_buffer_ids=[1])
    obs = np.zeros(buf.maxsize)
    obs[15] = 1
    indice = buf.sample_index(0)
    assert np.allclose(indice, [15])
    assert np.allclose(buf.prev(indice), [15])
    assert np.allclose(buf.next(indice), [15])
    assert np.allclose(buf.obs, obs)
    assert np.allclose(ep_len, [0]) and np.allclose(ep_rew, [0.0])
    ep_len, ep_rew = buf.add(obs=[2], act=[2], rew=[2], done=[1],
                             cached_buffer_ids=[3])
    obs[[0, 25]] = 2
    indice = buf.sample_index(0)
    assert np.allclose(indice, [0, 15])
    assert np.allclose(buf.prev(indice), [0, 15])
    assert np.allclose(buf.next(indice), [0, 15])
    assert np.allclose(buf.obs, obs)
    assert np.allclose(ep_len, [1]) and np.allclose(ep_rew, [2.0])
    assert np.allclose(buf.unfinished_index(), [15])
    assert np.allclose(buf.sample_index(0), [0, 15])
    ep_len, ep_rew = buf.add(obs=[3, 4], act=[3, 4], rew=[3, 4],
                             done=[0, 1], cached_buffer_ids=[3, 1])
    assert np.allclose(ep_len, [0, 2]) and np.allclose(ep_rew, [0, 5.0])
    obs[[0, 1, 2, 15, 16, 25]] = [2, 1, 4, 1, 4, 3]
    assert np.allclose(buf.obs, obs)
    assert np.allclose(buf.unfinished_index(), [25])
    indice = buf.sample_index(0)
    assert np.allclose(indice, [0, 1, 2, 25])
    assert np.allclose(buf.done[indice], [1, 0, 1, 0])
    assert np.allclose(buf.prev(indice), [0, 1, 1, 25])
    assert np.allclose(buf.next(indice), [0, 2, 2, 25])
    indice = buf.sample_index(10000)
    assert np.bincount(indice)[[0, 1, 2, 25]].min() > 2000  # uniform sample
    # cached buffer with main_buffer size == 0 (no update)
    # used in test_collector
    buf = CachedReplayBuffer(ReplayBuffer(0, sample_avail=True), 4, 5)
    data = np.zeros(4)
    rew = np.ones([4, 4])
    buf.add(obs=data, act=data, rew=rew, done=[0, 0, 1, 1], obs_next=data)
    buf.add(obs=data, act=data, rew=rew, done=[0, 0, 0, 0], obs_next=data)
    buf.add(obs=data, act=data, rew=rew, done=[1, 1, 1, 1], obs_next=data)
    buf.add(obs=data, act=data, rew=rew, done=[0, 0, 0, 0], obs_next=data)
    buf.add(obs=data, act=data, rew=rew, done=[0, 1, 0, 1], obs_next=data)
    assert np.allclose(buf.done, [
        0, 0, 1, 0, 0,
        0, 1, 1, 0, 0,
        0, 0, 0, 0, 0,
        0, 1, 0, 0, 0,
    ])
    indice = buf.sample_index(0)
    assert np.allclose(indice, [0, 1, 10, 11])
    assert np.allclose(buf.prev(indice), [0, 0, 10, 10])
    assert np.allclose(buf.next(indice), [1, 1, 11, 11])
示例#2
0
def test_multibuf_stack():
    size = 5
    bufsize = 9
    stack_num = 4
    cached_num = 3
    env = MyTestEnv(size)
    # test if CachedReplayBuffer can handle stack_num + ignore_obs_next
    buf4 = CachedReplayBuffer(
        ReplayBuffer(bufsize, stack_num=stack_num, ignore_obs_next=True),
        cached_num, size)
    # test if CachedReplayBuffer can handle super corner case:
    # prio-buffer + stack_num + ignore_obs_next + sample_avail
    buf5 = CachedReplayBuffer(
        PrioritizedReplayBuffer(bufsize,
                                0.6,
                                0.4,
                                stack_num=stack_num,
                                ignore_obs_next=True,
                                sample_avail=True), cached_num, size)
    obs = env.reset(1)
    for i in range(18):
        obs_next, rew, done, info = env.step(1)
        obs_list = np.array([obs + size * i for i in range(cached_num)])
        act_list = [1] * cached_num
        rew_list = [rew] * cached_num
        done_list = [done] * cached_num
        obs_next_list = -obs_list
        info_list = [info] * cached_num
        buf4.add(obs_list, act_list, rew_list, done_list, obs_next_list,
                 info_list)
        buf5.add(obs_list, act_list, rew_list, done_list, obs_next_list,
                 info_list)
        obs = obs_next
        if done:
            obs = env.reset(1)
    # check the `add` order is correct
    assert np.allclose(
        buf4.obs.reshape(-1),
        [
            12,
            13,
            14,
            4,
            6,
            7,
            8,
            9,
            11,  # main_buffer
            1,
            2,
            3,
            4,
            0,  # cached_buffer[0]
            6,
            7,
            8,
            9,
            0,  # cached_buffer[1]
            11,
            12,
            13,
            14,
            0,  # cached_buffer[2]
        ]), buf4.obs
    assert np.allclose(
        buf4.done,
        [
            0,
            0,
            1,
            1,
            0,
            0,
            0,
            1,
            0,  # main_buffer
            0,
            0,
            0,
            1,
            0,  # cached_buffer[0]
            0,
            0,
            0,
            1,
            0,  # cached_buffer[1]
            0,
            0,
            0,
            1,
            0,  # cached_buffer[2]
        ]), buf4.done
    assert np.allclose(buf4.unfinished_index(), [10, 15, 20])
    indice = sorted(buf4.sample_index(0))
    assert np.allclose(indice, list(range(bufsize)) + [9, 10, 14, 15, 19, 20])
    assert np.allclose(buf4[indice].obs[..., 0], [
        [11, 11, 11, 12],
        [11, 11, 12, 13],
        [11, 12, 13, 14],
        [4, 4, 4, 4],
        [6, 6, 6, 6],
        [6, 6, 6, 7],
        [6, 6, 7, 8],
        [6, 7, 8, 9],
        [11, 11, 11, 11],
        [1, 1, 1, 1],
        [1, 1, 1, 2],
        [6, 6, 6, 6],
        [6, 6, 6, 7],
        [11, 11, 11, 11],
        [11, 11, 11, 12],
    ])
    assert np.allclose(buf4[indice].obs_next[..., 0], [
        [11, 11, 12, 13],
        [11, 12, 13, 14],
        [11, 12, 13, 14],
        [4, 4, 4, 4],
        [6, 6, 6, 7],
        [6, 6, 7, 8],
        [6, 7, 8, 9],
        [6, 7, 8, 9],
        [11, 11, 11, 12],
        [1, 1, 1, 2],
        [1, 1, 1, 2],
        [6, 6, 6, 7],
        [6, 6, 6, 7],
        [11, 11, 11, 12],
        [11, 11, 11, 12],
    ])
    assert np.all(buf4.done == buf5.done)
    indice = buf5.sample_index(0)
    assert np.allclose(sorted(indice), [2, 7])
    assert np.all(np.isin(buf5.sample_index(100), indice))
    # manually change the stack num
    buf5.stack_num = 2
    for buf in buf5.buffers:
        buf.stack_num = 2
    indice = buf5.sample_index(0)
    assert np.allclose(sorted(indice), [0, 1, 2, 5, 6, 7, 10, 15, 20])
    batch, _ = buf5.sample(0)
    assert np.allclose(buf5[np.arange(buf5.maxsize)].weight, 1)
    buf5.update_weight(indice, batch.weight * 0)
    weight = buf5[np.arange(buf5.maxsize)].weight
    modified_weight = weight[[0, 1, 2, 5, 6, 7]]
    assert modified_weight.min() == modified_weight.max()
    assert modified_weight.max() < 1
    unmodified_weight = weight[[3, 4, 8]]
    assert unmodified_weight.min() == unmodified_weight.max()
    assert unmodified_weight.max() < 1
    cached_weight = weight[9:]
    assert cached_weight.min() == cached_weight.max() == 1
    # test Atari with CachedReplayBuffer, save_only_last_obs + ignore_obs_next
    buf6 = CachedReplayBuffer(
        ReplayBuffer(bufsize,
                     stack_num=stack_num,
                     save_only_last_obs=True,
                     ignore_obs_next=True), cached_num, size)
    obs = np.random.rand(size, 4, 84, 84)
    buf6.add(obs=[obs[2], obs[0]],
             act=[1, 1],
             rew=[0, 0],
             done=[0, 1],
             obs_next=[obs[3], obs[1]],
             cached_buffer_ids=[1, 2])
    assert buf6.obs.shape == (buf6.maxsize, 84, 84)
    assert np.allclose(buf6.obs[0], obs[0, -1])
    assert np.allclose(buf6.obs[14], obs[2, -1])
    assert np.allclose(buf6.obs[19], obs[0, -1])
    assert buf6[0].obs.shape == (4, 84, 84)