예제 #1
0
def test_episodic_returns(size=2560):
    fn = BasePolicy.compute_episodic_return
    buf = ReplayBuffer(20)
    batch = Batch(
        done=np.array([1, 0, 0, 1, 0, 1, 0, 1.]),
        rew=np.array([0, 1, 2, 3, 4, 5, 6, 7.]),
        info=Batch({
            'TimeLimit.truncated':
            np.array([False, False, False, False, False, True, False, False])
        }))
    for b in batch:
        b.obs = b.act = 1
        buf.add(b)
    returns, _ = fn(batch, buf, buf.sample_indices(0), gamma=.1, gae_lambda=1)
    ans = np.array([0, 1.23, 2.3, 3, 4.5, 5, 6.7, 7])
    assert np.allclose(returns, ans)
    buf.reset()
    batch = Batch(
        done=np.array([0, 1, 0, 1, 0, 1, 0.]),
        rew=np.array([7, 6, 1, 2, 3, 4, 5.]),
    )
    for b in batch:
        b.obs = b.act = 1
        buf.add(b)
    returns, _ = fn(batch, buf, buf.sample_indices(0), gamma=.1, gae_lambda=1)
    ans = np.array([7.6, 6, 1.2, 2, 3.4, 4, 5])
    assert np.allclose(returns, ans)
    buf.reset()
    batch = Batch(
        done=np.array([0, 1, 0, 1, 0, 0, 1.]),
        rew=np.array([7, 6, 1, 2, 3, 4, 5.]),
    )
    for b in batch:
        b.obs = b.act = 1
        buf.add(b)
    returns, _ = fn(batch, buf, buf.sample_indices(0), gamma=.1, gae_lambda=1)
    ans = np.array([7.6, 6, 1.2, 2, 3.45, 4.5, 5])
    assert np.allclose(returns, ans)
    buf.reset()
    batch = Batch(
        done=np.array([0, 0, 0, 1., 0, 0, 0, 1, 0, 0, 0, 1]),
        rew=np.array(
            [101, 102, 103., 200, 104, 105, 106, 201, 107, 108, 109, 202]),
    )
    for b in batch:
        b.obs = b.act = 1
        buf.add(b)
    v = np.array([2., 3., 4, -1, 5., 6., 7, -2, 8., 9., 10, -3])
    returns, _ = fn(batch,
                    buf,
                    buf.sample_indices(0),
                    v,
                    gamma=0.99,
                    gae_lambda=0.95)
    ground_truth = np.array([
        454.8344, 376.1143, 291.298, 200., 464.5610, 383.1085, 295.387, 201.,
        474.2876, 390.1027, 299.476, 202.
    ])
    assert np.allclose(returns, ground_truth)
    buf.reset()
    batch = Batch(done=np.array([0, 0, 0, 1., 0, 0, 0, 1, 0, 0, 0, 1]),
                  rew=np.array([
                      101, 102, 103., 200, 104, 105, 106, 201, 107, 108, 109,
                      202
                  ]),
                  info=Batch({
                      'TimeLimit.truncated':
                      np.array([
                          False, False, False, True, False, False, False, True,
                          False, False, False, False
                      ])
                  }))
    for b in batch:
        b.obs = b.act = 1
        buf.add(b)
    v = np.array([2., 3., 4, -1, 5., 6., 7, -2, 8., 9., 10, -3])
    returns, _ = fn(batch,
                    buf,
                    buf.sample_indices(0),
                    v,
                    gamma=0.99,
                    gae_lambda=0.95)
    ground_truth = np.array([
        454.0109, 375.2386, 290.3669, 199.01, 462.9138, 381.3571, 293.5248,
        199.02, 474.2876, 390.1027, 299.476, 202.
    ])
    assert np.allclose(returns, ground_truth)

    if __name__ == '__main__':
        buf = ReplayBuffer(size)
        batch = Batch(
            done=np.random.randint(100, size=size) == 0,
            rew=np.random.random(size),
        )
        for b in batch:
            b.obs = b.act = 1
            buf.add(b)
        indices = buf.sample_indices(0)

        def vanilla():
            return compute_episodic_return_base(batch, gamma=.1)

        def optimized():
            return fn(batch, buf, indices, gamma=.1, gae_lambda=1.0)

        cnt = 3000
        print('GAE vanilla', timeit(vanilla, setup=vanilla, number=cnt))
        print('GAE optim  ', timeit(optimized, setup=optimized, number=cnt))
예제 #2
0
def test_replaybuffer(size=10, bufsize=20):
    env = MyTestEnv(size)
    buf = ReplayBuffer(bufsize)
    buf.update(buf)
    assert str(buf) == buf.__class__.__name__ + '()'
    obs = env.reset()
    action_list = [1] * 5 + [0] * 10 + [1] * 10
    for i, a in enumerate(action_list):
        obs_next, rew, done, info = env.step(a)
        buf.add(
            Batch(obs=obs,
                  act=[a],
                  rew=rew,
                  done=done,
                  obs_next=obs_next,
                  info=info))
        obs = obs_next
        assert len(buf) == min(bufsize, i + 1)
    assert buf.act.dtype == int
    assert buf.act.shape == (bufsize, 1)
    data, indices = buf.sample(bufsize * 2)
    assert (indices < len(buf)).all()
    assert (data.obs < size).all()
    assert (0 <= data.done).all() and (data.done <= 1).all()
    b = ReplayBuffer(size=10)
    # neg bsz should return empty index
    assert b.sample_indices(-1).tolist() == []
    ptr, ep_rew, ep_len, ep_idx = b.add(
        Batch(obs=1,
              act=1,
              rew=1,
              done=1,
              obs_next='str',
              info={
                  'a': 3,
                  'b': {
                      'c': 5.0
                  }
              }))
    assert b.obs[0] == 1
    assert b.done[0]
    assert b.obs_next[0] == 'str'
    assert np.all(b.obs[1:] == 0)
    assert np.all(b.obs_next[1:] == np.array(None))
    assert b.info.a[0] == 3 and b.info.a.dtype == int
    assert np.all(b.info.a[1:] == 0)
    assert b.info.b.c[0] == 5.0 and b.info.b.c.dtype == float
    assert np.all(b.info.b.c[1:] == 0.0)
    assert ptr.shape == (1, ) and ptr[0] == 0
    assert ep_rew.shape == (1, ) and ep_rew[0] == 1
    assert ep_len.shape == (1, ) and ep_len[0] == 1
    assert ep_idx.shape == (1, ) and ep_idx[0] == 0
    # test extra keys pop up, the buffer should handle it dynamically
    batch = Batch(obs=2,
                  act=2,
                  rew=2,
                  done=0,
                  obs_next="str2",
                  info={
                      "a": 4,
                      "d": {
                          "e": -np.inf
                      }
                  })
    b.add(batch)
    info_keys = ["a", "b", "d"]
    assert set(b.info.keys()) == set(info_keys)
    assert b.info.a[1] == 4 and b.info.b.c[1] == 0
    assert b.info.d.e[1] == -np.inf
    # test batch-style adding method, where len(batch) == 1
    batch.done = [1]
    batch.info.e = np.zeros([1, 4])
    batch = Batch.stack([batch])
    ptr, ep_rew, ep_len, ep_idx = b.add(batch, buffer_ids=[0])
    assert ptr.shape == (1, ) and ptr[0] == 2
    assert ep_rew.shape == (1, ) and ep_rew[0] == 4
    assert ep_len.shape == (1, ) and ep_len[0] == 2
    assert ep_idx.shape == (1, ) and ep_idx[0] == 1
    assert set(b.info.keys()) == set(info_keys + ["e"])
    assert b.info.e.shape == (b.maxsize, 1, 4)
    with pytest.raises(IndexError):
        b[22]
    # test prev / next
    assert np.all(b.prev(np.array([0, 1, 2])) == [0, 1, 1])
    assert np.all(b.next(np.array([0, 1, 2])) == [0, 2, 2])
    batch.done = [0]
    b.add(batch, buffer_ids=[0])
    assert np.all(b.prev(np.array([0, 1, 2, 3])) == [0, 1, 1, 3])
    assert np.all(b.next(np.array([0, 1, 2, 3])) == [0, 2, 2, 3])