Пример #1
0
def test_replay_buffer_init_errors():
    with pytest.raises(ValueError, match=r"Specified.* and environment"):
        ReplayBuffer(15, env="MockEnv", obs_shape=(10, 10))
    with pytest.raises(ValueError, match=r"Shape or dtype missing.*"):
        ReplayBuffer(15, obs_shape=(10, 10), act_shape=(15, ), obs_dtype=bool)
    with pytest.raises(ValueError, match=r"Shape or dtype missing.*"):
        ReplayBuffer(15, obs_shape=(10, 10), obs_dtype=bool, act_dtype=bool)
Пример #2
0
def test_replay_buffer_store_errors():
    b = ReplayBuffer(10,
                     obs_shape=(),
                     obs_dtype=bool,
                     act_shape=(),
                     act_dtype=float)
    with pytest.raises(ValueError, match=".* same length.*"):
        b.store(np.ones(4), np.ones(4), np.ones(3))
Пример #3
0
def test_replay_buffer(capacity, chunk_len, obs_shape, act_shape, dtype):
    """Builds a ReplayBuffer with the provided `capacity` and inserts
  `capacity * 3` observation-action-observation samples into the buffer in
  chunks of length `chunk_len`.

  All chunks are of the appropriate observation or action shape, and contain
  the value fill_val.

  `len(buffer)` should increase until we reach capacity.
  `buffer._idx` should loop between 0 and `capacity - 1`.
  After every insertion, samples should only contain 66.6.
  """
    buf = ReplayBuffer(capacity,
                       obs_shape=obs_shape,
                       act_shape=act_shape,
                       obs_dtype=dtype,
                       act_dtype=dtype)

    for i in range(0, capacity * 3, chunk_len):
        assert len(buf) == min(i, capacity)
        assert buf._buffer._idx == i % capacity

        old_obs_data = _fill_chunk(i, chunk_len, obs_shape, dtype=dtype)
        new_obs_data = _fill_chunk(3 * capacity + i,
                                   chunk_len,
                                   obs_shape,
                                   dtype=dtype)
        act_data = _fill_chunk(6 * capacity + i,
                               chunk_len,
                               act_shape,
                               dtype=dtype)

        buf.store(old_obs_data, act_data, new_obs_data)

        # Are samples right shape?
        old_obs, acts, new_obs = buf.sample(100)
        assert old_obs.shape == new_obs.shape == (100, ) + obs_shape
        assert acts.shape == (100, ) + act_shape

        # Are samples right data type?
        assert old_obs.dtype == dtype
        assert acts.dtype == dtype
        assert new_obs.dtype == dtype

        # Are samples in range?
        _check_bound(i + chunk_len, capacity, old_obs)
        _check_bound(i + chunk_len, capacity, new_obs, 3 * capacity)
        _check_bound(i + chunk_len, capacity, acts, 6 * capacity)

        # Are samples in-order?
        old_obs_fill = _get_fill_from_chunk(old_obs)
        new_obs_fill = _get_fill_from_chunk(new_obs)
        act_fill = _get_fill_from_chunk(acts)

        assert np.all(new_obs_fill - old_obs_fill == 3 *
                      capacity), "out of order"
        assert np.all(act_fill - new_obs_fill == 3 * capacity), "out of order"
Пример #4
0
def test_replay_buffer_from_data():
    old_obs = np.array([5, 2], dtype=int)
    act = np.ones((2, 6), dtype=float)
    new_obs = np.array([7, 8], dtype=int)
    buf = ReplayBuffer.from_data(old_obs, act, new_obs)
    assert np.array_equal(buf._buffer._arrays['old_obs'], old_obs)
    assert np.array_equal(buf._buffer._arrays['new_obs'], new_obs)
    assert np.array_equal(buf._buffer._arrays['act'], act)

    with pytest.raises(ValueError, match=r".*same length."):
        new_obs_toolong = np.array([7, 8, 9], dtype=int)
        ReplayBuffer.from_data(old_obs, act, new_obs_toolong)
    with pytest.raises(ValueError, match=r".*same dtype."):
        new_obs_float = np.array(new_obs, dtype=float)
        ReplayBuffer.from_data(old_obs, act, new_obs_float)
Пример #5
0
def test_replay_buffer_store_errors():
  b = ReplayBuffer(10, obs_shape=(), obs_dtype=bool, act_shape=(),
                   act_dtype=float)

  dtypes = {
      'obs': np.float32,
      'next_obs': np.float32,
      'acts': np.float32,
      'rews': np.float32,
      'dones': np.bool,
  }
  for odd_field in dtypes.keys():
    with pytest.raises(ValueError, match=".* same length.*"):
      transition = {k: np.ones(3 if k == odd_field else 4, dtype=dtype)
                    for k, dtype in dtypes.items()}
      transition = rollout.Transitions(**transition)
      b.store(transition)
Пример #6
0
def test_replay_buffer_from_data():
  obs = np.array([5, 2], dtype=int)
  acts = np.ones((2, 6), dtype=float)
  next_obs = np.array([7, 8], dtype=int)
  rews = np.array([0.5, 1.0], dtype=float)
  dones = np.array([True, False])
  buf = ReplayBuffer.from_data(rollout.Transitions(
      obs=obs, acts=acts, next_obs=next_obs, rews=rews, dones=dones,
  ))
  assert np.array_equal(buf._buffer._arrays['obs'], obs)
  assert np.array_equal(buf._buffer._arrays['next_obs'], next_obs)
  assert np.array_equal(buf._buffer._arrays['acts'], acts)

  with pytest.raises(ValueError, match=r".*same length."):
    next_obs_toolong = np.array([7, 8, 9], dtype=int)
    ReplayBuffer.from_data(rollout.Transitions(
        obs=obs, acts=acts, next_obs=next_obs_toolong, rews=rews, dones=dones,
    ))
  with pytest.raises(ValueError, match=r".*same dtype."):
    next_obs_float = np.array(next_obs, dtype=float)
    ReplayBuffer.from_data(rollout.Transitions(
        obs=obs, acts=acts, next_obs=next_obs_float, rews=rews, dones=dones,
    ))
Пример #7
0
    def __init__(self,
                 env: Union[gym.Env, str],
                 gen_policy: BaseRLModel,
                 discrim: DiscrimNet,
                 expert_policies: Sequence[BaseRLModel],
                 *,
                 disc_opt_cls: tf.train.Optimizer = tf.train.AdamOptimizer,
                 disc_opt_kwargs: dict = {},
                 n_disc_samples_per_buffer: int = 200,
                 n_expert_samples: int = 4000,
                 gen_replay_buffer_capacity: Optional[int] = None,
                 init_tensorboard: bool = False,
                 debug_use_ground_truth: bool = False):
        """Builds Trainer.

    Args:
        env: A Gym environment or ID that the policy is trained on.
        gen_policy: The generator policy that trained to maximize discriminator
                    confusion.
        discrim: The discriminator network.
            For GAIL, use a DiscrimNetGAIL. For AIRL, use a DiscrimNetAIRL.
        expert_policies: An expert policy
            or a list of expert policies that are used to generate example
            obs-action-obs triples.

            WARNING:
            Due to the way VecEnvs handle episode completion states, the last
            obs-act-obs triple in every episode is omitted. (See issue #1.)
        disc_opt_cls: The optimizer for discriminator training.
        disc_opt_kwargs: Parameters for discriminator training.
        n_disc_samples_per_buffer: The number of obs-act-obs triples
            sampled from each replay buffer (expert and generator) during each
            step of discriminator training. This is also the number of triples
            stored in the replay buffer after each epoch of generator training.
        n_expert_samples: The number of expert obs-action-obs triples
            that are generated. If the number of expert policies given
            doesn't divide this number evenly, then the last expert policy
            generates more timesteps.
        gen_replay_buffer_capacity: The capacity of the
            generator replay buffer (the number of obs-action-obs samples from
            the generator that can be stored).

            By default this is equal to `20 * n_disc_samples_per_buffer`.
        init_tensorboard: If True, makes various discriminator
            TensorBoard summaries. (Generator summaries appear under a
            different runname than the discriminator summaries because they
            are configured by initializing the stable_baselines policy).
        debug_use_ground_truth: If True, use the ground truth reward.
            This disables the reward wrapping that would normally replace
            the environment reward with the learned reward. This is useful for
            sanity checking that the policy training is functional.
    """
        if n_disc_samples_per_buffer > n_expert_samples:
            warn("The discriminator batch size is larger than the number of "
                 "expert samples.")

        # TODO(adam): we're not guaranteed to use this session, see issue #31
        self._sess = tf.Session()

        self.env = maybe_load_env(env, vectorize=True)
        self.gen_policy = gen_policy
        self.expert_policies = expert_policies
        self._n_disc_samples_per_buffer = n_disc_samples_per_buffer
        self.debug_use_ground_truth = debug_use_ground_truth

        self._global_step = tf.train.create_global_step()

        # Discriminator and reward output
        self._disc_opt_cls = disc_opt_cls
        self._disc_opt_kwargs = disc_opt_kwargs
        with tf.variable_scope("trainer"):
            with tf.variable_scope("discriminator"):
                self.discrim = discrim
                self._build_disc_train()
            self._build_policy_train_reward()
            self._build_test_reward()
        self._init_tensorboard = init_tensorboard
        if init_tensorboard:
            with tf.name_scope("summaries"):
                self._build_summarize()

        self._sess.run(tf.global_variables_initializer())

        # TODO(adam): make this wrapping configurable for debugging purposes
        self.env = self.wrap_env_train_reward(self.env)
        self.gen_policy.set_env(self.env)

        if gen_replay_buffer_capacity is None:
            gen_replay_buffer_capacity = 20 * self._n_disc_samples_per_buffer
        self._gen_replay_buffer = ReplayBuffer(gen_replay_buffer_capacity,
                                               self.env)
        self._populate_gen_replay_buffer()

        exp_rollouts = rollout.generate_multiple(self.expert_policies,
                                                 self.env,
                                                 n_expert_samples)[:3]
        self._exp_replay_buffer = ReplayBuffer.from_data(*exp_rollouts)
Пример #8
0
class Trainer:
    """Trainer for GAIL and AIRL."""
    def __init__(self,
                 env: Union[gym.Env, str],
                 gen_policy: BaseRLModel,
                 discrim: DiscrimNet,
                 expert_policies: Sequence[BaseRLModel],
                 *,
                 disc_opt_cls: tf.train.Optimizer = tf.train.AdamOptimizer,
                 disc_opt_kwargs: dict = {},
                 n_disc_samples_per_buffer: int = 200,
                 n_expert_samples: int = 4000,
                 gen_replay_buffer_capacity: Optional[int] = None,
                 init_tensorboard: bool = False,
                 debug_use_ground_truth: bool = False):
        """Builds Trainer.

    Args:
        env: A Gym environment or ID that the policy is trained on.
        gen_policy: The generator policy that trained to maximize discriminator
                    confusion.
        discrim: The discriminator network.
            For GAIL, use a DiscrimNetGAIL. For AIRL, use a DiscrimNetAIRL.
        expert_policies: An expert policy
            or a list of expert policies that are used to generate example
            obs-action-obs triples.

            WARNING:
            Due to the way VecEnvs handle episode completion states, the last
            obs-act-obs triple in every episode is omitted. (See issue #1.)
        disc_opt_cls: The optimizer for discriminator training.
        disc_opt_kwargs: Parameters for discriminator training.
        n_disc_samples_per_buffer: The number of obs-act-obs triples
            sampled from each replay buffer (expert and generator) during each
            step of discriminator training. This is also the number of triples
            stored in the replay buffer after each epoch of generator training.
        n_expert_samples: The number of expert obs-action-obs triples
            that are generated. If the number of expert policies given
            doesn't divide this number evenly, then the last expert policy
            generates more timesteps.
        gen_replay_buffer_capacity: The capacity of the
            generator replay buffer (the number of obs-action-obs samples from
            the generator that can be stored).

            By default this is equal to `20 * n_disc_samples_per_buffer`.
        init_tensorboard: If True, makes various discriminator
            TensorBoard summaries. (Generator summaries appear under a
            different runname than the discriminator summaries because they
            are configured by initializing the stable_baselines policy).
        debug_use_ground_truth: If True, use the ground truth reward.
            This disables the reward wrapping that would normally replace
            the environment reward with the learned reward. This is useful for
            sanity checking that the policy training is functional.
    """
        if n_disc_samples_per_buffer > n_expert_samples:
            warn("The discriminator batch size is larger than the number of "
                 "expert samples.")

        # TODO(adam): we're not guaranteed to use this session, see issue #31
        self._sess = tf.Session()

        self.env = maybe_load_env(env, vectorize=True)
        self.gen_policy = gen_policy
        self.expert_policies = expert_policies
        self._n_disc_samples_per_buffer = n_disc_samples_per_buffer
        self.debug_use_ground_truth = debug_use_ground_truth

        self._global_step = tf.train.create_global_step()

        # Discriminator and reward output
        self._disc_opt_cls = disc_opt_cls
        self._disc_opt_kwargs = disc_opt_kwargs
        with tf.variable_scope("trainer"):
            with tf.variable_scope("discriminator"):
                self.discrim = discrim
                self._build_disc_train()
            self._build_policy_train_reward()
            self._build_test_reward()
        self._init_tensorboard = init_tensorboard
        if init_tensorboard:
            with tf.name_scope("summaries"):
                self._build_summarize()

        self._sess.run(tf.global_variables_initializer())

        # TODO(adam): make this wrapping configurable for debugging purposes
        self.env = self.wrap_env_train_reward(self.env)
        self.gen_policy.set_env(self.env)

        if gen_replay_buffer_capacity is None:
            gen_replay_buffer_capacity = 20 * self._n_disc_samples_per_buffer
        self._gen_replay_buffer = ReplayBuffer(gen_replay_buffer_capacity,
                                               self.env)
        self._populate_gen_replay_buffer()

        exp_rollouts = rollout.generate_multiple(self.expert_policies,
                                                 self.env,
                                                 n_expert_samples)[:3]
        self._exp_replay_buffer = ReplayBuffer.from_data(*exp_rollouts)

    def train_disc(self, n_steps=10, **kwargs):
        """Trains the discriminator to minimize classification cross-entropy.

    Args:
        n_steps (int): The number of training steps.
        gen_old_obs (np.ndarray): See `_build_disc_feed_dict`.
        gen_act (np.ndarray): See `_build_disc_feed_dict`.
        gen_new_obs (np.ndarray): See `_build_disc_feed_dict`.
    """
        for _ in range(n_steps):
            fd = self._build_disc_feed_dict(**kwargs)
            step, _ = self._sess.run([self._global_step, self._disc_train_op],
                                     feed_dict=fd)
            if self._init_tensorboard and step % 20 == 0:
                self._summarize(fd, step)

    def train_gen(self, n_steps=10000):
        self.gen_policy.set_env(self.env)
        # TODO(adam): learn was not intended to be called for each training batch
        # It should work, but might incur unnecessary overhead: e.g. in PPO2
        # a new Runner instance is created each time. Also a hotspot for errors:
        # algorithms not tested for this use case, may reset state accidentally.
        self.gen_policy.learn(n_steps, reset_num_timesteps=False)
        self._populate_gen_replay_buffer()

    def _populate_gen_replay_buffer(self) -> None:
        """Generate and store generator samples in the buffer.

    More specifically, rolls out generator-policy trajectories in the
    environment until `self._n_disc_samples_per_buffer` obs-act-obs samples are
    produced, and then stores these samples.
    """
        gen_rollouts = rollout.flatten_trajectories(
            rollout.generate(self.gen_policy,
                             self.env,
                             n_timesteps=self._n_disc_samples_per_buffer))[:3]
        self._gen_replay_buffer.store(*gen_rollouts)

    def train(self,
              *,
              n_epochs=100,
              n_gen_steps_per_epoch=None,
              n_disc_steps_per_epoch=None):
        """Trains the discriminator and generator against each other.

    Args:
        n_epochs (int): The number of epochs to train. Every epoch consists
            of training the discriminator and then training the generator.
        n_disc_steps_per_epoch (int): The number of steps to train the
            discriminator every epoch. More precisely, the number of full batch
            Adam optimizer steps to perform.
        n_gen_steps_per_epoch (int): The number of generator training steps
            during each epoch. (ie, the timesteps argument in in
            `policy.learn(timesteps)`).
    """
        for i in tqdm(range(n_epochs), desc="AIRL train"):
            self.train_disc(**_n_steps_if_not_none(n_disc_steps_per_epoch))
            self.train_gen(**_n_steps_if_not_none(n_gen_steps_per_epoch))

    def eval_disc_loss(self, **kwargs):
        """Evaluates the discriminator loss.

    The generator rollout parameters of the form "gen_*" are optional,
    but if one is given, then all such parameters must be filled (otherwise
    this method will error). If none of the generator rollout parameters are
    given, then a rollout with the same length as the expert rollout
    is generated on the fly.

    Args:
        gen_old_obs (np.ndarray): See `_build_disc_feed_dict`.
        gen_act (np.ndarray): See `_build_disc_feed_dict`.
        gen_new_obs (np.ndarray): See `_build_disc_feed_dict`.

    Returns:
        discriminator_loss (float): The total cross-entropy error in the
            discriminator's classification.
    """
        fd = self._build_disc_feed_dict(**kwargs)
        return np.mean(self._sess.run(self.discrim.disc_loss, feed_dict=fd))

    def wrap_env_train_reward(self, env):
        """Returns the given Env wrapped with a reward function that returns
    the AIRL training reward (discriminator confusion).

    The wrapped `Env`'s reward is directly evaluated from the reward network,
    and therefore changes whenever `self.train()` is called.

    Args:
        env (str, Env, or VecEnv): The Env that we want to wrap. If a
            string environment name is given or a Env is given, then we first
            convert to a VecEnv before continuing.
    wrapped_env (VecEnv): The wrapped environment with a new reward.
    """
        env = maybe_load_env(env, vectorize=True)
        if self.debug_use_ground_truth:
            return env
        else:
            return _RewardVecEnvWrapper(env, self._policy_train_reward_fn)

    def wrap_env_test_reward(self, env):
        """Returns the given Env wrapped with a reward function that returns
    the reward learned by this Trainer.

    The wrapped `Env`'s reward is directly evaluated from the reward network,
    and therefore changes whenever `self.train()` is called.

    Args:
        env (str, Env, or VecEnv): The Env that should be wrapped. If a
            string environment name is given or a Env is given, then we first
            make a VecEnv before continuing.

    Returns:
        wrapped_env (VecEnv): The wrapped environment with a new reward.
    """
        env = maybe_load_env(env, vectorize=True)
        if self.debug_use_ground_truth:
            return env
        else:
            return _RewardVecEnvWrapper(env, self._test_reward_fn)

    def _build_summarize(self):
        self._summary_writer = summaries.make_summary_writer(
            graph=self._sess.graph)
        self.discrim.build_summaries()
        self._summary_op = tf.summary.merge_all()

    def _summarize(self, fd, step):
        events = self._sess.run(self._summary_op, feed_dict=fd)
        self._summary_writer.add_summary(events, step)

    def _build_disc_train(self):
        # Construct Train operation.
        disc_opt = self._disc_opt_cls(**self._disc_opt_kwargs)
        self._disc_train_op = disc_opt.minimize(tf.reduce_mean(
            self.discrim.disc_loss),
                                                global_step=self._global_step)

    def _build_disc_feed_dict(
            self,
            *,
            gen_old_obs: Optional[np.ndarray] = None,
            gen_act: Optional[np.ndarray] = None,
            gen_new_obs: Optional[np.ndarray] = None) -> dict:
        """Build a feed dict that holds the next training batch of generator
    and expert obs-act-obs triples.

    Args:
        gen_old_obs (np.ndarray): A numpy array with shape
            `[self.n_disc_samples_per_buffer_per_buffer] + env.observation_space.shape`.
            The ith observation in this array is the observation seen when the
            generator chooses action `gen_act[i]`.
        gen_act (np.ndarray): A numpy array with shape
            `[self.n_disc_samples_per_buffer_per_buffer] + env.action_space.shape`.
        gen_new_obs (np.ndarray): A numpy array with shape
            `[self.n_disc_samples_per_buffer_per_buffer] + env.observation_space.shape`.
            The ith observation in this array is from the transition state after
            the generator chooses action `gen_act[i]`.
    """  # noqa: E501

        # Sample generator training batch from replay buffers, unless provided
        # in argument.
        none_count = sum(
            int(x is None) for x in (gen_old_obs, gen_act, gen_new_obs))
        if none_count == 3:
            tf.logging.debug("_build_disc_feed_dict: No generator rollout "
                             "parameters were "
                             "provided, so we are generating them now.")
            gen_old_obs, gen_act, gen_new_obs = self._gen_replay_buffer.sample(
                self._n_disc_samples_per_buffer)
        elif none_count != 0:
            raise ValueError("Gave some but not all of the generator params.")

        # Sample expert training batch from replay buffer.
        expert_old_obs, expert_act, expert_new_obs = self._exp_replay_buffer.sample(
            self._n_disc_samples_per_buffer)

        # Check dimensions.
        n_expert = len(expert_old_obs)
        n_gen = len(gen_old_obs)
        N = n_expert + n_gen
        assert n_expert == len(expert_act)
        assert n_expert == len(expert_new_obs)
        assert n_gen == len(gen_act)
        assert n_gen == len(gen_new_obs)

        # Concatenate rollouts, and label each row as expert or generator.
        old_obs = np.concatenate([expert_old_obs, gen_old_obs])
        act = np.concatenate([expert_act, gen_act])
        new_obs = np.concatenate([expert_new_obs, gen_new_obs])
        labels = np.concatenate(
            [np.zeros(n_expert, dtype=int),
             np.ones(n_gen, dtype=int)])

        # Calculate generator-policy log probabilities.
        log_act_prob = self.gen_policy.action_probability(old_obs,
                                                          actions=act,
                                                          logp=True)
        assert len(log_act_prob) == N
        log_act_prob = log_act_prob.reshape((N, ))

        fd = {
            self.discrim.old_obs_ph: old_obs,
            self.discrim.act_ph: act,
            self.discrim.new_obs_ph: new_obs,
            self.discrim.labels_ph: labels,
            self.discrim.log_policy_act_prob_ph: log_act_prob,
        }
        return fd

    def _build_policy_train_reward(self):
        """Sets self._policy_train_reward_fn, the reward function to use when
    running a policy optimizer (e.g. PPO).
    """
        def R(old_obs, act, new_obs):
            """Vectorized reward function.

      Args:
          old_obs (array): The observation input. Its shape is
              `((None,) + self.env.observation_space.shape)`.
          act (array): The action input. Its shape is
              `((None,) + self.env.action_space.shape)`. The None dimension is
              expected to be the same as None dimension from `obs_input`.
          new_obs (array): The observation input. Its shape is
              `((None,) + self.env.observation_space.shape)`.
      """
            old_obs = np.atleast_1d(old_obs)
            act = np.atleast_1d(act)
            new_obs = np.atleast_1d(new_obs)

            n_gen = len(old_obs)
            assert len(act) == n_gen
            assert len(new_obs) == n_gen

            # Calculate generator-policy log probabilities.
            log_act_prob = self.gen_policy.action_probability(old_obs,
                                                              actions=act,
                                                              logp=True)
            assert len(log_act_prob) == n_gen
            log_act_prob = log_act_prob.reshape((n_gen, ))

            fd = {
                self.discrim.old_obs_ph: old_obs,
                self.discrim.act_ph: act,
                self.discrim.new_obs_ph: new_obs,
                self.discrim.labels_ph: np.ones(n_gen),
                self.discrim.log_policy_act_prob_ph: log_act_prob,
            }
            rew = self._sess.run(self.discrim.policy_train_reward,
                                 feed_dict=fd)
            return rew.flatten()

        self._policy_train_reward_fn = R

    def _build_test_reward(self):
        """Sets self._test_reward_fn, the transfer reward function"""
        def R(old_obs, act, new_obs):
            fd = {
                self.discrim.old_obs_ph: old_obs,
                self.discrim.act_ph: act,
                self.discrim.new_obs_ph: new_obs,
            }
            rew = self._sess.run(self.discrim._policy_test_reward,
                                 feed_dict=fd)
            return rew.flatten()

        self._test_reward_fn = R
Пример #9
0
def test_replay_buffer(capacity, chunk_len, obs_shape, act_shape, dtype):
  """Builds a ReplayBuffer with the provided `capacity` and inserts
  `capacity * 3` observation-action-observation samples into the buffer in
  chunks of length `chunk_len`.

  All chunks are of the appropriate observation or action shape, and contain
  the value fill_val.

  `len(buffer)` should increase until we reach capacity.
  `buffer._idx` should loop between 0 and `capacity - 1`.
  After every insertion, samples should only contain 66.6.
  """
  buf = ReplayBuffer(capacity, obs_shape=obs_shape, act_shape=act_shape,
                     obs_dtype=dtype, act_dtype=dtype)

  for i in range(0, capacity*3, chunk_len):
    assert len(buf) == min(i, capacity)
    assert buf._buffer._idx == i % capacity

    batch = rollout.Transitions(
        obs=_fill_chunk(i, chunk_len, obs_shape, dtype=dtype),
        next_obs=_fill_chunk(3 * capacity + i, chunk_len,
                             obs_shape, dtype=dtype),
        acts=_fill_chunk(6 * capacity + i, chunk_len,
                         act_shape, dtype=dtype),
        rews=np.arange(9 * capacity + i, 9 * capacity + i + chunk_len,
                       dtype=np.float32),
        dones=np.arange(i, i + chunk_len, dtype=np.int32) % 2,
    )
    buf.store(batch)

    # Are samples right shape?
    sample = buf.sample(100)
    assert sample.obs.shape == sample.next_obs.shape == (100,) + obs_shape
    assert sample.acts.shape == (100,) + act_shape
    assert sample.rews.shape == (100,)
    assert sample.dones.shape == (100,)

    # Are samples right data type?
    assert sample.obs.dtype == dtype
    assert sample.acts.dtype == dtype
    assert sample.next_obs.dtype == dtype
    assert sample.rews.dtype == np.float32
    assert sample.dones.dtype == np.bool

    # Are samples in range?
    _check_bound(i + chunk_len, capacity, sample.obs)
    _check_bound(i + chunk_len, capacity, sample.next_obs, 3 * capacity)
    _check_bound(i + chunk_len, capacity, sample.acts, 6 * capacity)
    _check_bound(i + chunk_len, capacity, sample.rews, 9 * capacity)

    # Are samples in-order?
    obs_fill = _get_fill_from_chunk(sample.obs)
    next_obs_fill = _get_fill_from_chunk(sample.next_obs)
    act_fill = _get_fill_from_chunk(sample.acts)

    assert np.all(next_obs_fill - obs_fill == 3 * capacity), "out of order"
    assert np.all(act_fill - next_obs_fill == 3 * capacity), "out of order"
    assert np.all(sample.rews - act_fill == 3 * capacity), "out of order"
    # Can't do much other than parity check for boolean values.
    # `samples.done` has the same parity as `obs_fill` by construction.
    assert np.all(obs_fill % 2 == sample.dones), "out of order"