def test_replay_buffer_from_data(): obs = np.array([5, 2], dtype=int) acts = np.ones((2, 6), dtype=float) next_obs = np.array([7, 8], dtype=int) dones = np.array([True, False]) def _check_buf(buf): assert np.array_equal(buf._buffer._arrays["obs"], obs) assert np.array_equal(buf._buffer._arrays["next_obs"], next_obs) assert np.array_equal(buf._buffer._arrays["acts"], acts) buf_std = ReplayBuffer.from_data( types.Transitions( obs=obs, acts=acts, next_obs=next_obs, dones=dones, )) _check_buf(buf_std) rews = np.array([0.5, 1.0], dtype=float) buf_rew = ReplayBuffer.from_data( types.TransitionsWithRew( obs=obs, acts=acts, next_obs=next_obs, rews=rews, dones=dones, )) _check_buf(buf_rew)
def generate_transitions( policy, venv, n_timesteps: int, *, truncate: bool = True, **kwargs ) -> types.TransitionsWithRew: """Generate obs-action-next_obs-reward tuples. Args: policy (BasePolicy or BaseAlgorithm): A stable_baselines3 policy or algorithm, trained on the gym environment. venv: The vectorized environments to interact with. n_timesteps: The minimum number of timesteps to sample. truncate: If True, then drop any additional samples to ensure that exactly `n_timesteps` samples are returned. **kwargs: Passed-through to generate_trajectories. Returns: A batch of Transitions. The length of the constituent arrays is guaranteed to be at least `n_timesteps` (if specified), but may be greater unless `truncate` is provided as we collect data until the end of each episode. """ traj = generate_trajectories( policy, venv, sample_until=min_timesteps(n_timesteps), **kwargs ) transitions = flatten_trajectories_with_rew(traj) if truncate and n_timesteps is not None: as_dict = dataclasses.asdict(transitions) truncated = {k: arr[:n_timesteps] for k, arr in as_dict.items()} transitions = types.TransitionsWithRew(**truncated) return transitions
def flatten_trajectories_with_rew( trajectories: Sequence[types.TrajectoryWithRew], ) -> types.TransitionsWithRew: transitions = flatten_trajectories(trajectories) rews = np.concatenate([traj.rews for traj in trajectories]) return types.TransitionsWithRew(**dataclasses.asdict(transitions), rews=rews)
def _join_transitions( trans_list: Sequence[types.TransitionsWithRew], ) -> types.TransitionsWithRew: def concat(x): return np.concatenate(list(x)) obs = concat(t.obs for t in trans_list) next_obs = concat(t.next_obs for t in trans_list) rews = concat(t.rews for t in trans_list) acts = concat(t.acts for t in trans_list) dones = concat(t.dones for t in trans_list) return types.TransitionsWithRew( obs=obs, next_obs=next_obs, rews=rews, acts=acts, dones=dones )
def transitions_rew( transitions: types.Transitions, length: int ) -> types.TransitionsWithRew: """Like `transitions` but with reward randomly sampled from a Gaussian.""" rews = np.random.randn(length) return types.TransitionsWithRew(**dataclasses.asdict(transitions), rews=rews)