예제 #1
0
    def compute_nstep_return(
        batch: Batch,
        buffer: ReplayBuffer,
        indice: np.ndarray,
        target_q_fn: Callable[[ReplayBuffer, np.ndarray], torch.Tensor],
        gamma: float = 0.99,
        n_step: int = 1,
        rew_norm: bool = False,
        use_mixed: bool = False,
    ) -> Batch:
        r"""Compute n-step return for Q-learning targets.

        .. math::
            G_t = \sum_{i = t}^{t + n - 1} \gamma^{i - t}(1 - d_i)r_i +
            \gamma^n (1 - d_{t + n}) Q_{\mathrm{target}}(s_{t + n})

        where :math:`\gamma` is the discount factor, :math:`\gamma \in [0, 1]`,
        :math:`d_t` is the done flag of step :math:`t`.

        :param Batch batch: a data batch, which is equal to buffer[indice].
        :param ReplayBuffer buffer: the data buffer.
        :param function target_q_fn: a function which compute target Q value
            of "obs_next" given data buffer and wanted indices.
        :param float gamma: the discount factor, should be in [0, 1]. Default to 0.99.
        :param int n_step: the number of estimation step, should be an int greater
            than 0. Default to 1.
        :param bool rew_norm: normalize the reward to Normal(0, 1), Default to False.

        :return: a Batch. The result will be stored in batch.returns as a
            torch.Tensor with the same shape as target_q_fn's return tensor.
        """
        assert not rew_norm, \
            "Reward normalization in computing n-step returns is unsupported now."
        rew = buffer.rew
        bsz = len(indice)
        indices = [indice]
        for _ in range(n_step - 1):
            indices.append(buffer.next(indices[-1]))
        indices = np.stack(indices)
        # terminal indicates buffer indexes nstep after 'indice',
        # and are truncated at the end of each episode
        terminal = indices[-1]
        with autocast(enabled=use_mixed):
            with torch.no_grad():
                target_q_torch = target_q_fn(buffer, terminal)  # (bsz, ?)
        target_q = to_numpy(target_q_torch.float().reshape(bsz, -1))
        target_q = target_q * BasePolicy.value_mask(buffer, terminal).reshape(
            -1, 1)
        end_flag = buffer.done.copy()
        end_flag[buffer.unfinished_index()] = True
        target_q = _nstep_return(rew, end_flag, target_q, indices, gamma,
                                 n_step)

        batch.returns = to_torch_as(target_q, target_q_torch)
        if hasattr(batch, "weight"):  # prio buffer update
            batch.weight = to_torch_as(batch.weight, target_q_torch)
        return batch
예제 #2
0
def test_replaybuffer(size=10, bufsize=20):
    env = MyTestEnv(size)
    buf = ReplayBuffer(bufsize)
    buf.update(buf)
    assert str(buf) == buf.__class__.__name__ + '()'
    obs = env.reset()
    action_list = [1] * 5 + [0] * 10 + [1] * 10
    for i, a in enumerate(action_list):
        obs_next, rew, done, info = env.step(a)
        buf.add(
            Batch(obs=obs,
                  act=[a],
                  rew=rew,
                  done=done,
                  obs_next=obs_next,
                  info=info))
        obs = obs_next
        assert len(buf) == min(bufsize, i + 1)
    assert buf.act.dtype == int
    assert buf.act.shape == (bufsize, 1)
    data, indices = buf.sample(bufsize * 2)
    assert (indices < len(buf)).all()
    assert (data.obs < size).all()
    assert (0 <= data.done).all() and (data.done <= 1).all()
    b = ReplayBuffer(size=10)
    # neg bsz should return empty index
    assert b.sample_indices(-1).tolist() == []
    ptr, ep_rew, ep_len, ep_idx = b.add(
        Batch(obs=1,
              act=1,
              rew=1,
              done=1,
              obs_next='str',
              info={
                  'a': 3,
                  'b': {
                      'c': 5.0
                  }
              }))
    assert b.obs[0] == 1
    assert b.done[0]
    assert b.obs_next[0] == 'str'
    assert np.all(b.obs[1:] == 0)
    assert np.all(b.obs_next[1:] == np.array(None))
    assert b.info.a[0] == 3 and b.info.a.dtype == int
    assert np.all(b.info.a[1:] == 0)
    assert b.info.b.c[0] == 5.0 and b.info.b.c.dtype == float
    assert np.all(b.info.b.c[1:] == 0.0)
    assert ptr.shape == (1, ) and ptr[0] == 0
    assert ep_rew.shape == (1, ) and ep_rew[0] == 1
    assert ep_len.shape == (1, ) and ep_len[0] == 1
    assert ep_idx.shape == (1, ) and ep_idx[0] == 0
    # test extra keys pop up, the buffer should handle it dynamically
    batch = Batch(obs=2,
                  act=2,
                  rew=2,
                  done=0,
                  obs_next="str2",
                  info={
                      "a": 4,
                      "d": {
                          "e": -np.inf
                      }
                  })
    b.add(batch)
    info_keys = ["a", "b", "d"]
    assert set(b.info.keys()) == set(info_keys)
    assert b.info.a[1] == 4 and b.info.b.c[1] == 0
    assert b.info.d.e[1] == -np.inf
    # test batch-style adding method, where len(batch) == 1
    batch.done = [1]
    batch.info.e = np.zeros([1, 4])
    batch = Batch.stack([batch])
    ptr, ep_rew, ep_len, ep_idx = b.add(batch, buffer_ids=[0])
    assert ptr.shape == (1, ) and ptr[0] == 2
    assert ep_rew.shape == (1, ) and ep_rew[0] == 4
    assert ep_len.shape == (1, ) and ep_len[0] == 2
    assert ep_idx.shape == (1, ) and ep_idx[0] == 1
    assert set(b.info.keys()) == set(info_keys + ["e"])
    assert b.info.e.shape == (b.maxsize, 1, 4)
    with pytest.raises(IndexError):
        b[22]
    # test prev / next
    assert np.all(b.prev(np.array([0, 1, 2])) == [0, 1, 1])
    assert np.all(b.next(np.array([0, 1, 2])) == [0, 2, 2])
    batch.done = [0]
    b.add(batch, buffer_ids=[0])
    assert np.all(b.prev(np.array([0, 1, 2, 3])) == [0, 1, 1, 3])
    assert np.all(b.next(np.array([0, 1, 2, 3])) == [0, 2, 2, 3])