Beispiel #1
0
def test_update():
    buf1 = ReplayBuffer(4, stack_num=2)
    buf2 = ReplayBuffer(4, stack_num=2)
    for i in range(5):
        buf1.add(
            Batch(obs=np.array([i]),
                  act=float(i),
                  rew=i * i,
                  done=i % 2 == 0,
                  info={'incident': 'found'}))
    assert len(buf1) > len(buf2)
    buf2.update(buf1)
    assert len(buf1) == len(buf2)
    assert (buf2[0].obs == buf1[1].obs).all()
    assert (buf2[-1].obs == buf1[0].obs).all()
    b = CachedReplayBuffer(ReplayBuffer(10), 4, 5)
    with pytest.raises(NotImplementedError):
        b.update(b)
Beispiel #2
0
def test_collector_with_atari_setting():
    reference_obs = np.zeros([6, 4, 84, 84])
    for i in range(6):
        reference_obs[i, 3, np.arange(84), np.arange(84)] = i
        reference_obs[i, 2, np.arange(84)] = i
        reference_obs[i, 1, :, np.arange(84)] = i
        reference_obs[i, 0] = i

    # atari single buffer
    env = MyTestEnv(size=5, sleep=0, array_state=True)
    policy = MyPolicy()
    c0 = Collector(policy, env, ReplayBuffer(size=100))
    c0.collect(n_step=6)
    c0.collect(n_episode=2)
    assert c0.buffer.obs.shape == (100, 4, 84, 84)
    assert c0.buffer.obs_next.shape == (100, 4, 84, 84)
    assert len(c0.buffer) == 15
    obs = np.zeros_like(c0.buffer.obs)
    obs[np.arange(15)] = reference_obs[np.arange(15) % 5]
    assert np.all(obs == c0.buffer.obs)

    c1 = Collector(policy, env, ReplayBuffer(size=100, ignore_obs_next=True))
    c1.collect(n_episode=3)
    assert np.allclose(c0.buffer.obs, c1.buffer.obs)
    with pytest.raises(AttributeError):
        c1.buffer.obs_next
    assert np.all(reference_obs[[1, 2, 3, 4, 4] * 3] == c1.buffer[:].obs_next)

    c2 = Collector(
        policy, env,
        ReplayBuffer(size=100, ignore_obs_next=True, save_only_last_obs=True))
    c2.collect(n_step=8)
    assert c2.buffer.obs.shape == (100, 84, 84)
    obs = np.zeros_like(c2.buffer.obs)
    obs[np.arange(8)] = reference_obs[[0, 1, 2, 3, 4, 0, 1, 2], -1]
    assert np.all(c2.buffer.obs == obs)
    assert np.allclose(c2.buffer[:].obs_next,
                       reference_obs[[1, 2, 3, 4, 4, 1, 2, 2], -1])

    # atari multi buffer
    env_fns = [
        lambda x=i: MyTestEnv(size=x, sleep=0, array_state=True)
        for i in [2, 3, 4, 5]
    ]
    envs = DummyVectorEnv(env_fns)
    c3 = Collector(policy, envs,
                   VectorReplayBuffer(total_size=100, buffer_num=4))
    c3.collect(n_step=12)
    result = c3.collect(n_episode=9)
    assert result["n/ep"] == 9 and result["n/st"] == 23
    assert c3.buffer.obs.shape == (100, 4, 84, 84)
    obs = np.zeros_like(c3.buffer.obs)
    obs[np.arange(8)] = reference_obs[[0, 1, 0, 1, 0, 1, 0, 1]]
    obs[np.arange(25, 34)] = reference_obs[[0, 1, 2, 0, 1, 2, 0, 1, 2]]
    obs[np.arange(50, 58)] = reference_obs[[0, 1, 2, 3, 0, 1, 2, 3]]
    obs[np.arange(75, 85)] = reference_obs[[0, 1, 2, 3, 4, 0, 1, 2, 3, 4]]
    assert np.all(obs == c3.buffer.obs)
    obs_next = np.zeros_like(c3.buffer.obs_next)
    obs_next[np.arange(8)] = reference_obs[[1, 2, 1, 2, 1, 2, 1, 2]]
    obs_next[np.arange(25, 34)] = reference_obs[[1, 2, 3, 1, 2, 3, 1, 2, 3]]
    obs_next[np.arange(50, 58)] = reference_obs[[1, 2, 3, 4, 1, 2, 3, 4]]
    obs_next[np.arange(75, 85)] = reference_obs[[1, 2, 3, 4, 5, 1, 2, 3, 4, 5]]
    assert np.all(obs_next == c3.buffer.obs_next)
    c4 = Collector(
        policy, envs,
        VectorReplayBuffer(total_size=100,
                           buffer_num=4,
                           stack_num=4,
                           ignore_obs_next=True,
                           save_only_last_obs=True))
    c4.collect(n_step=12)
    result = c4.collect(n_episode=9)
    assert result["n/ep"] == 9 and result["n/st"] == 23
    assert c4.buffer.obs.shape == (100, 84, 84)
    obs = np.zeros_like(c4.buffer.obs)
    slice_obs = reference_obs[:, -1]
    obs[np.arange(8)] = slice_obs[[0, 1, 0, 1, 0, 1, 0, 1]]
    obs[np.arange(25, 34)] = slice_obs[[0, 1, 2, 0, 1, 2, 0, 1, 2]]
    obs[np.arange(50, 58)] = slice_obs[[0, 1, 2, 3, 0, 1, 2, 3]]
    obs[np.arange(75, 85)] = slice_obs[[0, 1, 2, 3, 4, 0, 1, 2, 3, 4]]
    assert np.all(c4.buffer.obs == obs)
    obs_next = np.zeros([len(c4.buffer), 4, 84, 84])
    ref_index = np.array([
        1,
        1,
        1,
        1,
        1,
        1,
        1,
        1,
        1,
        2,
        2,
        1,
        2,
        2,
        1,
        2,
        2,
        1,
        2,
        3,
        3,
        1,
        2,
        3,
        3,
        1,
        2,
        3,
        4,
        4,
        1,
        2,
        3,
        4,
        4,
    ])
    obs_next[:, -1] = slice_obs[ref_index]
    ref_index -= 1
    ref_index[ref_index < 0] = 0
    obs_next[:, -2] = slice_obs[ref_index]
    ref_index -= 1
    ref_index[ref_index < 0] = 0
    obs_next[:, -3] = slice_obs[ref_index]
    ref_index -= 1
    ref_index[ref_index < 0] = 0
    obs_next[:, -4] = slice_obs[ref_index]
    assert np.all(obs_next == c4.buffer[:].obs_next)

    buf = ReplayBuffer(100,
                       stack_num=4,
                       ignore_obs_next=True,
                       save_only_last_obs=True)
    c5 = Collector(policy, envs, CachedReplayBuffer(buf, 4, 10))
    result_ = c5.collect(n_step=12)
    assert len(buf) == 5 and len(c5.buffer) == 12
    result = c5.collect(n_episode=9)
    assert result["n/ep"] == 9 and result["n/st"] == 23
    assert len(buf) == 35
    assert np.all(buf.obs[:len(buf)] == slice_obs[[
        0, 1, 0, 1, 2, 0, 1, 0, 1, 2, 3, 0, 1, 2, 3, 4, 0, 1, 0, 1, 2, 0, 1, 0,
        1, 2, 3, 0, 1, 2, 0, 1, 2, 3, 4
    ]])
    assert np.all(buf[:].obs_next[:, -1] == slice_obs[[
        1, 1, 1, 2, 2, 1, 1, 1, 2, 3, 3, 1, 2, 3, 4, 4, 1, 1, 1, 2, 2, 1, 1, 1,
        2, 3, 3, 1, 2, 2, 1, 2, 3, 4, 4
    ]])
    assert len(buf) == len(c5.buffer)

    # test buffer=None
    c6 = Collector(policy, envs)
    result1 = c6.collect(n_step=12)
    for key in ["n/ep", "n/st", "rews", "lens"]:
        assert np.allclose(result1[key], result_[key])
    result2 = c6.collect(n_episode=9)
    for key in ["n/ep", "n/st", "rews", "lens"]:
        assert np.allclose(result2[key], result[key])
Beispiel #3
0
def test_multibuf_hdf5():
    size = 100
    buffers = {
        "vector": ReplayBufferManager([ReplayBuffer(size) for i in range(4)]),
        "cached": CachedReplayBuffer(ReplayBuffer(size), 4, size)
    }
    buffer_types = {k: b.__class__ for k, b in buffers.items()}
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    info_t = torch.tensor([1.]).to(device)
    for i in range(4):
        kwargs = {
            'obs': Batch(index=np.array([i])),
            'act': i,
            'rew': np.array([1, 2]),
            'done': i % 3 == 2,
            'info': {
                "number": {
                    "n": i,
                    "t": info_t
                },
                'extra': None
            },
        }
        buffers["vector"].add(**Batch.stack([kwargs, kwargs, kwargs]),
                              buffer_ids=[0, 1, 2])
        buffers["cached"].add(**Batch.stack([kwargs, kwargs, kwargs]),
                              cached_buffer_ids=[0, 1, 2])

    # save
    paths = {}
    for k, buf in buffers.items():
        f, path = tempfile.mkstemp(suffix='.hdf5')
        os.close(f)
        buf.save_hdf5(path)
        paths[k] = path

    # load replay buffer
    _buffers = {k: buffer_types[k].load_hdf5(paths[k]) for k in paths.keys()}

    # compare
    for k in buffers.keys():
        assert len(_buffers[k]) == len(buffers[k])
        assert np.allclose(_buffers[k].act, buffers[k].act)
        assert _buffers[k].stack_num == buffers[k].stack_num
        assert _buffers[k].maxsize == buffers[k].maxsize
        assert np.all(_buffers[k]._indices == buffers[k]._indices)
    # check shallow copy in ReplayBufferManager
    for k in ["vector", "cached"]:
        buffers[k].info.number.n[0] = -100
        assert buffers[k].buffers[0].info.number.n[0] == -100
    # check if still behave normally
    for k in ["vector", "cached"]:
        kwargs = {
            'obs': Batch(index=np.array([5])),
            'act': 5,
            'rew': np.array([2, 1]),
            'done': False,
            'info': {
                "number": {
                    "n": i
                },
                'Timelimit.truncate': True
            },
        }
        buffers[k].add(**Batch.stack([kwargs, kwargs, kwargs, kwargs]))
        act = np.zeros(buffers[k].maxsize)
        if k == "vector":
            act[np.arange(5)] = np.array([0, 1, 2, 3, 5])
            act[np.arange(5) + size] = np.array([0, 1, 2, 3, 5])
            act[np.arange(5) + size * 2] = np.array([0, 1, 2, 3, 5])
            act[size * 3] = 5
        elif k == "cached":
            act[np.arange(9)] = np.array([0, 1, 2, 0, 1, 2, 0, 1, 2])
            act[np.arange(3) + size] = np.array([3, 5, 2])
            act[np.arange(3) + size * 2] = np.array([3, 5, 2])
            act[np.arange(3) + size * 3] = np.array([3, 5, 2])
            act[size * 4] = 5
        assert np.allclose(buffers[k].act, act)

    for path in paths.values():
        os.remove(path)
Beispiel #4
0
def test_multibuf_stack():
    size = 5
    bufsize = 9
    stack_num = 4
    cached_num = 3
    env = MyTestEnv(size)
    # test if CachedReplayBuffer can handle stack_num + ignore_obs_next
    buf4 = CachedReplayBuffer(
        ReplayBuffer(bufsize, stack_num=stack_num, ignore_obs_next=True),
        cached_num, size)
    # test if CachedReplayBuffer can handle super corner case:
    # prio-buffer + stack_num + ignore_obs_next + sample_avail
    buf5 = CachedReplayBuffer(
        PrioritizedReplayBuffer(bufsize,
                                0.6,
                                0.4,
                                stack_num=stack_num,
                                ignore_obs_next=True,
                                sample_avail=True), cached_num, size)
    obs = env.reset(1)
    for i in range(18):
        obs_next, rew, done, info = env.step(1)
        obs_list = np.array([obs + size * i for i in range(cached_num)])
        act_list = [1] * cached_num
        rew_list = [rew] * cached_num
        done_list = [done] * cached_num
        obs_next_list = -obs_list
        info_list = [info] * cached_num
        buf4.add(obs_list, act_list, rew_list, done_list, obs_next_list,
                 info_list)
        buf5.add(obs_list, act_list, rew_list, done_list, obs_next_list,
                 info_list)
        obs = obs_next
        if done:
            obs = env.reset(1)
    # check the `add` order is correct
    assert np.allclose(
        buf4.obs.reshape(-1),
        [
            12,
            13,
            14,
            4,
            6,
            7,
            8,
            9,
            11,  # main_buffer
            1,
            2,
            3,
            4,
            0,  # cached_buffer[0]
            6,
            7,
            8,
            9,
            0,  # cached_buffer[1]
            11,
            12,
            13,
            14,
            0,  # cached_buffer[2]
        ]), buf4.obs
    assert np.allclose(
        buf4.done,
        [
            0,
            0,
            1,
            1,
            0,
            0,
            0,
            1,
            0,  # main_buffer
            0,
            0,
            0,
            1,
            0,  # cached_buffer[0]
            0,
            0,
            0,
            1,
            0,  # cached_buffer[1]
            0,
            0,
            0,
            1,
            0,  # cached_buffer[2]
        ]), buf4.done
    assert np.allclose(buf4.unfinished_index(), [10, 15, 20])
    indice = sorted(buf4.sample_index(0))
    assert np.allclose(indice, list(range(bufsize)) + [9, 10, 14, 15, 19, 20])
    assert np.allclose(buf4[indice].obs[..., 0], [
        [11, 11, 11, 12],
        [11, 11, 12, 13],
        [11, 12, 13, 14],
        [4, 4, 4, 4],
        [6, 6, 6, 6],
        [6, 6, 6, 7],
        [6, 6, 7, 8],
        [6, 7, 8, 9],
        [11, 11, 11, 11],
        [1, 1, 1, 1],
        [1, 1, 1, 2],
        [6, 6, 6, 6],
        [6, 6, 6, 7],
        [11, 11, 11, 11],
        [11, 11, 11, 12],
    ])
    assert np.allclose(buf4[indice].obs_next[..., 0], [
        [11, 11, 12, 13],
        [11, 12, 13, 14],
        [11, 12, 13, 14],
        [4, 4, 4, 4],
        [6, 6, 6, 7],
        [6, 6, 7, 8],
        [6, 7, 8, 9],
        [6, 7, 8, 9],
        [11, 11, 11, 12],
        [1, 1, 1, 2],
        [1, 1, 1, 2],
        [6, 6, 6, 7],
        [6, 6, 6, 7],
        [11, 11, 11, 12],
        [11, 11, 11, 12],
    ])
    assert np.all(buf4.done == buf5.done)
    indice = buf5.sample_index(0)
    assert np.allclose(sorted(indice), [2, 7])
    assert np.all(np.isin(buf5.sample_index(100), indice))
    # manually change the stack num
    buf5.stack_num = 2
    for buf in buf5.buffers:
        buf.stack_num = 2
    indice = buf5.sample_index(0)
    assert np.allclose(sorted(indice), [0, 1, 2, 5, 6, 7, 10, 15, 20])
    batch, _ = buf5.sample(0)
    assert np.allclose(buf5[np.arange(buf5.maxsize)].weight, 1)
    buf5.update_weight(indice, batch.weight * 0)
    weight = buf5[np.arange(buf5.maxsize)].weight
    modified_weight = weight[[0, 1, 2, 5, 6, 7]]
    assert modified_weight.min() == modified_weight.max()
    assert modified_weight.max() < 1
    unmodified_weight = weight[[3, 4, 8]]
    assert unmodified_weight.min() == unmodified_weight.max()
    assert unmodified_weight.max() < 1
    cached_weight = weight[9:]
    assert cached_weight.min() == cached_weight.max() == 1
    # test Atari with CachedReplayBuffer, save_only_last_obs + ignore_obs_next
    buf6 = CachedReplayBuffer(
        ReplayBuffer(bufsize,
                     stack_num=stack_num,
                     save_only_last_obs=True,
                     ignore_obs_next=True), cached_num, size)
    obs = np.random.rand(size, 4, 84, 84)
    buf6.add(obs=[obs[2], obs[0]],
             act=[1, 1],
             rew=[0, 0],
             done=[0, 1],
             obs_next=[obs[3], obs[1]],
             cached_buffer_ids=[1, 2])
    assert buf6.obs.shape == (buf6.maxsize, 84, 84)
    assert np.allclose(buf6.obs[0], obs[0, -1])
    assert np.allclose(buf6.obs[14], obs[2, -1])
    assert np.allclose(buf6.obs[19], obs[0, -1])
    assert buf6[0].obs.shape == (4, 84, 84)
Beispiel #5
0
def test_cachedbuffer():
    buf = CachedReplayBuffer(ReplayBuffer(10), 4, 5)
    assert buf.sample_index(0).tolist() == []
    # check the normal function/usage/storage in CachedReplayBuffer
    ep_len, ep_rew = buf.add(obs=[1],
                             act=[1],
                             rew=[1],
                             done=[0],
                             cached_buffer_ids=[1])
    obs = np.zeros(buf.maxsize)
    obs[15] = 1
    indice = buf.sample_index(0)
    assert np.allclose(indice, [15])
    assert np.allclose(buf.prev(indice), [15])
    assert np.allclose(buf.next(indice), [15])
    assert np.allclose(buf.obs, obs)
    assert np.allclose(ep_len, [0]) and np.allclose(ep_rew, [0.0])
    ep_len, ep_rew = buf.add(obs=[2],
                             act=[2],
                             rew=[2],
                             done=[1],
                             cached_buffer_ids=[3])
    obs[[0, 25]] = 2
    indice = buf.sample_index(0)
    assert np.allclose(indice, [0, 15])
    assert np.allclose(buf.prev(indice), [0, 15])
    assert np.allclose(buf.next(indice), [0, 15])
    assert np.allclose(buf.obs, obs)
    assert np.allclose(ep_len, [1]) and np.allclose(ep_rew, [2.0])
    assert np.allclose(buf.unfinished_index(), [15])
    assert np.allclose(buf.sample_index(0), [0, 15])
    ep_len, ep_rew = buf.add(obs=[3, 4],
                             act=[3, 4],
                             rew=[3, 4],
                             done=[0, 1],
                             cached_buffer_ids=[3, 1])
    assert np.allclose(ep_len, [0, 2]) and np.allclose(ep_rew, [0, 5.0])
    obs[[0, 1, 2, 15, 16, 25]] = [2, 1, 4, 1, 4, 3]
    assert np.allclose(buf.obs, obs)
    assert np.allclose(buf.unfinished_index(), [25])
    indice = buf.sample_index(0)
    assert np.allclose(indice, [0, 1, 2, 25])
    assert np.allclose(buf.done[indice], [1, 0, 1, 0])
    assert np.allclose(buf.prev(indice), [0, 1, 1, 25])
    assert np.allclose(buf.next(indice), [0, 2, 2, 25])
    indice = buf.sample_index(10000)
    assert np.bincount(indice)[[0, 1, 2, 25]].min() > 2000  # uniform sample
    # cached buffer with main_buffer size == 0 (no update)
    # used in test_collector
    buf = CachedReplayBuffer(ReplayBuffer(0, sample_avail=True), 4, 5)
    data = np.zeros(4)
    rew = np.ones([4, 4])
    buf.add(obs=data, act=data, rew=rew, done=[0, 0, 1, 1], obs_next=data)
    buf.add(obs=data, act=data, rew=rew, done=[0, 0, 0, 0], obs_next=data)
    buf.add(obs=data, act=data, rew=rew, done=[1, 1, 1, 1], obs_next=data)
    buf.add(obs=data, act=data, rew=rew, done=[0, 0, 0, 0], obs_next=data)
    buf.add(obs=data, act=data, rew=rew, done=[0, 1, 0, 1], obs_next=data)
    assert np.allclose(buf.done, [
        0,
        0,
        1,
        0,
        0,
        0,
        1,
        1,
        0,
        0,
        0,
        0,
        0,
        0,
        0,
        0,
        1,
        0,
        0,
        0,
    ])
    indice = buf.sample_index(0)
    assert np.allclose(indice, [0, 1, 10, 11])
    assert np.allclose(buf.prev(indice), [0, 0, 10, 10])
    assert np.allclose(buf.next(indice), [1, 1, 11, 11])