Exemple #1
0
def test_replay_buffer_from_data():
    obs = np.array([5, 2], dtype=int)
    acts = np.ones((2, 6), dtype=float)
    next_obs = np.array([7, 8], dtype=int)
    dones = np.array([True, False])

    def _check_buf(buf):
        assert np.array_equal(buf._buffer._arrays["obs"], obs)
        assert np.array_equal(buf._buffer._arrays["next_obs"], next_obs)
        assert np.array_equal(buf._buffer._arrays["acts"], acts)

    buf_std = ReplayBuffer.from_data(
        types.Transitions(
            obs=obs,
            acts=acts,
            next_obs=next_obs,
            dones=dones,
        ))
    _check_buf(buf_std)

    rews = np.array([0.5, 1.0], dtype=float)
    buf_rew = ReplayBuffer.from_data(
        types.TransitionsWithRew(
            obs=obs,
            acts=acts,
            next_obs=next_obs,
            rews=rews,
            dones=dones,
        ))
    _check_buf(buf_rew)
Exemple #2
0
def generate_transitions(
    policy, venv, n_timesteps: int, *, truncate: bool = True, **kwargs
) -> types.TransitionsWithRew:
    """Generate obs-action-next_obs-reward tuples.

    Args:
      policy (BasePolicy or BaseAlgorithm): A stable_baselines3 policy or
          algorithm, trained on the gym environment.
      venv: The vectorized environments to interact with.
      n_timesteps: The minimum number of timesteps to sample.
      truncate: If True, then drop any additional samples to ensure that exactly
          `n_timesteps` samples are returned.
      **kwargs: Passed-through to generate_trajectories.

    Returns:
      A batch of Transitions. The length of the constituent arrays is guaranteed
      to be at least `n_timesteps` (if specified), but may be greater unless
      `truncate` is provided as we collect data until the end of each episode.
    """
    traj = generate_trajectories(
        policy, venv, sample_until=min_timesteps(n_timesteps), **kwargs
    )
    transitions = flatten_trajectories_with_rew(traj)
    if truncate and n_timesteps is not None:
        as_dict = dataclasses.asdict(transitions)
        truncated = {k: arr[:n_timesteps] for k, arr in as_dict.items()}
        transitions = types.TransitionsWithRew(**truncated)
    return transitions
Exemple #3
0
def flatten_trajectories_with_rew(
    trajectories: Sequence[types.TrajectoryWithRew],
) -> types.TransitionsWithRew:
    transitions = flatten_trajectories(trajectories)
    rews = np.concatenate([traj.rews for traj in trajectories])
    return types.TransitionsWithRew(**dataclasses.asdict(transitions),
                                    rews=rews)
Exemple #4
0
def _join_transitions(
    trans_list: Sequence[types.TransitionsWithRew],
) -> types.TransitionsWithRew:
    def concat(x):
        return np.concatenate(list(x))

    obs = concat(t.obs for t in trans_list)
    next_obs = concat(t.next_obs for t in trans_list)
    rews = concat(t.rews for t in trans_list)
    acts = concat(t.acts for t in trans_list)
    dones = concat(t.dones for t in trans_list)
    return types.TransitionsWithRew(
        obs=obs, next_obs=next_obs, rews=rews, acts=acts, dones=dones
    )
Exemple #5
0
def transitions_rew(
    transitions: types.Transitions, length: int
) -> types.TransitionsWithRew:
    """Like `transitions` but with reward randomly sampled from a Gaussian."""
    rews = np.random.randn(length)
    return types.TransitionsWithRew(**dataclasses.asdict(transitions), rews=rews)